197 lines
5.3 KiB
Go
197 lines
5.3 KiB
Go
package api
|
||
|
||
import (
|
||
"ai-gateway/internal/billing"
|
||
"ai-gateway/internal/logger"
|
||
"ai-gateway/internal/models"
|
||
"bufio"
|
||
"encoding/json"
|
||
"log"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
type Usage struct {
|
||
PromptTokens int `json:"prompt_tokens"`
|
||
CompletionTokens int `json:"completion_tokens"`
|
||
TotalTokens int `json:"total_tokens"`
|
||
PromptTokenDetails *UsagePromptTokenDetails `json:"prompt_token_details,omitempty"`
|
||
CompletionTokenDetails *UsageCompletionTokenDetails `json:"completion_token_details,omitempty"`
|
||
InputTokens int `json:"input_tokens,omitempty"`
|
||
OutputTokens int `json:"output_tokens,omitempty"`
|
||
}
|
||
|
||
type UsagePromptTokenDetails struct {
|
||
CachedTokens int `json:"cached_tokens"`
|
||
TextTokens int `json:"text_tokens"`
|
||
AudioTokens int `json:"audio_tokens"`
|
||
ImageTokens int `json:"image_tokens"`
|
||
}
|
||
|
||
type UsageCompletionTokenDetails struct {
|
||
TextTokens int `json:"text_tokens"`
|
||
AudioTokens int `json:"audio_tokens"`
|
||
ReasoningTokens int `json:"reasoning_tokens"`
|
||
}
|
||
|
||
type StreamDelta struct {
|
||
Role string `json:"role,omitempty"`
|
||
Content *string `json:"content,omitempty"`
|
||
Reasoning *string `json:"reasoning,omitempty"`
|
||
}
|
||
|
||
type StreamChoice struct {
|
||
Delta StreamDelta `json:"delta"`
|
||
}
|
||
|
||
type StreamChunk struct {
|
||
ID string `json:"id"`
|
||
Object string `json:"object"`
|
||
Created int64 `json:"created"`
|
||
Model string `json:"model"`
|
||
Choices []StreamChoice `json:"choices"`
|
||
Usage *Usage `json:"usage"`
|
||
}
|
||
|
||
// handleStreamingResponse 处理流式响应
|
||
func handleStreamingResponse(c *gin.Context, resp *http.Response, requestTimestamp, responseTimestamp time.Time,
|
||
apiKeyID uint, virtualModelName string, backendModel *models.BackendModel, requestTokenCount int,
|
||
requestBody string, database *gorm.DB) {
|
||
|
||
// 设置流式响应头
|
||
c.Header("Content-Type", "text/event-stream")
|
||
c.Header("Cache-Control", "no-cache")
|
||
c.Header("Connection", "keep-alive")
|
||
c.Header("Transfer-Encoding", "chunked")
|
||
|
||
// 复制其他响应头
|
||
for key, values := range resp.Header {
|
||
if key != "Content-Length" && key != "Transfer-Encoding" {
|
||
for _, value := range values {
|
||
c.Header(key, value)
|
||
}
|
||
}
|
||
}
|
||
|
||
c.Status(resp.StatusCode)
|
||
|
||
// 用于累积完整响应内容
|
||
var fullContent strings.Builder
|
||
|
||
// 创建一个 scanner 来逐行读取流
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
flusher, ok := c.Writer.(http.Flusher)
|
||
|
||
var promptTokens int
|
||
var completionTokens int
|
||
var totalTokens int
|
||
|
||
for scanner.Scan() {
|
||
chunk, done := processStreamingResponse(c, scanner, ok, flusher)
|
||
|
||
// 仅在 chunk.ID 不为空时处理 token 计数和内容累积
|
||
if chunk.ID != "" {
|
||
// 更新 token 计数
|
||
if chunk.Usage != nil {
|
||
promptTokens = chunk.Usage.PromptTokens
|
||
completionTokens = chunk.Usage.CompletionTokens
|
||
totalTokens = chunk.Usage.TotalTokens
|
||
}
|
||
|
||
// 记录完整响应内容
|
||
for _, choice := range chunk.Choices {
|
||
if choice.Delta.Reasoning != nil {
|
||
fullContent.WriteString(*choice.Delta.Reasoning)
|
||
}
|
||
if choice.Delta.Content != nil {
|
||
fullContent.WriteString(*choice.Delta.Content)
|
||
}
|
||
}
|
||
}
|
||
|
||
if done {
|
||
return
|
||
}
|
||
}
|
||
|
||
// 确保发送最后的数据
|
||
if ok {
|
||
flusher.Flush()
|
||
}
|
||
|
||
// 扫描可能出现的错误
|
||
if err := scanner.Err(); err != nil {
|
||
log.Printf("Error reading stream: %v", err)
|
||
}
|
||
|
||
// 计算响应 token 数
|
||
responseTokenCount := billing.CalculateTextTokensSimple(fullContent.String())
|
||
if completionTokens > 0 {
|
||
// 如果后端返回了 total_tokens,则使用它
|
||
responseTokenCount = completionTokens
|
||
}
|
||
if promptTokens > 0 {
|
||
requestTokenCount = promptTokens
|
||
}
|
||
log.Printf("Streamed Response Tokens - Prompt: %d, Completion: %d, Total: %d", promptTokens, completionTokens, totalTokens)
|
||
|
||
// 计算费用
|
||
costCalculator := billing.NewCostCalculator()
|
||
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
|
||
|
||
// 创建日志记录,只记录实际的响应内容
|
||
logEntry := createRequestLog(apiKeyID, virtualModelName, backendModel,
|
||
requestTimestamp, time.Now(), requestTokenCount, responseTokenCount,
|
||
cost, requestBody, fullContent.String())
|
||
|
||
// 异步记录日志
|
||
logger.LogRequest(database, logEntry)
|
||
}
|
||
|
||
func processStreamingResponse(c *gin.Context, scanner *bufio.Scanner, ok bool, flusher http.Flusher) (StreamChunk, bool) {
|
||
line := scanner.Text()
|
||
chunk := StreamChunk{}
|
||
|
||
// SSE 格式:data: {...}
|
||
if strings.HasPrefix(line, "data: ") {
|
||
data := strings.TrimPrefix(line, "data: ")
|
||
|
||
// 检查是否是结束标记
|
||
if data == "[DONE]" {
|
||
_, err := c.Writer.Write([]byte(line + "\n\n"))
|
||
if err != nil {
|
||
return chunk, true
|
||
}
|
||
if ok {
|
||
flusher.Flush()
|
||
}
|
||
return chunk, false
|
||
}
|
||
|
||
// 解析 JSON 数据以提取内容
|
||
err := json.Unmarshal([]byte(data), &chunk)
|
||
if err != nil {
|
||
log.Printf("Error unmarshaling stream chunk: %v", err)
|
||
return chunk, true
|
||
}
|
||
}
|
||
|
||
// 转发数据到客户端
|
||
_, err := c.Writer.Write([]byte(line + "\n"))
|
||
if err != nil {
|
||
return chunk, true
|
||
}
|
||
|
||
// 如果是空行(SSE 消息分隔符),刷新
|
||
if line == "" {
|
||
if ok {
|
||
flusher.Flush()
|
||
}
|
||
}
|
||
return chunk, false
|
||
}
|