Files
AIRouter/backend/api/openai_stream_handlers.go

197 lines
5.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package api
import (
"ai-gateway/internal/billing"
"ai-gateway/internal/logger"
"ai-gateway/internal/models"
"bufio"
"encoding/json"
"log"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptTokenDetails *UsagePromptTokenDetails `json:"prompt_token_details,omitempty"`
CompletionTokenDetails *UsageCompletionTokenDetails `json:"completion_token_details,omitempty"`
InputTokens int `json:"input_tokens,omitempty"`
OutputTokens int `json:"output_tokens,omitempty"`
}
type UsagePromptTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
ImageTokens int `json:"image_tokens"`
}
type UsageCompletionTokenDetails struct {
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
ReasoningTokens int `json:"reasoning_tokens"`
}
type StreamDelta struct {
Role string `json:"role,omitempty"`
Content *string `json:"content,omitempty"`
Reasoning *string `json:"reasoning,omitempty"`
}
type StreamChoice struct {
Delta StreamDelta `json:"delta"`
}
type StreamChunk struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []StreamChoice `json:"choices"`
Usage *Usage `json:"usage"`
}
// handleStreamingResponse 处理流式响应
func handleStreamingResponse(c *gin.Context, resp *http.Response, requestTimestamp, responseTimestamp time.Time,
apiKeyID uint, virtualModelName string, backendModel *models.BackendModel, requestTokenCount int,
requestBody string, database *gorm.DB) {
// 设置流式响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Transfer-Encoding", "chunked")
// 复制其他响应头
for key, values := range resp.Header {
if key != "Content-Length" && key != "Transfer-Encoding" {
for _, value := range values {
c.Header(key, value)
}
}
}
c.Status(resp.StatusCode)
// 用于累积完整响应内容
var fullContent strings.Builder
// 创建一个 scanner 来逐行读取流
scanner := bufio.NewScanner(resp.Body)
flusher, ok := c.Writer.(http.Flusher)
var promptTokens int
var completionTokens int
var totalTokens int
for scanner.Scan() {
chunk, done := processStreamingResponse(c, scanner, ok, flusher)
// 仅在 chunk.ID 不为空时处理 token 计数和内容累积
if chunk.ID != "" {
// 更新 token 计数
if chunk.Usage != nil {
promptTokens = chunk.Usage.PromptTokens
completionTokens = chunk.Usage.CompletionTokens
totalTokens = chunk.Usage.TotalTokens
}
// 记录完整响应内容
for _, choice := range chunk.Choices {
if choice.Delta.Reasoning != nil {
fullContent.WriteString(*choice.Delta.Reasoning)
}
if choice.Delta.Content != nil {
fullContent.WriteString(*choice.Delta.Content)
}
}
}
if done {
return
}
}
// 确保发送最后的数据
if ok {
flusher.Flush()
}
// 扫描可能出现的错误
if err := scanner.Err(); err != nil {
log.Printf("Error reading stream: %v", err)
}
// 计算响应 token 数
responseTokenCount := billing.CalculateTextTokensSimple(fullContent.String())
if completionTokens > 0 {
// 如果后端返回了 total_tokens则使用它
responseTokenCount = completionTokens
}
if promptTokens > 0 {
requestTokenCount = promptTokens
}
log.Printf("Streamed Response Tokens - Prompt: %d, Completion: %d, Total: %d", promptTokens, completionTokens, totalTokens)
// 计算费用
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 创建日志记录,只记录实际的响应内容
logEntry := createRequestLog(apiKeyID, virtualModelName, backendModel,
requestTimestamp, time.Now(), requestTokenCount, responseTokenCount,
cost, requestBody, fullContent.String())
// 异步记录日志
logger.LogRequest(database, logEntry)
}
func processStreamingResponse(c *gin.Context, scanner *bufio.Scanner, ok bool, flusher http.Flusher) (StreamChunk, bool) {
line := scanner.Text()
chunk := StreamChunk{}
// SSE 格式data: {...}
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
// 检查是否是结束标记
if data == "[DONE]" {
_, err := c.Writer.Write([]byte(line + "\n\n"))
if err != nil {
return chunk, true
}
if ok {
flusher.Flush()
}
return chunk, false
}
// 解析 JSON 数据以提取内容
err := json.Unmarshal([]byte(data), &chunk)
if err != nil {
log.Printf("Error unmarshaling stream chunk: %v", err)
return chunk, true
}
}
// 转发数据到客户端
_, err := c.Writer.Write([]byte(line + "\n"))
if err != nil {
return chunk, true
}
// 如果是空行SSE 消息分隔符),刷新
if line == "" {
if ok {
flusher.Flush()
}
}
return chunk, false
}