diff --git a/backend/api/handlers.go b/backend/api/handlers.go index 9f59c10..fbd99ad 100644 --- a/backend/api/handlers.go +++ b/backend/api/handlers.go @@ -5,12 +5,14 @@ import ( "ai-gateway/internal/logger" "ai-gateway/internal/models" "ai-gateway/internal/router" + "bufio" "bytes" "encoding/json" "fmt" "io" "log" "net/http" + "strings" "time" "github.com/gin-gonic/gin" @@ -228,6 +230,25 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) { // 记录响应时间 responseTimestamp := time.Now() + // 从上下文获取API密钥 + apiKeyValue, exists := c.Get("apiKey") + var apiKeyID uint + if exists { + if apiKey, ok := apiKeyValue.(models.APIKey); ok { + apiKeyID = apiKey.ID + } + } + + // 检查是否是流式请求 + if req.Stream { + // 使用流式响应处理函数 + handleStreamingResponse(c, resp, requestTimestamp, responseTimestamp, + apiKeyID, req.Model, backendModel, requestTokenCount, + string(requestBody), h.DB) + return + } + + // 非流式响应处理 // 读取响应体 responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -268,15 +289,6 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) { cost = backendModel.FixedPrice } - // 从上下文获取API密钥 - apiKeyValue, exists := c.Get("apiKey") - var apiKeyID uint - if exists { - if apiKey, ok := apiKeyValue.(models.APIKey); ok { - apiKeyID = apiKey.ID - } - } - // 创建日志记录 logEntry := &models.RequestLog{ APIKeyID: apiKeyID, @@ -502,6 +514,128 @@ func calculateTokenCountFromText(text string) int { return len(encoding.Encode(text, nil, nil)) } +// handleStreamingResponse 处理流式响应 +func handleStreamingResponse(c *gin.Context, resp *http.Response, requestTimestamp, responseTimestamp time.Time, + apiKeyID uint, virtualModelName string, backendModel *models.BackendModel, requestTokenCount int, + requestBody string, database *gorm.DB) { + + // 设置流式响应\u5934 + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Transfer-Encoding", "chunked") + + // 复制其他响应头 + for key, values := range resp.Header { + if key != "Content-Length" && key != "Transfer-Encoding" { + for _, value := range values { + c.Header(key, value) + } + } + } + + c.Status(resp.StatusCode) + + // 用于累积完整响应内容 + var fullContent strings.Builder + var responseBody strings.Builder + + // 创建一个 scanner 来逐行读取流 + scanner := bufio.NewScanner(resp.Body) + flusher, ok := c.Writer.(http.Flusher) + + for scanner.Scan() { + line := scanner.Text() + + // 将原始行写入响应体缓冲区 + responseBody.WriteString(line) + responseBody.WriteString("\n") + + // SSE 格式:data: {...} + if strings.HasPrefix(line, "data: ") { + data := strings.TrimPrefix(line, "data: ") + + // 检查是否是结束标记 + if data == "[DONE]" { + _, err := c.Writer.Write([]byte(line + "\n\n")) + if err != nil { + return + } + if ok { + flusher.Flush() + } + break + } + + // 解析 JSON 数据以提取内容 + var chunk map[string]interface{} + if err := json.Unmarshal([]byte(data), &chunk); err == nil { + if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 { + if choice, ok := choices[0].(map[string]interface{}); ok { + if delta, ok := choice["delta"].(map[string]interface{}); ok { + if content, ok := delta["content"].(string); ok { + fullContent.WriteString(content) + } + } + } + } + } + } + + // 转发数据到客户端 + _, err := c.Writer.Write([]byte(line + "\n")) + if err != nil { + return + } + + // 如果是空行(SSE 消息分隔符),刷新 + if line == "" { + if ok { + flusher.Flush() + } + } + } + + // 确保发送最后的数据 + if ok { + flusher.Flush() + } + + // 扫描可能出现的错误 + if err := scanner.Err(); err != nil { + log.Printf("Error reading stream: %v", err) + } + + // 计算响应 token 数 + responseTokenCount := calculateTokenCountFromText(fullContent.String()) + + // 计算费用 + var cost float64 + switch backendModel.BillingMethod { + case models.BillingMethodToken: + cost = float64(requestTokenCount)*backendModel.PromptTokenPrice + float64(responseTokenCount)*backendModel.CompletionTokenPrice + case models.BillingMethodRequest: + cost = backendModel.FixedPrice + } + + // 创建日志记录 + logEntry := &models.RequestLog{ + APIKeyID: apiKeyID, + VirtualModelName: virtualModelName, + BackendModelName: backendModel.Name, + RequestTimestamp: requestTimestamp, + ResponseTimestamp: time.Now(), // 使用实际结束时间 + RequestTokens: requestTokenCount, + ResponseTokens: responseTokenCount, + Cost: cost, + RequestBody: requestBody, + ResponseBody: responseBody.String(), + } + + // 异步记录日志 + logger.LogRequest(database, logEntry) +} + // GetProvidersHandler 处理 GET /api/providers 请求 func (h *APIHandler) GetProvidersHandler(c *gin.Context) { providers, err := db.GetProviders(h.DB)