133 lines
3.4 KiB
Go
133 lines
3.4 KiB
Go
package db
|
|
|
|
import (
|
|
"ai-gateway/internal/models"
|
|
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// InitDB 初始化数据库连接并执行自动迁移
|
|
func InitDB() (*gorm.DB, error) {
|
|
// 使用SQLite驱动连接到gateway.db数据库文件
|
|
db, err := gorm.Open(sqlite.Open("gateway.db"), &gorm.Config{})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 自动迁移所有模型,创建相应的表
|
|
err = db.AutoMigrate(
|
|
&models.Provider{},
|
|
&models.APIKey{},
|
|
&models.VirtualModel{},
|
|
&models.BackendModel{},
|
|
&models.RequestLog{},
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return db, nil
|
|
}
|
|
|
|
// GetProviders 从数据库中获取所有服务商列表
|
|
func GetProviders(db *gorm.DB) ([]models.Provider, error) {
|
|
var providers []models.Provider
|
|
result := db.Find(&providers)
|
|
if result.Error != nil {
|
|
return nil, result.Error
|
|
}
|
|
|
|
return providers, nil
|
|
}
|
|
|
|
// GetProviderByID 通过ID获取服务商
|
|
func GetProviderByID(db *gorm.DB, id uint) (*models.Provider, error) {
|
|
var provider models.Provider
|
|
if err := db.First(&provider, id).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return &provider, nil
|
|
}
|
|
|
|
// CreateProvider 创建一个新的服务商
|
|
func CreateProvider(db *gorm.DB, provider *models.Provider) error {
|
|
return db.Create(provider).Error
|
|
}
|
|
|
|
// UpdateProvider 更新一个服务商
|
|
func UpdateProvider(db *gorm.DB, provider *models.Provider) error {
|
|
return db.Save(provider).Error
|
|
}
|
|
|
|
// DeleteProvider 删除一个服务商
|
|
func DeleteProvider(db *gorm.DB, id uint) error {
|
|
return db.Delete(&models.Provider{}, id).Error
|
|
}
|
|
|
|
// CreateVirtualModel 创建一个新的虚拟模型
|
|
func CreateVirtualModel(db *gorm.DB, virtualModel *models.VirtualModel) error {
|
|
return db.Create(virtualModel).Error
|
|
}
|
|
|
|
// GetVirtualModels 获取所有虚拟模型
|
|
func GetVirtualModels(db *gorm.DB) ([]models.VirtualModel, error) {
|
|
var virtualModels []models.VirtualModel
|
|
if err := db.Preload("BackendModels").Find(&virtualModels).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return virtualModels, nil
|
|
}
|
|
|
|
// GetVirtualModelByID 通过ID获取虚拟模型
|
|
func GetVirtualModelByID(db *gorm.DB, id uint) (*models.VirtualModel, error) {
|
|
var virtualModel models.VirtualModel
|
|
if err := db.Preload("BackendModels").First(&virtualModel, id).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return &virtualModel, nil
|
|
}
|
|
|
|
// UpdateVirtualModel 更新一个虚拟模型
|
|
func UpdateVirtualModel(db *gorm.DB, virtualModel *models.VirtualModel) error {
|
|
// 开始事务
|
|
tx := db.Begin()
|
|
if tx.Error != nil {
|
|
return tx.Error
|
|
}
|
|
|
|
// 删除与此虚拟模型关联的所有旧的后端模型
|
|
result := tx.Where("virtual_model_id = ?", virtualModel.ID).Delete(&models.BackendModel{})
|
|
if result.Error != nil {
|
|
tx.Rollback()
|
|
return result.Error
|
|
}
|
|
|
|
// 添加新的后端模型
|
|
if len(virtualModel.BackendModels) > 0 {
|
|
for i := range virtualModel.BackendModels {
|
|
// 清除 ID 字段,让数据库自动分配新 ID
|
|
virtualModel.BackendModels[i].ID = 0
|
|
virtualModel.BackendModels[i].VirtualModelID = virtualModel.ID
|
|
}
|
|
|
|
if err := tx.Create(&virtualModel.BackendModels).Error; err != nil {
|
|
tx.Rollback()
|
|
return err
|
|
}
|
|
}
|
|
|
|
// 更新虚拟模型本身
|
|
if err := tx.Omit("BackendModels").Save(virtualModel).Error; err != nil {
|
|
tx.Rollback()
|
|
return err
|
|
}
|
|
|
|
return tx.Commit().Error
|
|
}
|
|
|
|
// DeleteVirtualModel 删除一个虚拟模型
|
|
func DeleteVirtualModel(db *gorm.DB, id uint) error {
|
|
return db.Delete(&models.VirtualModel{}, id).Error
|
|
}
|