94 lines
2.8 KiB
Go
94 lines
2.8 KiB
Go
package main
|
||
|
||
import (
|
||
"ai-gateway/api"
|
||
"ai-gateway/internal/db"
|
||
"ai-gateway/internal/middleware"
|
||
"log"
|
||
|
||
"github.com/gin-contrib/cors"
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
func main() {
|
||
// 初始化数据库连接
|
||
database, err := db.InitDB()
|
||
if err != nil {
|
||
log.Fatalf("Failed to initialize database: %v", err)
|
||
}
|
||
|
||
// 创建Gin路由器实例
|
||
router := gin.Default()
|
||
|
||
// 配置CORS中间件 - 采用非常宽松的策略,允许所有请求头
|
||
config := cors.DefaultConfig()
|
||
config.AllowAllOrigins = true
|
||
config.AllowMethods = []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"}
|
||
// 允许所有请求头(使用通配符)
|
||
config.AllowHeaders = []string{"*"}
|
||
config.AllowWildcard = true
|
||
// 注意:当AllowAllOrigins为true时,AllowCredentials必须为false
|
||
config.AllowCredentials = false
|
||
// 暴露所有响应头
|
||
config.ExposeHeaders = []string{"*"}
|
||
|
||
router.Use(cors.New(config))
|
||
|
||
// 创建API处理器
|
||
handler := &api.APIHandler{
|
||
DB: database,
|
||
}
|
||
|
||
// 添加健康检查端点(无需认证)
|
||
router.GET("/health", handler.HealthCheckHandler)
|
||
router.GET("/healthz", handler.HealthCheckHandler)
|
||
|
||
// 创建根组
|
||
root_ := router.Group("/")
|
||
root_.Use(middleware.AuthMiddleware(database))
|
||
{
|
||
root_.GET("/models", handler.ListModels)
|
||
}
|
||
|
||
// 创建受保护的路由组
|
||
protected := router.Group("/v1")
|
||
protected.Use(middleware.AuthMiddleware(database))
|
||
{
|
||
protected.GET("/models", handler.ListModels)
|
||
protected.POST("/chat/completions", handler.ChatCompletions)
|
||
protected.POST("/responses", handler.ResponsesCompletions)
|
||
protected.POST("/embeddings", handler.Embeddings)
|
||
}
|
||
|
||
// 创建API管理路由组
|
||
api_ := router.Group("/api")
|
||
api_.Use(middleware.AuthMiddleware(database))
|
||
{
|
||
// Providers
|
||
api_.GET("/providers", handler.GetProvidersHandler)
|
||
api_.GET("/providers/:id", handler.GetProviderHandler)
|
||
api_.POST("/providers", handler.CreateProviderHandler)
|
||
api_.PUT("/providers/:id", handler.UpdateProviderHandler)
|
||
api_.DELETE("/providers/:id", handler.DeleteProviderHandler)
|
||
|
||
// Virtual Models
|
||
api_.POST("/virtual-models", handler.CreateVirtualModelHandler)
|
||
api_.GET("/virtual-models", handler.GetVirtualModelsHandler)
|
||
api_.GET("/virtual-models/:id", handler.GetVirtualModelHandler)
|
||
api_.PUT("/virtual-models/:id", handler.UpdateVirtualModelHandler)
|
||
api_.DELETE("/virtual-models/:id", handler.DeleteVirtualModelHandler)
|
||
|
||
// Request Logs
|
||
api_.GET("/logs", handler.GetRequestLogsHandler)
|
||
api_.GET("/logs/stats", handler.GetRequestLogStatsHandler)
|
||
api_.GET("/logs/:id", handler.GetRequestLogDetailHandler)
|
||
api_.DELETE("/logs", handler.ClearRequestLogsHandler)
|
||
}
|
||
|
||
// 启动HTTP服务器
|
||
log.Println("Starting AI Gateway server on :8080")
|
||
if err := router.Run(":8080"); err != nil {
|
||
log.Fatalf("Failed to start server: %v", err)
|
||
}
|
||
}
|