重构请求处理逻辑,优化请求体解析和后端请求转发,增加日志记录功能
This commit is contained in:
@@ -129,13 +129,12 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
|
||||
|
||||
var req ChatCompletionRequest
|
||||
|
||||
// 增加日志记录请求体
|
||||
// 读取并解析请求体
|
||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // 把body重新写回去
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// 解析请求体
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
log.Printf("Failed to bind JSON: %v", err) // 增加错误日志
|
||||
log.Printf("Failed to bind JSON: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Invalid request format",
|
||||
@@ -145,7 +144,7 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 使用tiktoken精确计算请求token数
|
||||
// 计算请求token数
|
||||
messages := convertToTikTokenMessages(req.Messages)
|
||||
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
|
||||
|
||||
@@ -161,23 +160,8 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 准备转发请求
|
||||
// 用body创建一个json对象
|
||||
var jsonBody map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to process request",
|
||||
"type": "internal_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
// 然后修改jsonBody的model字段
|
||||
jsonBody["model"] = backendModel.Name
|
||||
// 最后重新marshal回requestBody
|
||||
requestBody, err := json.Marshal(jsonBody)
|
||||
|
||||
// 准备后端请求
|
||||
requestBody, err := prepareBackendRequest(bodyBytes, backendModel.Name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
@@ -191,32 +175,8 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
|
||||
// 构建后端API URL
|
||||
backendURL := backendModel.Provider.BaseURL + "/v1/chat/completions"
|
||||
|
||||
// 创建HTTP请求
|
||||
httpReq, err := http.NewRequest("POST", backendURL, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to create backend request",
|
||||
"type": "internal_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+backendModel.Provider.ApiKey)
|
||||
|
||||
// 复制原始请求的其他相关头部
|
||||
if userAgent := c.GetHeader("User-Agent"); userAgent != "" {
|
||||
httpReq.Header.Set("User-Agent", userAgent)
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
client := &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
}
|
||||
resp, err := client.Do(httpReq)
|
||||
// 转发请求到后端
|
||||
resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
@@ -231,26 +191,18 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
|
||||
// 记录响应时间
|
||||
responseTimestamp := time.Now()
|
||||
|
||||
// 从上下文获取API密钥
|
||||
apiKeyValue, exists := c.Get("apiKey")
|
||||
var apiKeyID uint
|
||||
if exists {
|
||||
if apiKey, ok := apiKeyValue.(models.APIKey); ok {
|
||||
apiKeyID = apiKey.ID
|
||||
}
|
||||
}
|
||||
// 获取API Key ID
|
||||
apiKeyID := getAPIKeyID(c)
|
||||
|
||||
// 检查是否是流式请求
|
||||
// 处理流式响应
|
||||
if req.Stream {
|
||||
// 使用流式响应处理函数
|
||||
handleStreamingResponse(c, resp, requestTimestamp, responseTimestamp,
|
||||
apiKeyID, req.Model, backendModel, requestTokenCount,
|
||||
string(requestBody), h.DB)
|
||||
return
|
||||
}
|
||||
|
||||
// 非流式响应处理
|
||||
// 读取响应体
|
||||
// 处理非流式响应
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
@@ -262,54 +214,24 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 增加日志记录后端响应
|
||||
log.Printf("Backend Response Status: %s", resp.Status)
|
||||
log.Printf("Backend Response Body: %s", string(responseBody))
|
||||
|
||||
// 计算响应token数
|
||||
responseTokenCount := 0
|
||||
var responseData map[string]interface{}
|
||||
if err := json.Unmarshal(responseBody, &responseData); err == nil {
|
||||
if choices, ok := responseData["choices"].([]interface{}); ok && len(choices) > 0 {
|
||||
if choice, ok := choices[0].(map[string]interface{}); ok {
|
||||
if message, ok := choice["message"].(map[string]interface{}); ok {
|
||||
if content, ok := message["content"].(string); ok {
|
||||
responseTokenCount = billing.CalculateTextTokensSimple(content)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
responseTokenCount := extractResponseTokenCount(responseBody)
|
||||
|
||||
// 计算费用
|
||||
costCalculator := billing.NewCostCalculator()
|
||||
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
|
||||
|
||||
// 创建日志记录
|
||||
logEntry := &models.RequestLog{
|
||||
APIKeyID: apiKeyID,
|
||||
VirtualModelName: req.Model,
|
||||
BackendModelName: backendModel.Name,
|
||||
RequestTimestamp: requestTimestamp,
|
||||
ResponseTimestamp: responseTimestamp,
|
||||
RequestTokens: requestTokenCount,
|
||||
ResponseTokens: responseTokenCount,
|
||||
Cost: cost,
|
||||
RequestBody: string(requestBody),
|
||||
ResponseBody: string(responseBody),
|
||||
}
|
||||
|
||||
// 异步记录日志
|
||||
// 创建并记录日志
|
||||
logEntry := createRequestLog(apiKeyID, req.Model, backendModel,
|
||||
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
|
||||
cost, string(requestBody), string(responseBody))
|
||||
logger.LogRequest(h.DB, logEntry)
|
||||
|
||||
// 复制响应头
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
c.Header(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// 设置响应状态码并返回响应体
|
||||
// 复制响应头并返回响应
|
||||
copyResponseHeaders(c, resp)
|
||||
c.Status(resp.StatusCode)
|
||||
c.Writer.Write(responseBody)
|
||||
}
|
||||
@@ -321,11 +243,10 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
|
||||
|
||||
var req ResponsesRequest
|
||||
|
||||
// 增加日志记录请求体
|
||||
// 读取并解析请求体
|
||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // 把body重新写回去
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// 解析请求体
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
@@ -336,7 +257,7 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 使用tiktoken精确计算请求token数
|
||||
// 计算请求token数
|
||||
messages := convertToTikTokenMessages(req.Input)
|
||||
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
|
||||
|
||||
@@ -352,10 +273,9 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 准备转发请求
|
||||
// 用body创建一个json对象
|
||||
var jsonBody map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
|
||||
// 准备后端请求
|
||||
requestBody, err := prepareBackendRequest(bodyBytes, backendModel.Name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to process request",
|
||||
@@ -364,40 +284,12 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
// 然后修改jsonBody的model字段
|
||||
jsonBody["model"] = backendModel.Name
|
||||
// 最后重新marshal回requestBody
|
||||
requestBody, err := json.Marshal(jsonBody)
|
||||
|
||||
// 构建后端API URL
|
||||
backendURL := backendModel.Provider.BaseURL + "/v1/responses"
|
||||
|
||||
// 创建HTTP请求
|
||||
httpReq, err := http.NewRequest("POST", backendURL, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to create backend request",
|
||||
"type": "internal_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+backendModel.Provider.ApiKey)
|
||||
|
||||
// 复制原始请求的其他相关头部
|
||||
if userAgent := c.GetHeader("User-Agent"); userAgent != "" {
|
||||
httpReq.Header.Set("User-Agent", userAgent)
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
client := &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
}
|
||||
resp, err := client.Do(httpReq)
|
||||
// 转发请求到后端
|
||||
resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
@@ -425,58 +317,23 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 计算响应token数
|
||||
responseTokenCount := 0
|
||||
var responseData map[string]interface{}
|
||||
if err := json.Unmarshal(responseBody, &responseData); err == nil {
|
||||
if choices, ok := responseData["choices"].([]interface{}); ok && len(choices) > 0 {
|
||||
if choice, ok := choices[0].(map[string]interface{}); ok {
|
||||
if message, ok := choice["message"].(map[string]interface{}); ok {
|
||||
if content, ok := message["content"].(string); ok {
|
||||
responseTokenCount = billing.CalculateTextTokensSimple(content)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
responseTokenCount := extractResponseTokenCount(responseBody)
|
||||
|
||||
// 计算费用
|
||||
costCalculator := billing.NewCostCalculator()
|
||||
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
|
||||
|
||||
// 从上下文获取API密钥
|
||||
apiKeyValue, exists := c.Get("apiKey")
|
||||
var apiKeyID uint
|
||||
if exists {
|
||||
if apiKey, ok := apiKeyValue.(models.APIKey); ok {
|
||||
apiKeyID = apiKey.ID
|
||||
}
|
||||
}
|
||||
// 获取API Key ID
|
||||
apiKeyID := getAPIKeyID(c)
|
||||
|
||||
// 创建日志记录
|
||||
logEntry := &models.RequestLog{
|
||||
APIKeyID: apiKeyID,
|
||||
VirtualModelName: req.Model,
|
||||
BackendModelName: backendModel.Name,
|
||||
RequestTimestamp: requestTimestamp,
|
||||
ResponseTimestamp: responseTimestamp,
|
||||
RequestTokens: requestTokenCount,
|
||||
ResponseTokens: responseTokenCount,
|
||||
Cost: cost,
|
||||
RequestBody: string(requestBody),
|
||||
ResponseBody: string(responseBody),
|
||||
}
|
||||
|
||||
// 异步记录日志
|
||||
// 创建并记录日志
|
||||
logEntry := createRequestLog(apiKeyID, req.Model, backendModel,
|
||||
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
|
||||
cost, string(requestBody), string(responseBody))
|
||||
logger.LogRequest(h.DB, logEntry)
|
||||
|
||||
// 复制响应体
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
c.Header(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// 设置响应状态码并返回响应体
|
||||
// 复制响应头并返回响应
|
||||
copyResponseHeaders(c, resp)
|
||||
c.Status(resp.StatusCode)
|
||||
c.Writer.Write(responseBody)
|
||||
}
|
||||
@@ -493,6 +350,117 @@ func convertToTikTokenMessages(messages []ChatCompletionMessage) []billing.ChatC
|
||||
return result
|
||||
}
|
||||
|
||||
// prepareBackendRequest 准备后端请求体,替换model字段
|
||||
func prepareBackendRequest(bodyBytes []byte, backendModelName string) ([]byte, error) {
|
||||
var jsonBody map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal request body: %w", err)
|
||||
}
|
||||
|
||||
// 修改model字段为后端模型名称
|
||||
jsonBody["model"] = backendModelName
|
||||
|
||||
// 重新marshal
|
||||
requestBody, err := json.Marshal(jsonBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
return requestBody, nil
|
||||
}
|
||||
|
||||
// forwardRequest 转发请求到后端API
|
||||
func forwardRequest(backendURL string, apiKey string, requestBody []byte, userAgent string) (*http.Response, error) {
|
||||
// 创建HTTP请求
|
||||
httpReq, err := http.NewRequest("POST", backendURL, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create backend request: %w", err)
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
// 设置User-Agent
|
||||
if userAgent != "" {
|
||||
httpReq.Header.Set("User-Agent", userAgent)
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
client := &http.Client{
|
||||
Timeout: 10 * time.Minute,
|
||||
}
|
||||
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to backend service: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// extractResponseTokenCount 从响应体中提取token数量
|
||||
func extractResponseTokenCount(responseBody []byte) int {
|
||||
var responseData map[string]interface{}
|
||||
if err := json.Unmarshal(responseBody, &responseData); err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if choices, ok := responseData["choices"].([]interface{}); ok && len(choices) > 0 {
|
||||
if choice, ok := choices[0].(map[string]interface{}); ok {
|
||||
if message, ok := choice["message"].(map[string]interface{}); ok {
|
||||
if content, ok := message["content"].(string); ok {
|
||||
return billing.CalculateTextTokensSimple(content)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// createRequestLog 创建请求日志记录
|
||||
func createRequestLog(apiKeyID uint, virtualModelName string, backendModel *models.BackendModel,
|
||||
requestTimestamp, responseTimestamp time.Time, requestTokenCount, responseTokenCount int,
|
||||
cost float64, requestBody, responseBody string) *models.RequestLog {
|
||||
|
||||
return &models.RequestLog{
|
||||
APIKeyID: apiKeyID,
|
||||
VirtualModelName: virtualModelName,
|
||||
BackendModelName: backendModel.Name,
|
||||
RequestTimestamp: requestTimestamp,
|
||||
ResponseTimestamp: responseTimestamp,
|
||||
RequestTokens: requestTokenCount,
|
||||
ResponseTokens: responseTokenCount,
|
||||
Cost: cost,
|
||||
RequestBody: requestBody,
|
||||
ResponseBody: responseBody,
|
||||
}
|
||||
}
|
||||
|
||||
// getAPIKeyID 从gin上下文中提取API Key ID
|
||||
func getAPIKeyID(c *gin.Context) uint {
|
||||
apiKeyValue, exists := c.Get("apiKey")
|
||||
if !exists {
|
||||
return 0
|
||||
}
|
||||
|
||||
if apiKey, ok := apiKeyValue.(models.APIKey); ok {
|
||||
return apiKey.ID
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// copyResponseHeaders 复制响应头到gin context
|
||||
func copyResponseHeaders(c *gin.Context, resp *http.Response) {
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
c.Header(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleStreamingResponse 处理流式响应
|
||||
func handleStreamingResponse(c *gin.Context, resp *http.Response, requestTimestamp, responseTimestamp time.Time,
|
||||
apiKeyID uint, virtualModelName string, backendModel *models.BackendModel, requestTokenCount int,
|
||||
|
||||
Reference in New Issue
Block a user