重构流式响应处理逻辑,新增流式数据解析和日志记录功能,优化token计数和费用计算

This commit is contained in:
2025-11-09 01:49:08 +08:00
parent 46096160fc
commit 1e2cf83ff0
2 changed files with 205 additions and 120 deletions

View File

@@ -5,16 +5,13 @@ import (
"ai-gateway/internal/logger"
"ai-gateway/internal/models"
"ai-gateway/internal/router"
"bufio"
"encoding/json"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// ListModels 处理 GET /models 请求
@@ -307,120 +304,3 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
c.Status(resp.StatusCode)
c.Writer.Write(responseBody)
}
// handleStreamingResponse 处理流式响应
func handleStreamingResponse(c *gin.Context, resp *http.Response, requestTimestamp, responseTimestamp time.Time,
apiKeyID uint, virtualModelName string, backendModel *models.BackendModel, requestTokenCount int,
requestBody string, database *gorm.DB) {
// 设置流式响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Transfer-Encoding", "chunked")
// 复制其他响应头
for key, values := range resp.Header {
if key != "Content-Length" && key != "Transfer-Encoding" {
for _, value := range values {
c.Header(key, value)
}
}
}
c.Status(resp.StatusCode)
// 用于累积完整响应内容
var fullContent strings.Builder
var responseBody strings.Builder
// 创建一个 scanner 来逐行读取流
scanner := bufio.NewScanner(resp.Body)
flusher, ok := c.Writer.(http.Flusher)
for scanner.Scan() {
line := scanner.Text()
// 将原始行写入响应体缓冲区
responseBody.WriteString(line)
responseBody.WriteString("\n")
// SSE 格式data: {...}
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
// 检查是否是结束标记
if data == "[DONE]" {
_, err := c.Writer.Write([]byte(line + "\n\n"))
if err != nil {
return
}
if ok {
flusher.Flush()
}
break
}
// 解析 JSON 数据以提取内容
var chunk map[string]interface{}
if err := json.Unmarshal([]byte(data), &chunk); err == nil {
if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 {
if choice, ok := choices[0].(map[string]interface{}); ok {
if delta, ok := choice["delta"].(map[string]interface{}); ok {
if content, ok := delta["content"].(string); ok {
fullContent.WriteString(content)
}
}
}
}
}
}
// 转发数据到客户端
_, err := c.Writer.Write([]byte(line + "\n"))
if err != nil {
return
}
// 如果是空行SSE 消息分隔符),刷新
if line == "" {
if ok {
flusher.Flush()
}
}
}
// 确保发送最后的数据
if ok {
flusher.Flush()
}
// 扫描可能出现的错误
if err := scanner.Err(); err != nil {
log.Printf("Error reading stream: %v", err)
}
// 计算响应 token 数
responseTokenCount := billing.CalculateTextTokensSimple(fullContent.String())
// 计算费用
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 创建日志记录
logEntry := &models.RequestLog{
APIKeyID: apiKeyID,
VirtualModelName: virtualModelName,
BackendModelName: backendModel.Name,
RequestTimestamp: requestTimestamp,
ResponseTimestamp: time.Now(), // 使用实际结束时间
RequestTokens: requestTokenCount,
ResponseTokens: responseTokenCount,
Cost: cost,
RequestBody: requestBody,
ResponseBody: responseBody.String(),
}
// 异步记录日志
logger.LogRequest(database, logEntry)
}

View File

