package db import ( "ai-gateway/internal/models" "log" "gorm.io/driver/sqlite" "gorm.io/gorm" ) // InitDB 初始化数据库连接并执行自动迁移 // 参数 dbPath: 数据库文件路径 func InitDB(dbPath string) (*gorm.DB, error) { // 使用SQLite驱动连接到数据库文件 db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{}) if err != nil { return nil, err } // 自动迁移所有模型,创建相应的表 err = db.AutoMigrate( &models.Provider{}, &models.APIKey{}, &models.VirtualModel{}, &models.BackendModel{}, &models.RequestLog{}, ) if err != nil { return nil, err } // 初始化默认数据 if err := initializeDefaultData(db); err != nil { log.Printf("Warning: Failed to initialize default data: %v", err) } return db, nil } // initializeDefaultData 初始化默认数据(如默认API密钥) func initializeDefaultData(db *gorm.DB) error { // 检查是否已存在API密钥 var count int64 if err := db.Model(&models.APIKey{}).Count(&count).Error; err != nil { return err } // 如果没有API密钥,创建默认密钥 if count == 0 { defaultKey := &models.APIKey{ Key: "sk-dev-key-789012", // 与前端配置保持一致 } if err := db.Create(defaultKey).Error; err != nil { return err } log.Println("✓ Created default API key: sk-dev-key-789012") } return nil } // GetProviders 从数据库中获取所有服务商列表 func GetProviders(db *gorm.DB) ([]models.Provider, error) { var providers []models.Provider result := db.Find(&providers) if result.Error != nil { return nil, result.Error } return providers, nil } // GetProviderByID 通过ID获取服务商 func GetProviderByID(db *gorm.DB, id uint) (*models.Provider, error) { var provider models.Provider if err := db.First(&provider, id).Error; err != nil { return nil, err } return &provider, nil } // CreateProvider 创建一个新的服务商 func CreateProvider(db *gorm.DB, provider *models.Provider) error { return db.Create(provider).Error } // UpdateProvider 更新一个服务商 func UpdateProvider(db *gorm.DB, provider *models.Provider) error { return db.Save(provider).Error } // DeleteProvider 删除一个服务商 func DeleteProvider(db *gorm.DB, id uint) error { return db.Delete(&models.Provider{}, id).Error } // CreateVirtualModel 创建一个新的虚拟模型 func CreateVirtualModel(db *gorm.DB, virtualModel *models.VirtualModel) error { return db.Create(virtualModel).Error } // GetVirtualModels 获取所有虚拟模型 func GetVirtualModels(db *gorm.DB) ([]models.VirtualModel, error) { var virtualModels []models.VirtualModel if err := db.Preload("BackendModels").Find(&virtualModels).Error; err != nil { return nil, err } return virtualModels, nil } // GetVirtualModelByID 通过ID获取虚拟模型 func GetVirtualModelByID(db *gorm.DB, id uint) (*models.VirtualModel, error) { var virtualModel models.VirtualModel if err := db.Preload("BackendModels").First(&virtualModel, id).Error; err != nil { return nil, err } return &virtualModel, nil } // UpdateVirtualModel 更新一个虚拟模型 func UpdateVirtualModel(db *gorm.DB, virtualModel *models.VirtualModel) error { // 开始事务 tx := db.Begin() if tx.Error != nil { return tx.Error } // 删除与此虚拟模型关联的所有旧的后端模型 result := tx.Where("virtual_model_id = ?", virtualModel.ID).Delete(&models.BackendModel{}) if result.Error != nil { tx.Rollback() return result.Error } // 添加新的后端模型 if len(virtualModel.BackendModels) > 0 { for i := range virtualModel.BackendModels { // 清除 ID 字段,让数据库自动分配新 ID virtualModel.BackendModels[i].ID = 0 virtualModel.BackendModels[i].VirtualModelID = virtualModel.ID } if err := tx.Create(&virtualModel.BackendModels).Error; err != nil { tx.Rollback() return err } } // 更新虚拟模型本身 if err := tx.Omit("BackendModels").Save(virtualModel).Error; err != nil { tx.Rollback() return err } return tx.Commit().Error } // DeleteVirtualModel 删除一个虚拟模型 func DeleteVirtualModel(db *gorm.DB, id uint) error { return db.Delete(&models.VirtualModel{}, id).Error }