重构请求处理逻辑,优化请求体解析和后端请求转发,增加日志记录功能

This commit is contained in:
2025-11-09 00:10:42 +08:00
parent c7cea51059
commit e2780a08b3

View File

@@ -129,13 +129,12 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
var req ChatCompletionRequest
// 增加日志记录请求体
// 读取并解析请求体
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // 把body重新写回去
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// 解析请求体
if err := c.ShouldBindJSON(&req); err != nil {
log.Printf("Failed to bind JSON: %v", err) // 增加错误日志
log.Printf("Failed to bind JSON: %v", err)
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "Invalid request format",
@@ -145,7 +144,7 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
return
}
// 使用tiktoken精确计算请求token数
// 计算请求token数
messages := convertToTikTokenMessages(req.Messages)
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
@@ -161,23 +160,8 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
return
}
// 准备转发请求
// 用body创建一个json对象
var jsonBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to process request",
"type": "internal_error",
},
})
return
}
// 然后修改jsonBody的model字段
jsonBody["model"] = backendModel.Name
// 最后重新marshal回requestBody
requestBody, err := json.Marshal(jsonBody)
// 准备后端请求
requestBody, err := prepareBackendRequest(bodyBytes, backendModel.Name)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
@@ -191,32 +175,8 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
// 构建后端API URL
backendURL := backendModel.Provider.BaseURL + "/v1/chat/completions"
// 创建HTTP请求
httpReq, err := http.NewRequest("POST", backendURL, bytes.NewReader(requestBody))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to create backend request",
"type": "internal_error",
},
})
return
}
// 设置请求头
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+backendModel.Provider.ApiKey)
// 复制原始请求的其他相关头部
if userAgent := c.GetHeader("User-Agent"); userAgent != "" {
httpReq.Header.Set("User-Agent", userAgent)
}
// 执行请求
client := &http.Client{
Timeout: 120 * time.Second,
}
resp, err := client.Do(httpReq)
// 转发请求到后端
resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent"))
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
@@ -231,26 +191,18 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
// 记录响应时间
responseTimestamp := time.Now()
// 从上下文获取API密钥
apiKeyValue, exists := c.Get("apiKey")
var apiKeyID uint
if exists {
if apiKey, ok := apiKeyValue.(models.APIKey); ok {
apiKeyID = apiKey.ID
}
}
// 获取API Key ID
apiKeyID := getAPIKeyID(c)
// 检查是否是流式请求
// 处理流式响应
if req.Stream {
// 使用流式响应处理函数
handleStreamingResponse(c, resp, requestTimestamp, responseTimestamp,
apiKeyID, req.Model, backendModel, requestTokenCount,
string(requestBody), h.DB)
return
}
// 非流式响应处理
// 读取响应体
// 处理非流式响应
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
@@ -262,54 +214,24 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
return
}
// 增加日志记录后端响应
log.Printf("Backend Response Status: %s", resp.Status)
log.Printf("Backend Response Body: %s", string(responseBody))
// 计算响应token数
responseTokenCount := 0
var responseData map[string]interface{}
if err := json.Unmarshal(responseBody, &responseData); err == nil {
if choices, ok := responseData["choices"].([]interface{}); ok && len(choices) > 0 {
if choice, ok := choices[0].(map[string]interface{}); ok {
if message, ok := choice["message"].(map[string]interface{}); ok {
if content, ok := message["content"].(string); ok {
responseTokenCount = billing.CalculateTextTokensSimple(content)
}
}
}
}
}
responseTokenCount := extractResponseTokenCount(responseBody)
// 计算费用
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 创建日志记录
logEntry := &models.RequestLog{
APIKeyID: apiKeyID,
VirtualModelName: req.Model,
BackendModelName: backendModel.Name,
RequestTimestamp: requestTimestamp,
ResponseTimestamp: responseTimestamp,
RequestTokens: requestTokenCount,
ResponseTokens: responseTokenCount,
Cost: cost,
RequestBody: string(requestBody),
ResponseBody: string(responseBody),
}
// 异步记录日志
// 创建并记录日志
logEntry := createRequestLog(apiKeyID, req.Model, backendModel,
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
cost, string(requestBody), string(responseBody))
logger.LogRequest(h.DB, logEntry)
// 复制响应头
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
// 设置响应状态码并返回响应体
// 复制响应头并返回响应
copyResponseHeaders(c, resp)
c.Status(resp.StatusCode)
c.Writer.Write(responseBody)
}
@@ -321,11 +243,10 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
var req ResponsesRequest
// 增加日志记录请求体
// 读取并解析请求体
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // 把body重新写回去
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// 解析请求体
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
@@ -336,7 +257,7 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
return
}
// 使用tiktoken精确计算请求token数
// 计算请求token数
messages := convertToTikTokenMessages(req.Input)
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
@@ -352,10 +273,9 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
return
}
// 准备转发请求
// 用body创建一个json对象
var jsonBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
// 准备后端请求
requestBody, err := prepareBackendRequest(bodyBytes, backendModel.Name)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to process request",
@@ -364,40 +284,12 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
})
return
}
// 然后修改jsonBody的model字段
jsonBody["model"] = backendModel.Name
// 最后重新marshal回requestBody
requestBody, err := json.Marshal(jsonBody)
// 构建后端API URL
backendURL := backendModel.Provider.BaseURL + "/v1/responses"
// 创建HTTP请求
httpReq, err := http.NewRequest("POST", backendURL, bytes.NewReader(requestBody))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to create backend request",
"type": "internal_error",
},
})
return
}
// 设置请求头
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+backendModel.Provider.ApiKey)
// 复制原始请求的其他相关头部
if userAgent := c.GetHeader("User-Agent"); userAgent != "" {
httpReq.Header.Set("User-Agent", userAgent)
}
// 执行请求
client := &http.Client{
Timeout: 120 * time.Second,
}
resp, err := client.Do(httpReq)
// 转发请求到后端
resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent"))
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
@@ -425,58 +317,23 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
}
// 计算响应token数
responseTokenCount := 0
var responseData map[string]interface{}
if err := json.Unmarshal(responseBody, &responseData); err == nil {
if choices, ok := responseData["choices"].([]interface{}); ok && len(choices) > 0 {
if choice, ok := choices[0].(map[string]interface{}); ok {
if message, ok := choice["message"].(map[string]interface{}); ok {
if content, ok := message["content"].(string); ok {
responseTokenCount = billing.CalculateTextTokensSimple(content)
}
}
}
}
}
responseTokenCount := extractResponseTokenCount(responseBody)
// 计算费用
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 从上下文获取API密钥
apiKeyValue, exists := c.Get("apiKey")
var apiKeyID uint
if exists {
if apiKey, ok := apiKeyValue.(models.APIKey); ok {
apiKeyID = apiKey.ID
}
}
// 获取API Key ID
apiKeyID := getAPIKeyID(c)
// 创建日志记录
logEntry := &models.RequestLog{
APIKeyID: apiKeyID,
VirtualModelName: req.Model,
BackendModelName: backendModel.Name,
RequestTimestamp: requestTimestamp,
ResponseTimestamp: responseTimestamp,
RequestTokens: requestTokenCount,
ResponseTokens: responseTokenCount,
Cost: cost,
RequestBody: string(requestBody),
ResponseBody: string(responseBody),
}
// 异步记录日志
// 创建并记录日志
logEntry := createRequestLog(apiKeyID, req.Model, backendModel,
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
cost, string(requestBody), string(responseBody))
logger.LogRequest(h.DB, logEntry)
// 复制响应
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
// 设置响应状态码并返回响应体
// 复制响应头并返回响应
copyResponseHeaders(c, resp)
c.Status(resp.StatusCode)
c.Writer.Write(responseBody)
}
@@ -493,6 +350,117 @@ func convertToTikTokenMessages(messages []ChatCompletionMessage) []billing.ChatC
return result
}
// prepareBackendRequest 准备后端请求体,替换model字段
func prepareBackendRequest(bodyBytes []byte, backendModelName string) ([]byte, error) {
var jsonBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
return nil, fmt.Errorf("failed to unmarshal request body: %w", err)
}
// 修改model字段为后端模型名称
jsonBody["model"] = backendModelName
// 重新marshal
requestBody, err := json.Marshal(jsonBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
return requestBody, nil
}
// forwardRequest 转发请求到后端API
func forwardRequest(backendURL string, apiKey string, requestBody []byte, userAgent string) (*http.Response, error) {
// 创建HTTP请求
httpReq, err := http.NewRequest("POST", backendURL, bytes.NewReader(requestBody))
if err != nil {
return nil, fmt.Errorf("failed to create backend request: %w", err)
}
// 设置请求头
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
// 设置User-Agent
if userAgent != "" {
httpReq.Header.Set("User-Agent", userAgent)
}
// 执行请求
client := &http.Client{
Timeout: 10 * time.Minute,
}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to connect to backend service: %w", err)
}
return resp, nil
}
// extractResponseTokenCount 从响应体中提取token数量
func extractResponseTokenCount(responseBody []byte) int {
var responseData map[string]interface{}
if err := json.Unmarshal(responseBody, &responseData); err != nil {
return 0
}
if choices, ok := responseData["choices"].([]interface{}); ok && len(choices) > 0 {
if choice, ok := choices[0].(map[string]interface{}); ok {
if message, ok := choice["message"].(map[string]interface{}); ok {
if content, ok := message["content"].(string); ok {
return billing.CalculateTextTokensSimple(content)
}
}
}
}
return 0
}
// createRequestLog 创建请求日志记录
func createRequestLog(apiKeyID uint, virtualModelName string, backendModel *models.BackendModel,
requestTimestamp, responseTimestamp time.Time, requestTokenCount, responseTokenCount int,
cost float64, requestBody, responseBody string) *models.RequestLog {
return &models.RequestLog{
APIKeyID: apiKeyID,
VirtualModelName: virtualModelName,
BackendModelName: backendModel.Name,
RequestTimestamp: requestTimestamp,
ResponseTimestamp: responseTimestamp,
RequestTokens: requestTokenCount,
ResponseTokens: responseTokenCount,
Cost: cost,
RequestBody: requestBody,
ResponseBody: responseBody,
}
}
// getAPIKeyID 从gin上下文中提取API Key ID
func getAPIKeyID(c *gin.Context) uint {
apiKeyValue, exists := c.Get("apiKey")
if !exists {
return 0
}
if apiKey, ok := apiKeyValue.(models.APIKey); ok {
return apiKey.ID
}
return 0
}
// copyResponseHeaders 复制响应头到gin context
func copyResponseHeaders(c *gin.Context, resp *http.Response) {
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
}
// handleStreamingResponse 处理流式响应
func handleStreamingResponse(c *gin.Context, resp *http.Response, requestTimestamp, responseTimestamp time.Time,
apiKeyID uint, virtualModelName string, backendModel *models.BackendModel, requestTokenCount int,