From 5bffcc6244fcc60e3682bd27dc086c836a75534d Mon Sep 17 00:00:00 2001 From: nanako <469449812@qq.com> Date: Sun, 9 Nov 2025 00:14:01 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=E8=AF=B7=E6=B1=82=E5=A4=84?= =?UTF-8?q?=E7=90=86=E9=80=BB=E8=BE=91=EF=BC=8C=E4=BC=98=E5=8C=96=E8=AF=B7?= =?UTF-8?q?=E6=B1=82=E4=BD=93=E8=A7=A3=E6=9E=90=EF=BC=8C=E6=8F=90=E5=8F=96?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=92=8C=E6=B6=88=E6=81=AF=E5=AD=97=E6=AE=B5?= =?UTF-8?q?=EF=BC=8C=E5=A2=9E=E5=8A=A0=E6=97=A5=E5=BF=97=E8=AE=B0=E5=BD=95?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/api/handlers.go | 87 ++++++++++++++++++++++++++++++----------- 1 file changed, 64 insertions(+), 23 deletions(-) diff --git a/backend/api/handlers.go b/backend/api/handlers.go index 486bb28..6412baa 100644 --- a/backend/api/handlers.go +++ b/backend/api/handlers.go @@ -127,14 +127,13 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) { // 记录请求开始时间 requestTimestamp := time.Now() - var req ChatCompletionRequest - - // 读取并解析请求体 + // 读取请求体 bodyBytes, _ := io.ReadAll(c.Request.Body) - c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - if err := c.ShouldBindJSON(&req); err != nil { - log.Printf("Failed to bind JSON: %v", err) + // 解析为 map,只解析一次 + var jsonBody map[string]interface{} + if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil { + log.Printf("Failed to unmarshal JSON: %v", err) c.JSON(http.StatusBadRequest, gin.H{ "error": gin.H{ "message": "Invalid request format", @@ -144,12 +143,31 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) { return } + // 从 map 中提取需要的字段 + modelName, _ := jsonBody["model"].(string) + isStream, _ := jsonBody["stream"].(bool) + + // 提取 messages 并转换为计算 token 所需的格式 + var messages []billing.ChatCompletionMessage + if messagesRaw, ok := jsonBody["messages"].([]interface{}); ok { + messages = make([]billing.ChatCompletionMessage, 0, len(messagesRaw)) + for _, msgRaw := range messagesRaw { + if msgMap, ok := msgRaw.(map[string]interface{}); ok { + role, _ := msgMap["role"].(string) + content, _ := msgMap["content"].(string) + messages = append(messages, billing.ChatCompletionMessage{ + Role: role, + Content: content, + }) + } + } + } + // 计算请求token数 - messages := convertToTikTokenMessages(req.Messages) requestTokenCount := billing.CalculateMessagesTokensSimple(messages) // 选择后端模型 - backendModel, err := router.SelectBackendModel(h.DB, req.Model, requestTokenCount) + backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount) if err != nil { c.JSON(http.StatusNotFound, gin.H{ "error": gin.H{ @@ -160,8 +178,11 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) { return } - // 准备后端请求 - requestBody, err := prepareBackendRequest(bodyBytes, backendModel.Name) + // 修改 model 字段为后端模型名称 + jsonBody["model"] = backendModel.Name + + // Marshal 请求体用于转发 + requestBody, err := json.Marshal(jsonBody) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ @@ -195,9 +216,9 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) { apiKeyID := getAPIKeyID(c) // 处理流式响应 - if req.Stream { + if isStream { handleStreamingResponse(c, resp, requestTimestamp, responseTimestamp, - apiKeyID, req.Model, backendModel, requestTokenCount, + apiKeyID, modelName, backendModel, requestTokenCount, string(requestBody), h.DB) return } @@ -225,7 +246,7 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) { cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount) // 创建并记录日志 - logEntry := createRequestLog(apiKeyID, req.Model, backendModel, + logEntry := createRequestLog(apiKeyID, modelName, backendModel, requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount, cost, string(requestBody), string(responseBody)) logger.LogRequest(h.DB, logEntry) @@ -241,13 +262,12 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) { // 记录请求开始时间 requestTimestamp := time.Now() - var req ResponsesRequest - - // 读取并解析请求体 + // 读取请求体 bodyBytes, _ := io.ReadAll(c.Request.Body) - c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - if err := c.ShouldBindJSON(&req); err != nil { + // 解析为 map,只解析一次 + var jsonBody map[string]interface{} + if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil { c.JSON(http.StatusBadRequest, gin.H{ "error": gin.H{ "message": "Invalid request format", @@ -257,12 +277,30 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) { return } + // 从 map 中提取需要的字段 + modelName, _ := jsonBody["model"].(string) + + // 提取 input 并转换为计算 token 所需的格式 + var messages []billing.ChatCompletionMessage + if inputRaw, ok := jsonBody["input"].([]interface{}); ok { + messages = make([]billing.ChatCompletionMessage, 0, len(inputRaw)) + for _, msgRaw := range inputRaw { + if msgMap, ok := msgRaw.(map[string]interface{}); ok { + role, _ := msgMap["role"].(string) + content, _ := msgMap["content"].(string) + messages = append(messages, billing.ChatCompletionMessage{ + Role: role, + Content: content, + }) + } + } + } + // 计算请求token数 - messages := convertToTikTokenMessages(req.Input) requestTokenCount := billing.CalculateMessagesTokensSimple(messages) // 选择后端模型 - backendModel, err := router.SelectBackendModel(h.DB, req.Model, requestTokenCount) + backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount) if err != nil { c.JSON(http.StatusNotFound, gin.H{ "error": gin.H{ @@ -273,8 +311,11 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) { return } - // 准备后端请求 - requestBody, err := prepareBackendRequest(bodyBytes, backendModel.Name) + // 修改 model 字段为后端模型名称 + jsonBody["model"] = backendModel.Name + + // Marshal 请求体用于转发 + requestBody, err := json.Marshal(jsonBody) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ @@ -327,7 +368,7 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) { apiKeyID := getAPIKeyID(c) // 创建并记录日志 - logEntry := createRequestLog(apiKeyID, req.Model, backendModel, + logEntry := createRequestLog(apiKeyID, modelName, backendModel, requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount, cost, string(requestBody), string(responseBody)) logger.LogRequest(h.DB, logEntry)