Files
AIRouter/backend/internal/db/database.go

133 lines
3.4 KiB
Go

package db
import (
"ai-gateway/internal/models"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// InitDB 初始化数据库连接并执行自动迁移
func InitDB() (*gorm.DB, error) {
// 使用SQLite驱动连接到gateway.db数据库文件
db, err := gorm.Open(sqlite.Open("gateway.db"), &gorm.Config{})
if err != nil {
return nil, err
}
// 自动迁移所有模型,创建相应的表
err = db.AutoMigrate(
&models.Provider{},
&models.APIKey{},
&models.VirtualModel{},
&models.BackendModel{},
&models.RequestLog{},
)
if err != nil {
return nil, err
}
return db, nil
}
// GetProviders 从数据库中获取所有服务商列表
func GetProviders(db *gorm.DB) ([]models.Provider, error) {
var providers []models.Provider
result := db.Find(&providers)
if result.Error != nil {
return nil, result.Error
}
return providers, nil
}
// GetProviderByID 通过ID获取服务商
func GetProviderByID(db *gorm.DB, id uint) (*models.Provider, error) {
var provider models.Provider
if err := db.First(&provider, id).Error; err != nil {
return nil, err
}
return &provider, nil
}
// CreateProvider 创建一个新的服务商
func CreateProvider(db *gorm.DB, provider *models.Provider) error {
return db.Create(provider).Error
}
// UpdateProvider 更新一个服务商
func UpdateProvider(db *gorm.DB, provider *models.Provider) error {
return db.Save(provider).Error
}
// DeleteProvider 删除一个服务商
func DeleteProvider(db *gorm.DB, id uint) error {
return db.Delete(&models.Provider{}, id).Error
}
// CreateVirtualModel 创建一个新的虚拟模型
func CreateVirtualModel(db *gorm.DB, virtualModel *models.VirtualModel) error {
return db.Create(virtualModel).Error
}
// GetVirtualModels 获取所有虚拟模型
func GetVirtualModels(db *gorm.DB) ([]models.VirtualModel, error) {
var virtualModels []models.VirtualModel
if err := db.Preload("BackendModels").Find(&virtualModels).Error; err != nil {
return nil, err
}
return virtualModels, nil
}
// GetVirtualModelByID 通过ID获取虚拟模型
func GetVirtualModelByID(db *gorm.DB, id uint) (*models.VirtualModel, error) {
var virtualModel models.VirtualModel
if err := db.Preload("BackendModels").First(&virtualModel, id).Error; err != nil {
return nil, err
}
return &virtualModel, nil
}
// UpdateVirtualModel 更新一个虚拟模型
func UpdateVirtualModel(db *gorm.DB, virtualModel *models.VirtualModel) error {
// 开始事务
tx := db.Begin()
if tx.Error != nil {
return tx.Error
}
// 删除与此虚拟模型关联的所有旧的后端模型
result := tx.Where("virtual_model_id = ?", virtualModel.ID).Delete(&models.BackendModel{})
if result.Error != nil {
tx.Rollback()
return result.Error
}
// 添加新的后端模型
if len(virtualModel.BackendModels) > 0 {
for i := range virtualModel.BackendModels {
// 清除 ID 字段,让数据库自动分配新 ID
virtualModel.BackendModels[i].ID = 0
virtualModel.BackendModels[i].VirtualModelID = virtualModel.ID
}
if err := tx.Create(&virtualModel.BackendModels).Error; err != nil {
tx.Rollback()
return err
}
}
// 更新虚拟模型本身
if err := tx.Omit("BackendModels").Save(virtualModel).Error; err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
// DeleteVirtualModel 删除一个虚拟模型
func DeleteVirtualModel(db *gorm.DB, id uint) error {
return db.Delete(&models.VirtualModel{}, id).Error
}