@@ -0,0 +1,205 @@
package api
import (
"ai-gateway/internal/billing"
"ai-gateway/internal/logger"
"ai-gateway/internal/models"
"bufio"
"encoding/json"
"log"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptTokenDetails *UsagePromptTokenDetails `json:"prompt_token_details,omitempty"`
CompletionTokenDetails *UsageCompletionTokenDetails `json:"completion_token_details,omitempty"`
InputTokens int `json:"input_tokens,omitempty"`
OutputTokens int `json:"output_tokens,omitempty"`
}
type UsagePromptTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
ImageTokens int `json:"image_tokens"`
}
type UsageCompletionTokenDetails struct {
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
ReasoningTokens int `json:"reasoning_tokens"`
}
type StreamDelta struct {
Role string `json:"role,omitempty"`
Content *string `json:"content,omitempty"`
Reasoning *string `json:"reasoning,omitempty"`
}
type StreamChoice struct {
Delta StreamDelta `json:"delta"`
}
type StreamChunk struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []StreamChoice `json:"choices"`
Usage *Usage `json:"usage"`
}
// handleStreamingResponse 处理流式响应
func handleStreamingResponse(c *gin.Context, resp *http.Response, requestTimestamp, responseTimestamp time.Time,
apiKeyID uint, virtualModelName string, backendModel *models.BackendModel, requestTokenCount int,
requestBody string, database *gorm.DB) {
// 设置流式响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Transfer-Encoding", "chunked")
// 复制其他响应头
for key, values := range resp.Header {
if key != "Content-Length" && key != "Transfer-Encoding" {
for _, value := range values {
c.Header(key, value)
}
}
}
c.Status(resp.StatusCode)
// 用于累积完整响应内容
var fullContent strings.Builder
// 创建一个 scanner 来逐行读取流
scanner := bufio.NewScanner(resp.Body)
flusher, ok := c.Writer.(http.Flusher)
var promptTokens int
var completionTokens int
var totalTokens int
for scanner.Scan() {
chunk, done := processStreamingResponse(c, scanner, ok, flusher)
// 仅在 chunk.ID 不为空时处理 token 计数和内容累积
if chunk.ID != "" {
// 更新 token 计数
if chunk.Usage != nil {
promptTokens = chunk.Usage.PromptTokens
completionTokens = chunk.Usage.CompletionTokens
totalTokens = chunk.Usage.TotalTokens
}
// 记录完整响应内容
for _, choice := range chunk.Choices {
if choice.Delta.Reasoning != nil {
fullContent.WriteString(*choice.Delta.Reasoning)
}
if choice.Delta.Content != nil {
fullContent.WriteString(*choice.Delta.Content)
}
}
}
if done {
return
}
}
// 确保发送最后的数据
if ok {
flusher.Flush()
}
// 扫描可能出现的错误
if err := scanner.Err(); err != nil {
log.Printf("Error reading stream: %v", err)
}
// 计算响应 token 数
responseTokenCount := billing.CalculateTextTokensSimple(fullContent.String())
if completionTokens > 0 {
// 如果后端返回了 total_tokens则使用它
responseTokenCount = completionTokens
}
if promptTokens > 0 {
requestTokenCount = promptTokens
}
log.Printf("Streamed Response Tokens - Prompt: %d, Completion: %d, Total: %d", promptTokens, completionTokens, totalTokens)
// 计算费用
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 创建日志记录,只记录实际的响应内容
logEntry := &models.RequestLog{
APIKeyID: apiKeyID,
VirtualModelName: virtualModelName,
BackendModelName: backendModel.Name,
RequestTimestamp: requestTimestamp,
ResponseTimestamp: time.Now(), // 使用实际结束时间
RequestTokens: requestTokenCount,
ResponseTokens: responseTokenCount,
Cost: cost,
RequestBody: requestBody,
ResponseBody: fullContent.String(), // 只记录提取的实际内容
}
// 异步记录日志
logger.LogRequest(database, logEntry)
}
func processStreamingResponse(c *gin.Context, scanner *bufio.Scanner, ok bool, flusher http.Flusher) (StreamChunk, bool) {
line := scanner.Text()
chunk := StreamChunk{}
// SSE 格式data: {...}
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
// 检查是否是结束标记
if data == "[DONE]" {
_, err := c.Writer.Write([]byte(line + "\n\n"))
if err != nil {
return chunk, true
}
if ok {
flusher.Flush()
}
return chunk, false
}
// 解析 JSON 数据以提取内容
err := json.Unmarshal([]byte(data), &chunk)
if err != nil {
log.Printf("Error unmarshaling stream chunk: %v", err)
return chunk, true
}
}
// 转发数据到客户端
_, err := c.Writer.Write([]byte(line + "\n"))
if err != nil {
return chunk, true
}
// 如果是空行SSE 消息分隔符),刷新
if line == "" {
if ok {
flusher.Flush()
}
}
return chunk, false
}