embeddings请求

This commit is contained in:
2025-11-11 01:27:45 +08:00
parent 4eab8dea65
commit 806b5c6846
3 changed files with 154 additions and 0 deletions

View File

@@ -304,3 +304,156 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
c.Status(resp.StatusCode)
c.Writer.Write(responseBody)
}
// EmbeddingsRequest embeddings请求结构
type EmbeddingsRequest struct {
Model string `json:"model"`
Input interface{} `json:"input"` // 可以是字符串或字符串数组
EncodingFormat string `json:"encoding_format,omitempty"`
User string `json:"user,omitempty"`
}
// Embeddings 处理 POST /v1/embeddings 请求
func (h *APIHandler) Embeddings(c *gin.Context) {
// 记录请求开始时间
requestTimestamp := time.Now()
// 读取请求体
bodyBytes, _ := io.ReadAll(c.Request.Body)
// 解析为 map只解析一次
var jsonBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
log.Printf("Failed to unmarshal JSON: %v", err)
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "Invalid request format",
"type": "invalid_request_error",
},
})
return
}
// 从 map 中提取需要的字段
modelName, _ := jsonBody["model"].(string)
// 计算输入文本的 token 数
var requestTokenCount int
if input, ok := jsonBody["input"]; ok {
switch v := input.(type) {
case string:
// 单个字符串
requestTokenCount = billing.CalculateTextTokensSimple(v)
case []interface{}:
// 字符串数组
for _, item := range v {
if str, ok := item.(string); ok {
requestTokenCount += billing.CalculateTextTokensSimple(str)
}
}
}
}
// 选择后端模型
backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "invalid_request_error",
},
})
return
}
// 修改 model 字段为后端模型名称
jsonBody["model"] = backendModel.Name
// Marshal 请求体用于转发
requestBody, err := json.Marshal(jsonBody)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to process request",
"type": "internal_error",
},
})
return
}
// 构建后端API URL
backendURL := backendModel.Provider.BaseURL + "/v1/embeddings"
// 转发请求到后端
resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent"))
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"message": "Failed to connect to backend service",
"type": "service_unavailable",
},
})
return
}
defer resp.Body.Close()
// 记录响应时间
responseTimestamp := time.Now()
// 读取响应体
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to read backend response",
"type": "internal_error",
},
})
return
}
log.Printf("Backend Response Status: %s", resp.Status)
log.Printf("Backend Response Body: %s", string(responseBody))
// 从响应中提取 token 使用量
responseTokenCount := extractEmbeddingsTokenCount(responseBody, requestTokenCount)
// 计算费用 - embeddings 通常只计算输入 token
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, 0)
// 获取API Key ID
apiKeyID := getAPIKeyID(c)
// 创建并记录日志
logEntry := createRequestLog(apiKeyID, modelName, backendModel,
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
cost, string(requestBody), string(responseBody))
logger.LogRequest(h.DB, logEntry)
// 复制响应头并返回响应
copyResponseHeaders(c, resp)
c.Status(resp.StatusCode)
c.Writer.Write(responseBody)
}
// extractEmbeddingsTokenCount 从 embeddings 响应体中提取 token 数量
func extractEmbeddingsTokenCount(responseBody []byte, fallbackCount int) int {
var responseData map[string]interface{}
if err := json.Unmarshal(responseBody, &responseData); err != nil {
return fallbackCount
}
// 尝试从 usage 字段提取
if usage, ok := responseData["usage"].(map[string]interface{}); ok {
if totalTokens, ok := usage["total_tokens"].(float64); ok {
return int(totalTokens)
}
if promptTokens, ok := usage["prompt_tokens"].(float64); ok {
return int(promptTokens)
}
}
// 如果没有 usage 字段,返回请求时计算的值
return fallbackCount
}

Binary file not shown.

View File

@@ -57,6 +57,7 @@ func main() {
protected.GET("/models", handler.ListModels)
protected.POST("/chat/completions", handler.ChatCompletions)
protected.POST("/responses", handler.ResponsesCompletions)
protected.POST("/embeddings", handler.Embeddings)
}
// 创建API管理路由组