diff --git a/backend/api/handlers.go b/backend/api/handlers.go index 68a8471..486bb28 100644 --- a/backend/api/handlers.go +++ b/backend/api/handlers.go @@ -129,13 +129,12 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) { var req ChatCompletionRequest - // 增加日志记录请求体 + // 读取并解析请求体 bodyBytes, _ := io.ReadAll(c.Request.Body) - c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // 把body重新写回去 + c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - // 解析请求体 if err := c.ShouldBindJSON(&req); err != nil { - log.Printf("Failed to bind JSON: %v", err) // 增加错误日志 + log.Printf("Failed to bind JSON: %v", err) c.JSON(http.StatusBadRequest, gin.H{ "error": gin.H{ "message": "Invalid request format", @@ -145,7 +144,7 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) { return } - // 使用tiktoken精确计算请求token数 + // 计算请求token数 messages := convertToTikTokenMessages(req.Messages) requestTokenCount := billing.CalculateMessagesTokensSimple(messages) @@ -161,23 +160,8 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) { return } - // 准备转发请求 - // 用body创建一个json对象 - var jsonBody map[string]interface{} - if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to process request", - "type": "internal_error", - }, - }) - return - } - // 然后修改jsonBody的model字段 - jsonBody["model"] = backendModel.Name - // 最后重新marshal回requestBody - requestBody, err := json.Marshal(jsonBody) - + // 准备后端请求 + requestBody, err := prepareBackendRequest(bodyBytes, backendModel.Name) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ @@ -191,32 +175,8 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) { // 构建后端API URL backendURL := backendModel.Provider.BaseURL + "/v1/chat/completions" - // 创建HTTP请求 - httpReq, err := http.NewRequest("POST", backendURL, bytes.NewReader(requestBody)) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to create backend request", - "type": "internal_error", - }, - }) - return - } - - // 设置请求头 - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Authorization", "Bearer "+backendModel.Provider.ApiKey) - - // 复制原始请求的其他相关头部 - if userAgent := c.GetHeader("User-Agent"); userAgent != "" { - httpReq.Header.Set("User-Agent", userAgent) - } - - // 执行请求 - client := &http.Client{ - Timeout: 120 * time.Second, - } - resp, err := client.Do(httpReq) + // 转发请求到后端 + resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent")) if err != nil { c.JSON(http.StatusBadGateway, gin.H{ "error": gin.H{ @@ -231,26 +191,18 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) { // 记录响应时间 responseTimestamp := time.Now() - // 从上下文获取API密钥 - apiKeyValue, exists := c.Get("apiKey") - var apiKeyID uint - if exists { - if apiKey, ok := apiKeyValue.(models.APIKey); ok { - apiKeyID = apiKey.ID - } - } + // 获取API Key ID + apiKeyID := getAPIKeyID(c) - // 检查是否是流式请求 + // 处理流式响应 if req.Stream { - // 使用流式响应处理函数 handleStreamingResponse(c, resp, requestTimestamp, responseTimestamp, apiKeyID, req.Model, backendModel, requestTokenCount, string(requestBody), h.DB) return } - // 非流式响应处理 - // 读取响应体 + // 处理非流式响应 responseBody, err := io.ReadAll(resp.Body) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ @@ -262,54 +214,24 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) { return } - // 增加日志记录后端响应 log.Printf("Backend Response Status: %s", resp.Status) log.Printf("Backend Response Body: %s", string(responseBody)) // 计算响应token数 - responseTokenCount := 0 - var responseData map[string]interface{} - if err := json.Unmarshal(responseBody, &responseData); err == nil { - if choices, ok := responseData["choices"].([]interface{}); ok && len(choices) > 0 { - if choice, ok := choices[0].(map[string]interface{}); ok { - if message, ok := choice["message"].(map[string]interface{}); ok { - if content, ok := message["content"].(string); ok { - responseTokenCount = billing.CalculateTextTokensSimple(content) - } - } - } - } - } + responseTokenCount := extractResponseTokenCount(responseBody) // 计算费用 costCalculator := billing.NewCostCalculator() cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount) - // 创建日志记录 - logEntry := &models.RequestLog{ - APIKeyID: apiKeyID, - VirtualModelName: req.Model, - BackendModelName: backendModel.Name, - RequestTimestamp: requestTimestamp, - ResponseTimestamp: responseTimestamp, - RequestTokens: requestTokenCount, - ResponseTokens: responseTokenCount, - Cost: cost, - RequestBody: string(requestBody), - ResponseBody: string(responseBody), - } - - // 异步记录日志 + // 创建并记录日志 + logEntry := createRequestLog(apiKeyID, req.Model, backendModel, + requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount, + cost, string(requestBody), string(responseBody)) logger.LogRequest(h.DB, logEntry) - // 复制响应头 - for key, values := range resp.Header { - for _, value := range values { - c.Header(key, value) - } - } - - // 设置响应状态码并返回响应体 + // 复制响应头并返回响应 + copyResponseHeaders(c, resp) c.Status(resp.StatusCode) c.Writer.Write(responseBody) } @@ -321,11 +243,10 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) { var req ResponsesRequest - // 增加日志记录请求体 + // 读取并解析请求体 bodyBytes, _ := io.ReadAll(c.Request.Body) - c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // 把body重新写回去 + c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - // 解析请求体 if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{ "error": gin.H{ @@ -336,7 +257,7 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) { return } - // 使用tiktoken精确计算请求token数 + // 计算请求token数 messages := convertToTikTokenMessages(req.Input) requestTokenCount := billing.CalculateMessagesTokensSimple(messages) @@ -352,10 +273,9 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) { return } - // 准备转发请求 - // 用body创建一个json对象 - var jsonBody map[string]interface{} - if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil { + // 准备后端请求 + requestBody, err := prepareBackendRequest(bodyBytes, backendModel.Name) + if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": "Failed to process request", @@ -364,40 +284,12 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) { }) return } - // 然后修改jsonBody的model字段 - jsonBody["model"] = backendModel.Name - // 最后重新marshal回requestBody - requestBody, err := json.Marshal(jsonBody) // 构建后端API URL backendURL := backendModel.Provider.BaseURL + "/v1/responses" - // 创建HTTP请求 - httpReq, err := http.NewRequest("POST", backendURL, bytes.NewReader(requestBody)) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to create backend request", - "type": "internal_error", - }, - }) - return - } - - // 设置请求头 - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Authorization", "Bearer "+backendModel.Provider.ApiKey) - - // 复制原始请求的其他相关头部 - if userAgent := c.GetHeader("User-Agent"); userAgent != "" { - httpReq.Header.Set("User-Agent", userAgent) - } - - // 执行请求 - client := &http.Client{ - Timeout: 120 * time.Second, - } - resp, err := client.Do(httpReq) + // 转发请求到后端 + resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent")) if err != nil { c.JSON(http.StatusBadGateway, gin.H{ "error": gin.H{ @@ -425,58 +317,23 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) { } // 计算响应token数 - responseTokenCount := 0 - var responseData map[string]interface{} - if err := json.Unmarshal(responseBody, &responseData); err == nil { - if choices, ok := responseData["choices"].([]interface{}); ok && len(choices) > 0 { - if choice, ok := choices[0].(map[string]interface{}); ok { - if message, ok := choice["message"].(map[string]interface{}); ok { - if content, ok := message["content"].(string); ok { - responseTokenCount = billing.CalculateTextTokensSimple(content) - } - } - } - } - } + responseTokenCount := extractResponseTokenCount(responseBody) // 计算费用 costCalculator := billing.NewCostCalculator() cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount) - // 从上下文获取API密钥 - apiKeyValue, exists := c.Get("apiKey") - var apiKeyID uint - if exists { - if apiKey, ok := apiKeyValue.(models.APIKey); ok { - apiKeyID = apiKey.ID - } - } + // 获取API Key ID + apiKeyID := getAPIKeyID(c) - // 创建日志记录 - logEntry := &models.RequestLog{ - APIKeyID: apiKeyID, - VirtualModelName: req.Model, - BackendModelName: backendModel.Name, - RequestTimestamp: requestTimestamp, - ResponseTimestamp: responseTimestamp, - RequestTokens: requestTokenCount, - ResponseTokens: responseTokenCount, - Cost: cost, - RequestBody: string(requestBody), - ResponseBody: string(responseBody), - } - - // 异步记录日志 + // 创建并记录日志 + logEntry := createRequestLog(apiKeyID, req.Model, backendModel, + requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount, + cost, string(requestBody), string(responseBody)) logger.LogRequest(h.DB, logEntry) - // 复制响应体 - for key, values := range resp.Header { - for _, value := range values { - c.Header(key, value) - } - } - - // 设置响应状态码并返回响应体 + // 复制响应头并返回响应 + copyResponseHeaders(c, resp) c.Status(resp.StatusCode) c.Writer.Write(responseBody) } @@ -493,6 +350,117 @@ func convertToTikTokenMessages(messages []ChatCompletionMessage) []billing.ChatC return result } +// prepareBackendRequest 准备后端请求体,替换model字段 +func prepareBackendRequest(bodyBytes []byte, backendModelName string) ([]byte, error) { + var jsonBody map[string]interface{} + if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil { + return nil, fmt.Errorf("failed to unmarshal request body: %w", err) + } + + // 修改model字段为后端模型名称 + jsonBody["model"] = backendModelName + + // 重新marshal + requestBody, err := json.Marshal(jsonBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request body: %w", err) + } + + return requestBody, nil +} + +// forwardRequest 转发请求到后端API +func forwardRequest(backendURL string, apiKey string, requestBody []byte, userAgent string) (*http.Response, error) { + // 创建HTTP请求 + httpReq, err := http.NewRequest("POST", backendURL, bytes.NewReader(requestBody)) + if err != nil { + return nil, fmt.Errorf("failed to create backend request: %w", err) + } + + // 设置请求头 + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+apiKey) + + // 设置User-Agent + if userAgent != "" { + httpReq.Header.Set("User-Agent", userAgent) + } + + // 执行请求 + client := &http.Client{ + Timeout: 10 * time.Minute, + } + + resp, err := client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("failed to connect to backend service: %w", err) + } + + return resp, nil +} + +// extractResponseTokenCount 从响应体中提取token数量 +func extractResponseTokenCount(responseBody []byte) int { + var responseData map[string]interface{} + if err := json.Unmarshal(responseBody, &responseData); err != nil { + return 0 + } + + if choices, ok := responseData["choices"].([]interface{}); ok && len(choices) > 0 { + if choice, ok := choices[0].(map[string]interface{}); ok { + if message, ok := choice["message"].(map[string]interface{}); ok { + if content, ok := message["content"].(string); ok { + return billing.CalculateTextTokensSimple(content) + } + } + } + } + + return 0 +} + +// createRequestLog 创建请求日志记录 +func createRequestLog(apiKeyID uint, virtualModelName string, backendModel *models.BackendModel, + requestTimestamp, responseTimestamp time.Time, requestTokenCount, responseTokenCount int, + cost float64, requestBody, responseBody string) *models.RequestLog { + + return &models.RequestLog{ + APIKeyID: apiKeyID, + VirtualModelName: virtualModelName, + BackendModelName: backendModel.Name, + RequestTimestamp: requestTimestamp, + ResponseTimestamp: responseTimestamp, + RequestTokens: requestTokenCount, + ResponseTokens: responseTokenCount, + Cost: cost, + RequestBody: requestBody, + ResponseBody: responseBody, + } +} + +// getAPIKeyID 从gin上下文中提取API Key ID +func getAPIKeyID(c *gin.Context) uint { + apiKeyValue, exists := c.Get("apiKey") + if !exists { + return 0 + } + + if apiKey, ok := apiKeyValue.(models.APIKey); ok { + return apiKey.ID + } + + return 0 +} + +// copyResponseHeaders 复制响应头到gin context +func copyResponseHeaders(c *gin.Context, resp *http.Response) { + for key, values := range resp.Header { + for _, value := range values { + c.Header(key, value) + } + } +} + // handleStreamingResponse 处理流式响应 func handleStreamingResponse(c *gin.Context, resp *http.Response, requestTimestamp, responseTimestamp time.Time, apiKeyID uint, virtualModelName string, backendModel *models.BackendModel, requestTokenCount int,