embeddings请求
This commit is contained in:
@@ -304,3 +304,156 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
|
||||
c.Status(resp.StatusCode)
|
||||
c.Writer.Write(responseBody)
|
||||
}
|
||||
|
||||
// EmbeddingsRequest embeddings请求结构
|
||||
type EmbeddingsRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input interface{} `json:"input"` // 可以是字符串或字符串数组
|
||||
EncodingFormat string `json:"encoding_format,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
// Embeddings 处理 POST /v1/embeddings 请求
|
||||
func (h *APIHandler) Embeddings(c *gin.Context) {
|
||||
// 记录请求开始时间
|
||||
requestTimestamp := time.Now()
|
||||
|
||||
// 读取请求体
|
||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||
|
||||
// 解析为 map,只解析一次
|
||||
var jsonBody map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
|
||||
log.Printf("Failed to unmarshal JSON: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Invalid request format",
|
||||
"type": "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 从 map 中提取需要的字段
|
||||
modelName, _ := jsonBody["model"].(string)
|
||||
|
||||
// 计算输入文本的 token 数
|
||||
var requestTokenCount int
|
||||
if input, ok := jsonBody["input"]; ok {
|
||||
switch v := input.(type) {
|
||||
case string:
|
||||
// 单个字符串
|
||||
requestTokenCount = billing.CalculateTextTokensSimple(v)
|
||||
case []interface{}:
|
||||
// 字符串数组
|
||||
for _, item := range v {
|
||||
if str, ok := item.(string); ok {
|
||||
requestTokenCount += billing.CalculateTextTokensSimple(str)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 选择后端模型
|
||||
backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
"type": "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 修改 model 字段为后端模型名称
|
||||
jsonBody["model"] = backendModel.Name
|
||||
|
||||
// Marshal 请求体用于转发
|
||||
requestBody, err := json.Marshal(jsonBody)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to process request",
|
||||
"type": "internal_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 构建后端API URL
|
||||
backendURL := backendModel.Provider.BaseURL + "/v1/embeddings"
|
||||
|
||||
// 转发请求到后端
|
||||
resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to connect to backend service",
|
||||
"type": "service_unavailable",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 记录响应时间
|
||||
responseTimestamp := time.Now()
|
||||
|
||||
// 读取响应体
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to read backend response",
|
||||
"type": "internal_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Backend Response Status: %s", resp.Status)
|
||||
log.Printf("Backend Response Body: %s", string(responseBody))
|
||||
|
||||
// 从响应中提取 token 使用量
|
||||
responseTokenCount := extractEmbeddingsTokenCount(responseBody, requestTokenCount)
|
||||
|
||||
// 计算费用 - embeddings 通常只计算输入 token
|
||||
costCalculator := billing.NewCostCalculator()
|
||||
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, 0)
|
||||
|
||||
// 获取API Key ID
|
||||
apiKeyID := getAPIKeyID(c)
|
||||
|
||||
// 创建并记录日志
|
||||
logEntry := createRequestLog(apiKeyID, modelName, backendModel,
|
||||
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
|
||||
cost, string(requestBody), string(responseBody))
|
||||
logger.LogRequest(h.DB, logEntry)
|
||||
|
||||
// 复制响应头并返回响应
|
||||
copyResponseHeaders(c, resp)
|
||||
c.Status(resp.StatusCode)
|
||||
c.Writer.Write(responseBody)
|
||||
}
|
||||
|
||||
// extractEmbeddingsTokenCount 从 embeddings 响应体中提取 token 数量
|
||||
func extractEmbeddingsTokenCount(responseBody []byte, fallbackCount int) int {
|
||||
var responseData map[string]interface{}
|
||||
if err := json.Unmarshal(responseBody, &responseData); err != nil {
|
||||
return fallbackCount
|
||||
}
|
||||
|
||||
// 尝试从 usage 字段提取
|
||||
if usage, ok := responseData["usage"].(map[string]interface{}); ok {
|
||||
if totalTokens, ok := usage["total_tokens"].(float64); ok {
|
||||
return int(totalTokens)
|
||||
}
|
||||
if promptTokens, ok := usage["prompt_tokens"].(float64); ok {
|
||||
return int(promptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有 usage 字段,返回请求时计算的值
|
||||
return fallbackCount
|
||||
}
|
||||
|
||||
Binary file not shown.
@@ -57,6 +57,7 @@ func main() {
|
||||
protected.GET("/models", handler.ListModels)
|
||||
protected.POST("/chat/completions", handler.ChatCompletions)
|
||||
protected.POST("/responses", handler.ResponsesCompletions)
|
||||
protected.POST("/embeddings", handler.Embeddings)
|
||||
}
|
||||
|
||||
// 创建API管理路由组
|
||||
|
||||
Reference in New Issue
Block a user