添加流式响应处理功能,优化API密钥获取和日志记录
This commit is contained in:
@@ -5,12 +5,14 @@ import (
|
||||
"ai-gateway/internal/logger"
|
||||
"ai-gateway/internal/models"
|
||||
"ai-gateway/internal/router"
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -228,6 +230,25 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
|
||||
// 记录响应时间
|
||||
responseTimestamp := time.Now()
|
||||
|
||||
// 从上下文获取API密钥
|
||||
apiKeyValue, exists := c.Get("apiKey")
|
||||
var apiKeyID uint
|
||||
if exists {
|
||||
if apiKey, ok := apiKeyValue.(models.APIKey); ok {
|
||||
apiKeyID = apiKey.ID
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否是流式请求
|
||||
if req.Stream {
|
||||
// 使用流式响应处理函数
|
||||
handleStreamingResponse(c, resp, requestTimestamp, responseTimestamp,
|
||||
apiKeyID, req.Model, backendModel, requestTokenCount,
|
||||
string(requestBody), h.DB)
|
||||
return
|
||||
}
|
||||
|
||||
// 非流式响应处理
|
||||
// 读取响应体
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
@@ -268,15 +289,6 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
|
||||
cost = backendModel.FixedPrice
|
||||
}
|
||||
|
||||
// 从上下文获取API密钥
|
||||
apiKeyValue, exists := c.Get("apiKey")
|
||||
var apiKeyID uint
|
||||
if exists {
|
||||
if apiKey, ok := apiKeyValue.(models.APIKey); ok {
|
||||
apiKeyID = apiKey.ID
|
||||
}
|
||||
}
|
||||
|
||||
// 创建日志记录
|
||||
logEntry := &models.RequestLog{
|
||||
APIKeyID: apiKeyID,
|
||||
@@ -502,6 +514,128 @@ func calculateTokenCountFromText(text string) int {
|
||||
return len(encoding.Encode(text, nil, nil))
|
||||
}
|
||||
|
||||
// handleStreamingResponse 处理流式响应
|
||||
func handleStreamingResponse(c *gin.Context, resp *http.Response, requestTimestamp, responseTimestamp time.Time,
|
||||
apiKeyID uint, virtualModelName string, backendModel *models.BackendModel, requestTokenCount int,
|
||||
requestBody string, database *gorm.DB) {
|
||||
|
||||
// 设置流式响应\u5934
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Transfer-Encoding", "chunked")
|
||||
|
||||
// 复制其他响应头
|
||||
for key, values := range resp.Header {
|
||||
if key != "Content-Length" && key != "Transfer-Encoding" {
|
||||
for _, value := range values {
|
||||
c.Header(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Status(resp.StatusCode)
|
||||
|
||||
// 用于累积完整响应内容
|
||||
var fullContent strings.Builder
|
||||
var responseBody strings.Builder
|
||||
|
||||
// 创建一个 scanner 来逐行读取流
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
// 将原始行写入响应体缓冲区
|
||||
responseBody.WriteString(line)
|
||||
responseBody.WriteString("\n")
|
||||
|
||||
// SSE 格式:data: {...}
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
|
||||
// 检查是否是结束标记
|
||||
if data == "[DONE]" {
|
||||
_, err := c.Writer.Write([]byte(line + "\n\n"))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// 解析 JSON 数据以提取内容
|
||||
var chunk map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err == nil {
|
||||
if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 {
|
||||
if choice, ok := choices[0].(map[string]interface{}); ok {
|
||||
if delta, ok := choice["delta"].(map[string]interface{}); ok {
|
||||
if content, ok := delta["content"].(string); ok {
|
||||
fullContent.WriteString(content)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 转发数据到客户端
|
||||
_, err := c.Writer.Write([]byte(line + "\n"))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 如果是空行(SSE 消息分隔符),刷新
|
||||
if line == "" {
|
||||
if ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 确保发送最后的数据
|
||||
if ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// 扫描可能出现的错误
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.Printf("Error reading stream: %v", err)
|
||||
}
|
||||
|
||||
// 计算响应 token 数
|
||||
responseTokenCount := calculateTokenCountFromText(fullContent.String())
|
||||
|
||||
// 计算费用
|
||||
var cost float64
|
||||
switch backendModel.BillingMethod {
|
||||
case models.BillingMethodToken:
|
||||
cost = float64(requestTokenCount)*backendModel.PromptTokenPrice + float64(responseTokenCount)*backendModel.CompletionTokenPrice
|
||||
case models.BillingMethodRequest:
|
||||
cost = backendModel.FixedPrice
|
||||
}
|
||||
|
||||
// 创建日志记录
|
||||
logEntry := &models.RequestLog{
|
||||
APIKeyID: apiKeyID,
|
||||
VirtualModelName: virtualModelName,
|
||||
BackendModelName: backendModel.Name,
|
||||
RequestTimestamp: requestTimestamp,
|
||||
ResponseTimestamp: time.Now(), // 使用实际结束时间
|
||||
RequestTokens: requestTokenCount,
|
||||
ResponseTokens: responseTokenCount,
|
||||
Cost: cost,
|
||||
RequestBody: requestBody,
|
||||
ResponseBody: responseBody.String(),
|
||||
}
|
||||
|
||||
// 异步记录日志
|
||||
logger.LogRequest(database, logEntry)
|
||||
}
|
||||
|
||||
// GetProvidersHandler 处理 GET /api/providers 请求
|
||||
func (h *APIHandler) GetProvidersHandler(c *gin.Context) {
|
||||
providers, err := db.GetProviders(h.DB)
|
||||
|
||||
Reference in New Issue
Block a user