diff --git a/backend/api/handlers.go b/backend/api/handlers.go index e29a844..9f59c10 100644 --- a/backend/api/handlers.go +++ b/backend/api/handlers.go @@ -50,6 +50,14 @@ type ChatCompletionMessage struct { Content string `json:"content"` } +// BackendChatCompletionRequest 是实际发送到后端模型的请求结构 +// 它只包含通用字段,以避免发送不被支持的参数 +type BackendChatCompletionRequest struct { + Model string `json:"model"` + Messages []ChatCompletionMessage `json:"messages"` + Stream bool `json:"stream,omitempty"` +} + // ResponsesRequest /v1/responses 端点请求结构 type ResponsesRequest struct { Model string `json:"model"` @@ -119,8 +127,13 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) { var req ChatCompletionRequest + // 增加日志记录请求体 + bodyBytes, _ := io.ReadAll(c.Request.Body) + c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // 把body重新写回去 + // 解析请求体 if err := c.ShouldBindJSON(&req); err != nil { + log.Printf("Failed to bind JSON: %v", err) // 增加错误日志 c.JSON(http.StatusBadRequest, gin.H{ "error": gin.H{ "message": "Invalid request format", @@ -146,7 +159,22 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) { } // 准备转发请求 - requestBody, err := json.Marshal(req) + // 用body创建一个json对象 + var jsonBody map[string]interface{} + if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": "Failed to process request", + "type": "internal_error", + }, + }) + return + } + // 然后修改jsonBody的model字段 + jsonBody["model"] = backendModel.Name + // 最后重新marshal回requestBody + requestBody, err := json.Marshal(jsonBody) + if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ @@ -212,6 +240,10 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) { return } + // 增加日志记录后端响应 + log.Printf("Backend Response Status: %s", resp.Status) + log.Printf("Backend Response Body: %s", string(responseBody)) + // 计算响应token数 responseTokenCount := 0 var responseData map[string]interface{} diff --git a/backend/gateway.db b/backend/gateway.db index b07dd92..e3abd2f 100644 Binary files a/backend/gateway.db and b/backend/gateway.db differ diff --git a/backend/internal/models/schema.go b/backend/internal/models/schema.go index da8ce0d..012cf3c 100644 --- a/backend/internal/models/schema.go +++ b/backend/internal/models/schema.go @@ -15,9 +15,10 @@ const ( // Provider 代表一个AI服务提供商 type Provider struct { gorm.Model - Name string `gorm:"uniqueIndex;not null"` // 服务商名称,唯一索引 - BaseURL string `gorm:"not null"` // API基础URL - ApiKey string `gorm:"not null"` // API密钥 + Name string `gorm:"uniqueIndex;not null" json:"name"` // 服务商名称,唯一索引 + BaseURL string `gorm:"not null" json:"base_url"` // API基础URL + ApiKey string `gorm:"not null" json:"api_key"` // API密钥 + APIVersion string `json:"api_version"` // API版本(可选) } // APIKey 用于网关本身的API认证 diff --git a/backend/internal/router/selector.go b/backend/internal/router/selector.go index a7cf55b..3068c4d 100644 --- a/backend/internal/router/selector.go +++ b/backend/internal/router/selector.go @@ -32,7 +32,7 @@ func SelectBackendModel(db *gorm.DB, virtualModelName string, requestTokenCount // 筛选满足上下文长度要求的模型 var suitableModels []models.BackendModel for _, backendModel := range virtualModel.BackendModels { - if backendModel.MaxContextLength >= requestTokenCount { + if backendModel.MaxContextLength == 0 || backendModel.MaxContextLength >= requestTokenCount { suitableModels = append(suitableModels, backendModel) } } @@ -67,11 +67,14 @@ func SelectBackendModel(db *gorm.DB, virtualModelName string, requestTokenCount estimatedCost = model.FixedPrice } - // 如果该模型设置了成本阈值,检查成本是否超过阈值 - if model.CostThreshold > 0 { - // 如果成本超过该模型的阈值,跳过该模型 - if estimatedCost > model.CostThreshold { - continue + // 如果CostThreshold不为0,则表示设置了成本阈值 + if model.CostThreshold != 0 { + // 如果该模型设置了成本阈值,检查成本是否超过阈值 + if model.CostThreshold > 0 { + // 如果成本超过该模型的阈值,跳过该模型 + if estimatedCost > model.CostThreshold { + continue + } } } diff --git a/backend/main.go b/backend/main.go index 1c7d073..696ab94 100644 --- a/backend/main.go +++ b/backend/main.go @@ -39,13 +39,20 @@ func main() { DB: database, } + // 创建根组 + root_ := router.Group("/") + root_.Use(middleware.AuthMiddleware(database)) + { + root_.GET("/models", handler.ListModels) + } + // 创建受保护的路由组 - protected := router.Group("/") + protected := router.Group("/v1") protected.Use(middleware.AuthMiddleware(database)) { protected.GET("/models", handler.ListModels) - protected.POST("/v1/chat/completions", handler.ChatCompletions) - protected.POST("/v1/responses", handler.ResponsesCompletions) + protected.POST("/chat/completions", handler.ChatCompletions) + protected.POST("/responses", handler.ResponsesCompletions) } // 创建API管理路由组 diff --git a/frontend/src/features/providers/components/ProviderForm.jsx b/frontend/src/features/providers/components/ProviderForm.jsx index 7d17cb7..b0813d3 100644 --- a/frontend/src/features/providers/components/ProviderForm.jsx +++ b/frontend/src/features/providers/components/ProviderForm.jsx @@ -8,20 +8,20 @@ const ProviderForm = ({ provider, onSave, onCancel }) => { useEffect(() => { if (provider) { - setName(provider.Name || ''); - setBaseURL(provider.BaseURL || ''); - setAPIKey(provider.APIKey || ''); - setAPIVersion(provider.APIVersion || ''); + setName(provider.name || ''); + setBaseURL(provider.base_url || ''); + setAPIKey(provider.api_key || ''); + setAPIVersion(provider.api_version || ''); } }, [provider]); const handleSubmit = (e) => { e.preventDefault(); onSave({ - Name: name, - BaseURL: baseURL, - APIKey: apiKey, - APIVersion: apiVersion, + name: name, + base_url: baseURL, + api_key: apiKey, + api_version: apiVersion, }); }; diff --git a/frontend/src/features/providers/components/ProviderList.jsx b/frontend/src/features/providers/components/ProviderList.jsx index e5e0acc..f0a9768 100644 --- a/frontend/src/features/providers/components/ProviderList.jsx +++ b/frontend/src/features/providers/components/ProviderList.jsx @@ -138,12 +138,12 @@ const ProviderList = () => {
- {provider.Name || 'N/A'} + {provider.name || 'N/A'}
- {provider.BaseURL || 'N/A'} + {provider.base_url || 'N/A'} diff --git a/frontend/src/features/virtual-models/components/VirtualModelForm.jsx b/frontend/src/features/virtual-models/components/VirtualModelForm.jsx index 7231ff1..1a9152a 100644 --- a/frontend/src/features/virtual-models/components/VirtualModelForm.jsx +++ b/frontend/src/features/virtual-models/components/VirtualModelForm.jsx @@ -79,11 +79,13 @@ const VirtualModelForm = ({ model, onSave, onCancel }) => { className="mt-1 block w-full rounded-md border-gray-300 shadow-sm" > - {providers.map((p) => ( - - ))} + {providers.map((p) => { + return ( + + ); + })}