diff --git a/backend/api/openai_handlers.go b/backend/api/openai_handlers.go index 6ed32f1..2af18b3 100644 --- a/backend/api/openai_handlers.go +++ b/backend/api/openai_handlers.go @@ -304,3 +304,156 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) { c.Status(resp.StatusCode) c.Writer.Write(responseBody) } + +// EmbeddingsRequest embeddings请求结构 +type EmbeddingsRequest struct { + Model string `json:"model"` + Input interface{} `json:"input"` // 可以是字符串或字符串数组 + EncodingFormat string `json:"encoding_format,omitempty"` + User string `json:"user,omitempty"` +} + +// Embeddings 处理 POST /v1/embeddings 请求 +func (h *APIHandler) Embeddings(c *gin.Context) { + // 记录请求开始时间 + requestTimestamp := time.Now() + + // 读取请求体 + bodyBytes, _ := io.ReadAll(c.Request.Body) + + // 解析为 map,只解析一次 + var jsonBody map[string]interface{} + if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil { + log.Printf("Failed to unmarshal JSON: %v", err) + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": "Invalid request format", + "type": "invalid_request_error", + }, + }) + return + } + + // 从 map 中提取需要的字段 + modelName, _ := jsonBody["model"].(string) + + // 计算输入文本的 token 数 + var requestTokenCount int + if input, ok := jsonBody["input"]; ok { + switch v := input.(type) { + case string: + // 单个字符串 + requestTokenCount = billing.CalculateTextTokensSimple(v) + case []interface{}: + // 字符串数组 + for _, item := range v { + if str, ok := item.(string); ok { + requestTokenCount += billing.CalculateTextTokensSimple(str) + } + } + } + } + + // 选择后端模型 + backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{ + "error": gin.H{ + "message": err.Error(), + "type": "invalid_request_error", + }, + }) + return + } + + // 修改 model 字段为后端模型名称 + jsonBody["model"] = backendModel.Name + + // Marshal 请求体用于转发 + requestBody, err := json.Marshal(jsonBody) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": "Failed to process request", + "type": "internal_error", + }, + }) + return + } + + // 构建后端API URL + backendURL := backendModel.Provider.BaseURL + "/v1/embeddings" + + // 转发请求到后端 + resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent")) + if err != nil { + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "message": "Failed to connect to backend service", + "type": "service_unavailable", + }, + }) + return + } + defer resp.Body.Close() + + // 记录响应时间 + responseTimestamp := time.Now() + + // 读取响应体 + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": "Failed to read backend response", + "type": "internal_error", + }, + }) + return + } + + log.Printf("Backend Response Status: %s", resp.Status) + log.Printf("Backend Response Body: %s", string(responseBody)) + + // 从响应中提取 token 使用量 + responseTokenCount := extractEmbeddingsTokenCount(responseBody, requestTokenCount) + + // 计算费用 - embeddings 通常只计算输入 token + costCalculator := billing.NewCostCalculator() + cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, 0) + + // 获取API Key ID + apiKeyID := getAPIKeyID(c) + + // 创建并记录日志 + logEntry := createRequestLog(apiKeyID, modelName, backendModel, + requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount, + cost, string(requestBody), string(responseBody)) + logger.LogRequest(h.DB, logEntry) + + // 复制响应头并返回响应 + copyResponseHeaders(c, resp) + c.Status(resp.StatusCode) + c.Writer.Write(responseBody) +} + +// extractEmbeddingsTokenCount 从 embeddings 响应体中提取 token 数量 +func extractEmbeddingsTokenCount(responseBody []byte, fallbackCount int) int { + var responseData map[string]interface{} + if err := json.Unmarshal(responseBody, &responseData); err != nil { + return fallbackCount + } + + // 尝试从 usage 字段提取 + if usage, ok := responseData["usage"].(map[string]interface{}); ok { + if totalTokens, ok := usage["total_tokens"].(float64); ok { + return int(totalTokens) + } + if promptTokens, ok := usage["prompt_tokens"].(float64); ok { + return int(promptTokens) + } + } + + // 如果没有 usage 字段,返回请求时计算的值 + return fallbackCount +} diff --git a/backend/gateway.db b/backend/gateway.db index 0d1823a..0bfbfc1 100644 Binary files a/backend/gateway.db and b/backend/gateway.db differ diff --git a/backend/main.go b/backend/main.go index 03fb760..9df5ece 100644 --- a/backend/main.go +++ b/backend/main.go @@ -57,6 +57,7 @@ func main() { protected.GET("/models", handler.ListModels) protected.POST("/chat/completions", handler.ChatCompletions) protected.POST("/responses", handler.ResponsesCompletions) + protected.POST("/embeddings", handler.Embeddings) } // 创建API管理路由组