162 lines
4.1 KiB
Go
162 lines
4.1 KiB
Go
package db
|
||
|
||
import (
|
||
"ai-gateway/internal/models"
|
||
"log"
|
||
|
||
"gorm.io/driver/sqlite"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// InitDB 初始化数据库连接并执行自动迁移
|
||
// 参数 dbPath: 数据库文件路径
|
||
func InitDB(dbPath string) (*gorm.DB, error) {
|
||
// 使用SQLite驱动连接到数据库文件
|
||
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 自动迁移所有模型,创建相应的表
|
||
err = db.AutoMigrate(
|
||
&models.Provider{},
|
||
&models.APIKey{},
|
||
&models.VirtualModel{},
|
||
&models.BackendModel{},
|
||
&models.RequestLog{},
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 初始化默认数据
|
||
if err := initializeDefaultData(db); err != nil {
|
||
log.Printf("Warning: Failed to initialize default data: %v", err)
|
||
}
|
||
|
||
return db, nil
|
||
}
|
||
|
||
// initializeDefaultData 初始化默认数据(如默认API密钥)
|
||
func initializeDefaultData(db *gorm.DB) error {
|
||
// 检查是否已存在API密钥
|
||
var count int64
|
||
if err := db.Model(&models.APIKey{}).Count(&count).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
// 如果没有API密钥,创建默认密钥
|
||
if count == 0 {
|
||
defaultKey := &models.APIKey{
|
||
Key: "sk-dev-key-789012", // 与前端配置保持一致
|
||
}
|
||
if err := db.Create(defaultKey).Error; err != nil {
|
||
return err
|
||
}
|
||
log.Println("✓ Created default API key: sk-dev-key-789012")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// GetProviders 从数据库中获取所有服务商列表
|
||
func GetProviders(db *gorm.DB) ([]models.Provider, error) {
|
||
var providers []models.Provider
|
||
result := db.Find(&providers)
|
||
if result.Error != nil {
|
||
return nil, result.Error
|
||
}
|
||
|
||
return providers, nil
|
||
}
|
||
|
||
// GetProviderByID 通过ID获取服务商
|
||
func GetProviderByID(db *gorm.DB, id uint) (*models.Provider, error) {
|
||
var provider models.Provider
|
||
if err := db.First(&provider, id).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return &provider, nil
|
||
}
|
||
|
||
// CreateProvider 创建一个新的服务商
|
||
func CreateProvider(db *gorm.DB, provider *models.Provider) error {
|
||
return db.Create(provider).Error
|
||
}
|
||
|
||
// UpdateProvider 更新一个服务商
|
||
func UpdateProvider(db *gorm.DB, provider *models.Provider) error {
|
||
return db.Save(provider).Error
|
||
}
|
||
|
||
// DeleteProvider 删除一个服务商
|
||
func DeleteProvider(db *gorm.DB, id uint) error {
|
||
return db.Delete(&models.Provider{}, id).Error
|
||
}
|
||
|
||
// CreateVirtualModel 创建一个新的虚拟模型
|
||
func CreateVirtualModel(db *gorm.DB, virtualModel *models.VirtualModel) error {
|
||
return db.Create(virtualModel).Error
|
||
}
|
||
|
||
// GetVirtualModels 获取所有虚拟模型
|
||
func GetVirtualModels(db *gorm.DB) ([]models.VirtualModel, error) {
|
||
var virtualModels []models.VirtualModel
|
||
if err := db.Preload("BackendModels").Find(&virtualModels).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return virtualModels, nil
|
||
}
|
||
|
||
// GetVirtualModelByID 通过ID获取虚拟模型
|
||
func GetVirtualModelByID(db *gorm.DB, id uint) (*models.VirtualModel, error) {
|
||
var virtualModel models.VirtualModel
|
||
if err := db.Preload("BackendModels").First(&virtualModel, id).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return &virtualModel, nil
|
||
}
|
||
|
||
// UpdateVirtualModel 更新一个虚拟模型
|
||
func UpdateVirtualModel(db *gorm.DB, virtualModel *models.VirtualModel) error {
|
||
// 开始事务
|
||
tx := db.Begin()
|
||
if tx.Error != nil {
|
||
return tx.Error
|
||
}
|
||
|
||
// 删除与此虚拟模型关联的所有旧的后端模型
|
||
result := tx.Where("virtual_model_id = ?", virtualModel.ID).Delete(&models.BackendModel{})
|
||
if result.Error != nil {
|
||
tx.Rollback()
|
||
return result.Error
|
||
}
|
||
|
||
// 添加新的后端模型
|
||
if len(virtualModel.BackendModels) > 0 {
|
||
for i := range virtualModel.BackendModels {
|
||
// 清除 ID 字段,让数据库自动分配新 ID
|
||
virtualModel.BackendModels[i].ID = 0
|
||
virtualModel.BackendModels[i].VirtualModelID = virtualModel.ID
|
||
}
|
||
|
||
if err := tx.Create(&virtualModel.BackendModels).Error; err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
}
|
||
|
||
// 更新虚拟模型本身
|
||
if err := tx.Omit("BackendModels").Save(virtualModel).Error; err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
|
||
return tx.Commit().Error
|
||
}
|
||
|
||
// DeleteVirtualModel 删除一个虚拟模型
|
||
func DeleteVirtualModel(db *gorm.DB, id uint) error {
|
||
return db.Delete(&models.VirtualModel{}, id).Error
|
||
}
|