Files
AIRouter/backend/api/openai_handlers.go
2025-11-11 01:27:45 +08:00

460 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package api
import (
"ai-gateway/internal/billing"
"ai-gateway/internal/logger"
"ai-gateway/internal/models"
"ai-gateway/internal/router"
"encoding/json"
"io"
"log"
"net/http"
"time"
"github.com/gin-gonic/gin"
)
// ListModels 处理 GET /models 请求
func (h *APIHandler) ListModels(c *gin.Context) {
var virtualModels []models.VirtualModel
// 查询所有虚拟模型
if err := h.DB.Find(&virtualModels).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to retrieve models",
"type": "internal_error",
},
})
return
}
// 格式化为OpenAI API响应格式
modelData := make([]ModelData, len(virtualModels))
for i, vm := range virtualModels {
modelData[i] = ModelData{
ID: vm.Name,
Object: "model",
Created: vm.CreatedAt.Unix(),
OwnedBy: "ai-gateway",
}
}
response := ModelListResponse{
Object: "list",
Data: modelData,
}
c.JSON(http.StatusOK, response)
}
// ChatCompletions 处理 POST /v1/chat/completions 请求
func (h *APIHandler) ChatCompletions(c *gin.Context) {
// 记录请求开始时间
requestTimestamp := time.Now()
// 读取请求体
bodyBytes, _ := io.ReadAll(c.Request.Body)
// 解析为 map只解析一次
var jsonBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
log.Printf("Failed to unmarshal JSON: %v", err)
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "Invalid request format",
"type": "invalid_request_error",
},
})
return
}
// 从 map 中提取需要的字段
modelName, _ := jsonBody["model"].(string)
isStream, _ := jsonBody["stream"].(bool)
// 提取 messages 并转换为计算 token 所需的格式
var messages []billing.ChatCompletionMessage
if messagesRaw, ok := jsonBody["messages"].([]interface{}); ok {
messages = make([]billing.ChatCompletionMessage, 0, len(messagesRaw))
for _, msgRaw := range messagesRaw {
if msgMap, ok := msgRaw.(map[string]interface{}); ok {
role, _ := msgMap["role"].(string)
content, _ := msgMap["content"].(string)
messages = append(messages, billing.ChatCompletionMessage{
Role: role,
Content: content,
})
}
}
}
// 计算请求token数
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
// 选择后端模型
backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "invalid_request_error",
},
})
return
}
// 修改 model 字段为后端模型名称
jsonBody["model"] = backendModel.Name
// Marshal 请求体用于转发
requestBody, err := json.Marshal(jsonBody)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to process request",
"type": "internal_error",
},
})
return
}
// 构建后端API URL
backendURL := backendModel.Provider.BaseURL + "/v1/chat/completions"
// 转发请求到后端
resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent"))
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"message": "Failed to connect to backend service",
"type": "service_unavailable",
},
})
return
}
defer resp.Body.Close()
// 记录响应时间
responseTimestamp := time.Now()
// 获取API Key ID
apiKeyID := getAPIKeyID(c)
// 处理流式响应
if isStream {
handleStreamingResponse(c, resp, requestTimestamp, responseTimestamp,
apiKeyID, modelName, backendModel, requestTokenCount,
string(requestBody), h.DB)
return
}
// 处理非流式响应
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to read backend response",
"type": "internal_error",
},
})
return
}
log.Printf("Backend Response Status: %s", resp.Status)
log.Printf("Backend Response Body: %s", string(responseBody))
// 计算响应token数
responseTokenCount := extractResponseTokenCount(responseBody)
// 计算费用
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 创建并记录日志
logEntry := createRequestLog(apiKeyID, modelName, backendModel,
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
cost, string(requestBody), string(responseBody))
logger.LogRequest(h.DB, logEntry)
// 复制响应头并返回响应
copyResponseHeaders(c, resp)
c.Status(resp.StatusCode)
c.Writer.Write(responseBody)
}
// ResponsesCompletions 处理 POST /v1/responses 请求
func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
// 记录请求开始时间
requestTimestamp := time.Now()
// 读取请求体
bodyBytes, _ := io.ReadAll(c.Request.Body)
// 解析为 map只解析一次
var jsonBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "Invalid request format",
"type": "invalid_request_error",
},
})
return
}
// 从 map 中提取需要的字段
modelName, _ := jsonBody["model"].(string)
// 提取 input 并转换为计算 token 所需的格式
var messages []billing.ChatCompletionMessage
if inputRaw, ok := jsonBody["input"].([]interface{}); ok {
messages = make([]billing.ChatCompletionMessage, 0, len(inputRaw))
for _, msgRaw := range inputRaw {
if msgMap, ok := msgRaw.(map[string]interface{}); ok {
role, _ := msgMap["role"].(string)
content, _ := msgMap["content"].(string)
messages = append(messages, billing.ChatCompletionMessage{
Role: role,
Content: content,
})
}
}
}
// 计算请求token数
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
// 选择后端模型
backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "invalid_request_error",
},
})
return
}
// 修改 model 字段为后端模型名称
jsonBody["model"] = backendModel.Name
// Marshal 请求体用于转发
requestBody, err := json.Marshal(jsonBody)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to process request",
"type": "internal_error",
},
})
return
}
// 构建后端API URL
backendURL := backendModel.Provider.BaseURL + "/v1/responses"
// 转发请求到后端
resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent"))
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"message": "Failed to connect to backend service",
"type": "service_unavailable",
},
})
return
}
defer resp.Body.Close()
// 记录响应时间
responseTimestamp := time.Now()
// 读取响应体
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to read backend response",
"type": "internal_error",
},
})
return
}
// 计算响应token数
responseTokenCount := extractResponseTokenCount(responseBody)
// 计算费用
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 获取API Key ID
apiKeyID := getAPIKeyID(c)
// 创建并记录日志
logEntry := createRequestLog(apiKeyID, modelName, backendModel,
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
cost, string(requestBody), string(responseBody))
logger.LogRequest(h.DB, logEntry)
// 复制响应头并返回响应
copyResponseHeaders(c, resp)
c.Status(resp.StatusCode)
c.Writer.Write(responseBody)
}
// EmbeddingsRequest embeddings请求结构
type EmbeddingsRequest struct {
Model string `json:"model"`
Input interface{} `json:"input"` // 可以是字符串或字符串数组
EncodingFormat string `json:"encoding_format,omitempty"`
User string `json:"user,omitempty"`
}
// Embeddings 处理 POST /v1/embeddings 请求
func (h *APIHandler) Embeddings(c *gin.Context) {
// 记录请求开始时间
requestTimestamp := time.Now()
// 读取请求体
bodyBytes, _ := io.ReadAll(c.Request.Body)
// 解析为 map只解析一次
var jsonBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
log.Printf("Failed to unmarshal JSON: %v", err)
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "Invalid request format",
"type": "invalid_request_error",
},
})
return
}
// 从 map 中提取需要的字段
modelName, _ := jsonBody["model"].(string)
// 计算输入文本的 token 数
var requestTokenCount int
if input, ok := jsonBody["input"]; ok {
switch v := input.(type) {
case string:
// 单个字符串
requestTokenCount = billing.CalculateTextTokensSimple(v)
case []interface{}:
// 字符串数组
for _, item := range v {
if str, ok := item.(string); ok {
requestTokenCount += billing.CalculateTextTokensSimple(str)
}
}
}
}
// 选择后端模型
backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "invalid_request_error",
},
})
return
}
// 修改 model 字段为后端模型名称
jsonBody["model"] = backendModel.Name
// Marshal 请求体用于转发
requestBody, err := json.Marshal(jsonBody)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to process request",
"type": "internal_error",
},
})
return
}
// 构建后端API URL
backendURL := backendModel.Provider.BaseURL + "/v1/embeddings"
// 转发请求到后端
resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent"))
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"message": "Failed to connect to backend service",
"type": "service_unavailable",
},
})
return
}
defer resp.Body.Close()
// 记录响应时间
responseTimestamp := time.Now()
// 读取响应体
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to read backend response",
"type": "internal_error",
},
})
return
}
log.Printf("Backend Response Status: %s", resp.Status)
log.Printf("Backend Response Body: %s", string(responseBody))
// 从响应中提取 token 使用量
responseTokenCount := extractEmbeddingsTokenCount(responseBody, requestTokenCount)
// 计算费用 - embeddings 通常只计算输入 token
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, 0)
// 获取API Key ID
apiKeyID := getAPIKeyID(c)
// 创建并记录日志
logEntry := createRequestLog(apiKeyID, modelName, backendModel,
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
cost, string(requestBody), string(responseBody))
logger.LogRequest(h.DB, logEntry)
// 复制响应头并返回响应
copyResponseHeaders(c, resp)
c.Status(resp.StatusCode)
c.Writer.Write(responseBody)
}
// extractEmbeddingsTokenCount 从 embeddings 响应体中提取 token 数量
func extractEmbeddingsTokenCount(responseBody []byte, fallbackCount int) int {
var responseData map[string]interface{}
if err := json.Unmarshal(responseBody, &responseData); err != nil {
return fallbackCount
}
// 尝试从 usage 字段提取
if usage, ok := responseData["usage"].(map[string]interface{}); ok {
if totalTokens, ok := usage["total_tokens"].(float64); ok {
return int(totalTokens)
}
if promptTokens, ok := usage["prompt_tokens"].(float64); ok {
return int(promptTokens)
}
}
// 如果没有 usage 字段,返回请求时计算的值
return fallbackCount
}