Files
AIRouter/backend/internal/db/database.go

162 lines
4.1 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package db
import (
"ai-gateway/internal/models"
"log"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// InitDB 初始化数据库连接并执行自动迁移
// 参数 dbPath: 数据库文件路径
func InitDB(dbPath string) (*gorm.DB, error) {
// 使用SQLite驱动连接到数据库文件
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
if err != nil {
return nil, err
}
// 自动迁移所有模型,创建相应的表
err = db.AutoMigrate(
&models.Provider{},
&models.APIKey{},
&models.VirtualModel{},
&models.BackendModel{},
&models.RequestLog{},
)
if err != nil {
return nil, err
}
// 初始化默认数据
if err := initializeDefaultData(db); err != nil {
log.Printf("Warning: Failed to initialize default data: %v", err)
}
return db, nil
}
// initializeDefaultData 初始化默认数据如默认API密钥
func initializeDefaultData(db *gorm.DB) error {
// 检查是否已存在API密钥
var count int64
if err := db.Model(&models.APIKey{}).Count(&count).Error; err != nil {
return err
}
// 如果没有API密钥创建默认密钥
if count == 0 {
defaultKey := &models.APIKey{
Key: "sk-dev-key-789012", // 与前端配置保持一致
}
if err := db.Create(defaultKey).Error; err != nil {
return err
}
log.Println("✓ Created default API key: sk-dev-key-789012")
}
return nil
}
// GetProviders 从数据库中获取所有服务商列表
func GetProviders(db *gorm.DB) ([]models.Provider, error) {
var providers []models.Provider
result := db.Find(&providers)
if result.Error != nil {
return nil, result.Error
}
return providers, nil
}
// GetProviderByID 通过ID获取服务商
func GetProviderByID(db *gorm.DB, id uint) (*models.Provider, error) {
var provider models.Provider
if err := db.First(&provider, id).Error; err != nil {
return nil, err
}
return &provider, nil
}
// CreateProvider 创建一个新的服务商
func CreateProvider(db *gorm.DB, provider *models.Provider) error {
return db.Create(provider).Error
}
// UpdateProvider 更新一个服务商
func UpdateProvider(db *gorm.DB, provider *models.Provider) error {
return db.Save(provider).Error
}
// DeleteProvider 删除一个服务商
func DeleteProvider(db *gorm.DB, id uint) error {
return db.Delete(&models.Provider{}, id).Error
}
// CreateVirtualModel 创建一个新的虚拟模型
func CreateVirtualModel(db *gorm.DB, virtualModel *models.VirtualModel) error {
return db.Create(virtualModel).Error
}
// GetVirtualModels 获取所有虚拟模型
func GetVirtualModels(db *gorm.DB) ([]models.VirtualModel, error) {
var virtualModels []models.VirtualModel
if err := db.Preload("BackendModels").Find(&virtualModels).Error; err != nil {
return nil, err
}
return virtualModels, nil
}
// GetVirtualModelByID 通过ID获取虚拟模型
func GetVirtualModelByID(db *gorm.DB, id uint) (*models.VirtualModel, error) {
var virtualModel models.VirtualModel
if err := db.Preload("BackendModels").First(&virtualModel, id).Error; err != nil {
return nil, err
}
return &virtualModel, nil
}
// UpdateVirtualModel 更新一个虚拟模型
func UpdateVirtualModel(db *gorm.DB, virtualModel *models.VirtualModel) error {
// 开始事务
tx := db.Begin()
if tx.Error != nil {
return tx.Error
}
// 删除与此虚拟模型关联的所有旧的后端模型
result := tx.Where("virtual_model_id = ?", virtualModel.ID).Delete(&models.BackendModel{})
if result.Error != nil {
tx.Rollback()
return result.Error
}
// 添加新的后端模型
if len(virtualModel.BackendModels) > 0 {
for i := range virtualModel.BackendModels {
// 清除 ID 字段,让数据库自动分配新 ID
virtualModel.BackendModels[i].ID = 0
virtualModel.BackendModels[i].VirtualModelID = virtualModel.ID
}
if err := tx.Create(&virtualModel.BackendModels).Error; err != nil {
tx.Rollback()
return err
}
}
// 更新虚拟模型本身
if err := tx.Omit("BackendModels").Save(virtualModel).Error; err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
// DeleteVirtualModel 删除一个虚拟模型
func DeleteVirtualModel(db *gorm.DB, id uint) error {
return db.Delete(&models.VirtualModel{}, id).Error
}