Files
AIRouter/backend/api/openai_handlers.go

427 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package api
import (
"ai-gateway/internal/billing"
"ai-gateway/internal/logger"
"ai-gateway/internal/models"
"ai-gateway/internal/router"
"bufio"
"encoding/json"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// ListModels 处理 GET /models 请求
func (h *APIHandler) ListModels(c *gin.Context) {
var virtualModels []models.VirtualModel
// 查询所有虚拟模型
if err := h.DB.Find(&virtualModels).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to retrieve models",
"type": "internal_error",
},
})
return
}
// 格式化为OpenAI API响应格式
modelData := make([]ModelData, len(virtualModels))
for i, vm := range virtualModels {
modelData[i] = ModelData{
ID: vm.Name,
Object: "model",
Created: vm.CreatedAt.Unix(),
OwnedBy: "ai-gateway",
}
}
response := ModelListResponse{
Object: "list",
Data: modelData,
}
c.JSON(http.StatusOK, response)
}
// ChatCompletions 处理 POST /v1/chat/completions 请求
func (h *APIHandler) ChatCompletions(c *gin.Context) {
// 记录请求开始时间
requestTimestamp := time.Now()
// 读取请求体
bodyBytes, _ := io.ReadAll(c.Request.Body)
// 解析为 map只解析一次
var jsonBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
log.Printf("Failed to unmarshal JSON: %v", err)
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "Invalid request format",
"type": "invalid_request_error",
},
})
return
}
// 从 map 中提取需要的字段
modelName, _ := jsonBody["model"].(string)
isStream, _ := jsonBody["stream"].(bool)
// 提取 messages 并转换为计算 token 所需的格式
var messages []billing.ChatCompletionMessage
if messagesRaw, ok := jsonBody["messages"].([]interface{}); ok {
messages = make([]billing.ChatCompletionMessage, 0, len(messagesRaw))
for _, msgRaw := range messagesRaw {
if msgMap, ok := msgRaw.(map[string]interface{}); ok {
role, _ := msgMap["role"].(string)
content, _ := msgMap["content"].(string)
messages = append(messages, billing.ChatCompletionMessage{
Role: role,
Content: content,
})
}
}
}
// 计算请求token数
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
// 选择后端模型
backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "invalid_request_error",
},
})
return
}
// 修改 model 字段为后端模型名称
jsonBody["model"] = backendModel.Name
// Marshal 请求体用于转发
requestBody, err := json.Marshal(jsonBody)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to process request",
"type": "internal_error",
},
})
return
}
// 构建后端API URL
backendURL := backendModel.Provider.BaseURL + "/v1/chat/completions"
// 转发请求到后端
resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent"))
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"message": "Failed to connect to backend service",
"type": "service_unavailable",
},
})
return
}
defer resp.Body.Close()
// 记录响应时间
responseTimestamp := time.Now()
// 获取API Key ID
apiKeyID := getAPIKeyID(c)
// 处理流式响应
if isStream {
handleStreamingResponse(c, resp, requestTimestamp, responseTimestamp,
apiKeyID, modelName, backendModel, requestTokenCount,
string(requestBody), h.DB)
return
}
// 处理非流式响应
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to read backend response",
"type": "internal_error",
},
})
return
}
log.Printf("Backend Response Status: %s", resp.Status)
log.Printf("Backend Response Body: %s", string(responseBody))
// 计算响应token数
responseTokenCount := extractResponseTokenCount(responseBody)
// 计算费用
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 创建并记录日志
logEntry := createRequestLog(apiKeyID, modelName, backendModel,
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
cost, string(requestBody), string(responseBody))
logger.LogRequest(h.DB, logEntry)
// 复制响应头并返回响应
copyResponseHeaders(c, resp)
c.Status(resp.StatusCode)
c.Writer.Write(responseBody)
}
// ResponsesCompletions 处理 POST /v1/responses 请求
func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
// 记录请求开始时间
requestTimestamp := time.Now()
// 读取请求体
bodyBytes, _ := io.ReadAll(c.Request.Body)
// 解析为 map只解析一次
var jsonBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "Invalid request format",
"type": "invalid_request_error",
},
})
return
}
// 从 map 中提取需要的字段
modelName, _ := jsonBody["model"].(string)
// 提取 input 并转换为计算 token 所需的格式
var messages []billing.ChatCompletionMessage
if inputRaw, ok := jsonBody["input"].([]interface{}); ok {
messages = make([]billing.ChatCompletionMessage, 0, len(inputRaw))
for _, msgRaw := range inputRaw {
if msgMap, ok := msgRaw.(map[string]interface{}); ok {
role, _ := msgMap["role"].(string)
content, _ := msgMap["content"].(string)
messages = append(messages, billing.ChatCompletionMessage{
Role: role,
Content: content,
})
}
}
}
// 计算请求token数
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
// 选择后端模型
backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "invalid_request_error",
},
})
return
}
// 修改 model 字段为后端模型名称
jsonBody["model"] = backendModel.Name
// Marshal 请求体用于转发
requestBody, err := json.Marshal(jsonBody)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to process request",
"type": "internal_error",
},
})
return
}
// 构建后端API URL
backendURL := backendModel.Provider.BaseURL + "/v1/responses"
// 转发请求到后端
resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent"))
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"message": "Failed to connect to backend service",
"type": "service_unavailable",
},
})
return
}
defer resp.Body.Close()
// 记录响应时间
responseTimestamp := time.Now()
// 读取响应体
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to read backend response",
"type": "internal_error",
},
})
return
}
// 计算响应token数
responseTokenCount := extractResponseTokenCount(responseBody)
// 计算费用
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 获取API Key ID
apiKeyID := getAPIKeyID(c)
// 创建并记录日志
logEntry := createRequestLog(apiKeyID, modelName, backendModel,
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
cost, string(requestBody), string(responseBody))
logger.LogRequest(h.DB, logEntry)
// 复制响应头并返回响应
copyResponseHeaders(c, resp)
c.Status(resp.StatusCode)
c.Writer.Write(responseBody)
}
// handleStreamingResponse 处理流式响应
func handleStreamingResponse(c *gin.Context, resp *http.Response, requestTimestamp, responseTimestamp time.Time,
apiKeyID uint, virtualModelName string, backendModel *models.BackendModel, requestTokenCount int,
requestBody string, database *gorm.DB) {
// 设置流式响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Transfer-Encoding", "chunked")
// 复制其他响应头
for key, values := range resp.Header {
if key != "Content-Length" && key != "Transfer-Encoding" {
for _, value := range values {
c.Header(key, value)
}
}
}
c.Status(resp.StatusCode)
// 用于累积完整响应内容
var fullContent strings.Builder
var responseBody strings.Builder
// 创建一个 scanner 来逐行读取流
scanner := bufio.NewScanner(resp.Body)
flusher, ok := c.Writer.(http.Flusher)
for scanner.Scan() {
line := scanner.Text()
// 将原始行写入响应体缓冲区
responseBody.WriteString(line)
responseBody.WriteString("\n")
// SSE 格式data: {...}
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
// 检查是否是结束标记
if data == "[DONE]" {
_, err := c.Writer.Write([]byte(line + "\n\n"))
if err != nil {
return
}
if ok {
flusher.Flush()
}
break
}
// 解析 JSON 数据以提取内容
var chunk map[string]interface{}
if err := json.Unmarshal([]byte(data), &chunk); err == nil {
if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 {
if choice, ok := choices[0].(map[string]interface{}); ok {
if delta, ok := choice["delta"].(map[string]interface{}); ok {
if content, ok := delta["content"].(string); ok {
fullContent.WriteString(content)
}
}
}
}
}
}
// 转发数据到客户端
_, err := c.Writer.Write([]byte(line + "\n"))
if err != nil {
return
}
// 如果是空行SSE 消息分隔符),刷新
if line == "" {
if ok {
flusher.Flush()
}
}
}
// 确保发送最后的数据
if ok {
flusher.Flush()
}
// 扫描可能出现的错误
if err := scanner.Err(); err != nil {
log.Printf("Error reading stream: %v", err)
}
// 计算响应 token 数
responseTokenCount := billing.CalculateTextTokensSimple(fullContent.String())
// 计算费用
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 创建日志记录
logEntry := &models.RequestLog{
APIKeyID: apiKeyID,
VirtualModelName: virtualModelName,
BackendModelName: backendModel.Name,
RequestTimestamp: requestTimestamp,
ResponseTimestamp: time.Now(), // 使用实际结束时间
RequestTokens: requestTokenCount,
ResponseTokens: responseTokenCount,
Cost: cost,
RequestBody: requestBody,
ResponseBody: responseBody.String(),
}
// 异步记录日志
logger.LogRequest(database, logEntry)
}