Files
AIRouter/backend/main.go
2025-11-17 11:42:17 +08:00

153 lines
4.7 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"ai-gateway/api"
"ai-gateway/internal/config"
"ai-gateway/internal/db"
"ai-gateway/internal/logger"
"ai-gateway/internal/middleware"
"ai-gateway/internal/scheduler"
"log"
"os"
"os/signal"
"syscall"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
)
func main() {
log.Println("🚀 启动 AI Gateway 服务...")
// 加载配置文件(使用固定路径,便于 Docker 挂载)
cfg := config.LoadConfigOrDefault(config.ConfigPath)
cfg.Print()
// 验证配置
if err := cfg.Validate(); err != nil {
log.Printf("⚠️ 配置验证失败: %v使用默认配置", err)
cfg = config.DefaultConfig()
}
logger.SetSaveRequestLog(cfg.App.LogInDB)
// 初始化数据库连接(使用固定数据库路径)
database, err := db.InitDB(config.DatabasePath)
if err != nil {
log.Fatalf("❌ 数据库初始化失败: %v", err)
}
log.Printf("✅ 数据库连接成功: %s", config.DatabasePath)
// 初始化并启动日志清理器
logCleaner := scheduler.NewLogCleaner(database, cfg.LogCleaner)
logCleaner.Start()
log.Println("✅ 日志自动清理器已启动")
// 设置优雅关闭
defer func() {
log.Println("🛑 正在关闭服务...")
logCleaner.Stop()
log.Println("✅ 日志自动清理器已停止")
}()
// 创建Gin路由器实例
router := gin.Default()
// 配置CORS中间件 - 采用非常宽松的策略,允许所有请求头
config := cors.DefaultConfig()
config.AllowAllOrigins = true
config.AllowMethods = []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"}
// 允许所有请求头(使用通配符)
config.AllowHeaders = []string{"*"}
config.AllowWildcard = true
// 注意当AllowAllOrigins为true时AllowCredentials必须为false
config.AllowCredentials = false
// 暴露所有响应头
config.ExposeHeaders = []string{"*"}
router.Use(cors.New(config))
// 创建API处理器
handler := &api.APIHandler{
DB: database,
LogCleaner: logCleaner,
WebUIPassword: cfg.App.WebUIPassword,
}
// 添加健康检查端点(无需认证)
router.GET("/health", handler.HealthCheckHandler)
router.GET("/healthz", handler.HealthCheckHandler)
// 创建根组
root_ := router.Group("/")
root_.Use(middleware.AuthMiddleware(database, cfg.App.WebUIPassword))
{
root_.GET("/models", handler.ListModels)
}
// 创建受保护的路由组
protected := router.Group("/v1")
protected.Use(middleware.AuthMiddleware(database, cfg.App.WebUIPassword))
{
protected.GET("/models", handler.ListModels)
protected.POST("/chat/completions", handler.ChatCompletions)
protected.POST("/responses", handler.ResponsesCompletions)
protected.POST("/embeddings", handler.Embeddings)
}
// 创建API管理路由组
api_ := router.Group("/api")
api_.POST("/login", handler.LoginHandler)
api_.Use(middleware.AuthMiddleware(database, cfg.App.WebUIPassword))
{
// Providers
api_.GET("/providers", handler.GetProvidersHandler)
api_.GET("/providers/:id", handler.GetProviderHandler)
api_.POST("/providers", handler.CreateProviderHandler)
api_.PUT("/providers/:id", handler.UpdateProviderHandler)
api_.DELETE("/providers/:id", handler.DeleteProviderHandler)
api_.GET("/providers/:id/models", handler.GetProviderModelsHandler)
// Virtual Models
api_.POST("/virtual-models", handler.CreateVirtualModelHandler)
api_.GET("/virtual-models", handler.GetVirtualModelsHandler)
api_.GET("/virtual-models/:id", handler.GetVirtualModelHandler)
api_.PUT("/virtual-models/:id", handler.UpdateVirtualModelHandler)
api_.DELETE("/virtual-models/:id", handler.DeleteVirtualModelHandler)
// Request Logs
api_.GET("/logs", handler.GetRequestLogsHandler)
api_.GET("/logs/stats", handler.GetRequestLogStatsHandler)
api_.GET("/logs/:id", handler.GetRequestLogDetailHandler)
api_.DELETE("/logs", handler.ClearRequestLogsHandler)
// Log Cleaner Management
api_.GET("/log-cleaner/status", handler.GetLogCleanerStatusHandler)
api_.POST("/log-cleaner/force-cleanup", handler.ForceLogCleanupHandler)
// API Keys Management
api_.GET("/api-keys", handler.GetAPIKeysHandler)
api_.POST("/api-keys", handler.CreateAPIKeyHandler)
api_.DELETE("/api-keys/:id", handler.DeleteAPIKeyHandler)
}
// 设置优雅关闭信号处理
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
// 在goroutine中启动HTTP服务器
go func() {
serverAddr := cfg.Server.Host + ":" + cfg.Server.Port
log.Printf("🌐 HTTP服务器启动在 %s", serverAddr)
if err := router.Run(serverAddr); err != nil {
log.Fatalf("❌ HTTP服务器启动失败: %v", err)
}
}()
// 等待退出信号
<-quit
log.Println("🛑 收到退出信号,正在优雅关闭...")
logCleaner.Stop()
log.Println("✅ AI Gateway 服务已关闭")
}