diff --git a/backend/api/openai_handlers.go b/backend/api/openai_handlers.go index d80094a..6ed32f1 100644 --- a/backend/api/openai_handlers.go +++ b/backend/api/openai_handlers.go @@ -5,16 +5,13 @@ import ( "ai-gateway/internal/logger" "ai-gateway/internal/models" "ai-gateway/internal/router" - "bufio" "encoding/json" "io" "log" "net/http" - "strings" "time" "github.com/gin-gonic/gin" - "gorm.io/gorm" ) // ListModels 处理 GET /models 请求 @@ -307,120 +304,3 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) { c.Status(resp.StatusCode) c.Writer.Write(responseBody) } - -// handleStreamingResponse 处理流式响应 -func handleStreamingResponse(c *gin.Context, resp *http.Response, requestTimestamp, responseTimestamp time.Time, - apiKeyID uint, virtualModelName string, backendModel *models.BackendModel, requestTokenCount int, - requestBody string, database *gorm.DB) { - - // 设置流式响应头 - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Transfer-Encoding", "chunked") - - // 复制其他响应头 - for key, values := range resp.Header { - if key != "Content-Length" && key != "Transfer-Encoding" { - for _, value := range values { - c.Header(key, value) - } - } - } - - c.Status(resp.StatusCode) - - // 用于累积完整响应内容 - var fullContent strings.Builder - var responseBody strings.Builder - - // 创建一个 scanner 来逐行读取流 - scanner := bufio.NewScanner(resp.Body) - flusher, ok := c.Writer.(http.Flusher) - - for scanner.Scan() { - line := scanner.Text() - - // 将原始行写入响应体缓冲区 - responseBody.WriteString(line) - responseBody.WriteString("\n") - - // SSE 格式:data: {...} - if strings.HasPrefix(line, "data: ") { - data := strings.TrimPrefix(line, "data: ") - - // 检查是否是结束标记 - if data == "[DONE]" { - _, err := c.Writer.Write([]byte(line + "\n\n")) - if err != nil { - return - } - if ok { - flusher.Flush() - } - break - } - - // 解析 JSON 数据以提取内容 - var chunk map[string]interface{} - if err := json.Unmarshal([]byte(data), &chunk); err == nil { - if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 { - if choice, ok := choices[0].(map[string]interface{}); ok { - if delta, ok := choice["delta"].(map[string]interface{}); ok { - if content, ok := delta["content"].(string); ok { - fullContent.WriteString(content) - } - } - } - } - } - } - - // 转发数据到客户端 - _, err := c.Writer.Write([]byte(line + "\n")) - if err != nil { - return - } - - // 如果是空行(SSE 消息分隔符),刷新 - if line == "" { - if ok { - flusher.Flush() - } - } - } - - // 确保发送最后的数据 - if ok { - flusher.Flush() - } - - // 扫描可能出现的错误 - if err := scanner.Err(); err != nil { - log.Printf("Error reading stream: %v", err) - } - - // 计算响应 token 数 - responseTokenCount := billing.CalculateTextTokensSimple(fullContent.String()) - - // 计算费用 - costCalculator := billing.NewCostCalculator() - cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount) - - // 创建日志记录 - logEntry := &models.RequestLog{ - APIKeyID: apiKeyID, - VirtualModelName: virtualModelName, - BackendModelName: backendModel.Name, - RequestTimestamp: requestTimestamp, - ResponseTimestamp: time.Now(), // 使用实际结束时间 - RequestTokens: requestTokenCount, - ResponseTokens: responseTokenCount, - Cost: cost, - RequestBody: requestBody, - ResponseBody: responseBody.String(), - } - - // 异步记录日志 - logger.LogRequest(database, logEntry) -} diff --git a/backend/api/openai_stream_handlers.go b/backend/api/openai_stream_handlers.go new file mode 100644 index 0000000..7ada168 --- /dev/null +++ b/backend/api/openai_stream_handlers.go @@ -0,0 +1,205 @@ +package api + +import ( + "ai-gateway/internal/billing" + "ai-gateway/internal/logger" + "ai-gateway/internal/models" + "bufio" + "encoding/json" + "log" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + PromptTokenDetails *UsagePromptTokenDetails `json:"prompt_token_details,omitempty"` + CompletionTokenDetails *UsageCompletionTokenDetails `json:"completion_token_details,omitempty"` + InputTokens int `json:"input_tokens,omitempty"` + OutputTokens int `json:"output_tokens,omitempty"` +} + +type UsagePromptTokenDetails struct { + CachedTokens int `json:"cached_tokens"` + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` + ImageTokens int `json:"image_tokens"` +} + +type UsageCompletionTokenDetails struct { + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` + ReasoningTokens int `json:"reasoning_tokens"` +} + +type StreamDelta struct { + Role string `json:"role,omitempty"` + Content *string `json:"content,omitempty"` + Reasoning *string `json:"reasoning,omitempty"` +} + +type StreamChoice struct { + Delta StreamDelta `json:"delta"` +} + +type StreamChunk struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []StreamChoice `json:"choices"` + Usage *Usage `json:"usage"` +} + +// handleStreamingResponse 处理流式响应 +func handleStreamingResponse(c *gin.Context, resp *http.Response, requestTimestamp, responseTimestamp time.Time, + apiKeyID uint, virtualModelName string, backendModel *models.BackendModel, requestTokenCount int, + requestBody string, database *gorm.DB) { + + // 设置流式响应头 + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Transfer-Encoding", "chunked") + + // 复制其他响应头 + for key, values := range resp.Header { + if key != "Content-Length" && key != "Transfer-Encoding" { + for _, value := range values { + c.Header(key, value) + } + } + } + + c.Status(resp.StatusCode) + + // 用于累积完整响应内容 + var fullContent strings.Builder + + // 创建一个 scanner 来逐行读取流 + scanner := bufio.NewScanner(resp.Body) + flusher, ok := c.Writer.(http.Flusher) + + var promptTokens int + var completionTokens int + var totalTokens int + + for scanner.Scan() { + chunk, done := processStreamingResponse(c, scanner, ok, flusher) + + // 仅在 chunk.ID 不为空时处理 token 计数和内容累积 + if chunk.ID != "" { + // 更新 token 计数 + if chunk.Usage != nil { + promptTokens = chunk.Usage.PromptTokens + completionTokens = chunk.Usage.CompletionTokens + totalTokens = chunk.Usage.TotalTokens + } + + // 记录完整响应内容 + for _, choice := range chunk.Choices { + if choice.Delta.Reasoning != nil { + fullContent.WriteString(*choice.Delta.Reasoning) + } + if choice.Delta.Content != nil { + fullContent.WriteString(*choice.Delta.Content) + } + } + } + + if done { + return + } + } + + // 确保发送最后的数据 + if ok { + flusher.Flush() + } + + // 扫描可能出现的错误 + if err := scanner.Err(); err != nil { + log.Printf("Error reading stream: %v", err) + } + + // 计算响应 token 数 + responseTokenCount := billing.CalculateTextTokensSimple(fullContent.String()) + if completionTokens > 0 { + // 如果后端返回了 total_tokens,则使用它 + responseTokenCount = completionTokens + } + if promptTokens > 0 { + requestTokenCount = promptTokens + } + log.Printf("Streamed Response Tokens - Prompt: %d, Completion: %d, Total: %d", promptTokens, completionTokens, totalTokens) + + // 计算费用 + costCalculator := billing.NewCostCalculator() + cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount) + + // 创建日志记录,只记录实际的响应内容 + logEntry := &models.RequestLog{ + APIKeyID: apiKeyID, + VirtualModelName: virtualModelName, + BackendModelName: backendModel.Name, + RequestTimestamp: requestTimestamp, + ResponseTimestamp: time.Now(), // 使用实际结束时间 + RequestTokens: requestTokenCount, + ResponseTokens: responseTokenCount, + Cost: cost, + RequestBody: requestBody, + ResponseBody: fullContent.String(), // 只记录提取的实际内容 + } + + // 异步记录日志 + logger.LogRequest(database, logEntry) +} + +func processStreamingResponse(c *gin.Context, scanner *bufio.Scanner, ok bool, flusher http.Flusher) (StreamChunk, bool) { + line := scanner.Text() + chunk := StreamChunk{} + + // SSE 格式:data: {...} + if strings.HasPrefix(line, "data: ") { + data := strings.TrimPrefix(line, "data: ") + + // 检查是否是结束标记 + if data == "[DONE]" { + _, err := c.Writer.Write([]byte(line + "\n\n")) + if err != nil { + return chunk, true + } + if ok { + flusher.Flush() + } + return chunk, false + } + + // 解析 JSON 数据以提取内容 + err := json.Unmarshal([]byte(data), &chunk) + if err != nil { + log.Printf("Error unmarshaling stream chunk: %v", err) + return chunk, true + } + } + + // 转发数据到客户端 + _, err := c.Writer.Write([]byte(line + "\n")) + if err != nil { + return chunk, true + } + + // 如果是空行(SSE 消息分隔符),刷新 + if line == "" { + if ok { + flusher.Flush() + } + } + return chunk, false +}