重构请求处理逻辑,优化请求体解析,提取模型和消息字段,增加日志记录功能

This commit is contained in:
2025-11-09 00:14:01 +08:00
parent e2780a08b3
commit 5bffcc6244

View File

@@ -127,14 +127,13 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
// 记录请求开始时间
requestTimestamp := time.Now()
var req ChatCompletionRequest
// 读取并解析请求体
// 读取请求体
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
if err := c.ShouldBindJSON(&req); err != nil {
log.Printf("Failed to bind JSON: %v", err)
// 解析为 map只解析一次
var jsonBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
log.Printf("Failed to unmarshal JSON: %v", err)
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "Invalid request format",
@@ -144,12 +143,31 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
return
}
// 从 map 中提取需要的字段
modelName, _ := jsonBody["model"].(string)
isStream, _ := jsonBody["stream"].(bool)
// 提取 messages 并转换为计算 token 所需的格式
var messages []billing.ChatCompletionMessage
if messagesRaw, ok := jsonBody["messages"].([]interface{}); ok {
messages = make([]billing.ChatCompletionMessage, 0, len(messagesRaw))
for _, msgRaw := range messagesRaw {
if msgMap, ok := msgRaw.(map[string]interface{}); ok {
role, _ := msgMap["role"].(string)
content, _ := msgMap["content"].(string)
messages = append(messages, billing.ChatCompletionMessage{
Role: role,
Content: content,
})
}
}
}
// 计算请求token数
messages := convertToTikTokenMessages(req.Messages)
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
// 选择后端模型
backendModel, err := router.SelectBackendModel(h.DB, req.Model, requestTokenCount)
backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"error": gin.H{
@@ -160,8 +178,11 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
return
}
// 准备后端请求
requestBody, err := prepareBackendRequest(bodyBytes, backendModel.Name)
// 修改 model 字段为后端模型名称
jsonBody["model"] = backendModel.Name
// Marshal 请求体用于转发
requestBody, err := json.Marshal(jsonBody)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
@@ -195,9 +216,9 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
apiKeyID := getAPIKeyID(c)
// 处理流式响应
if req.Stream {
if isStream {
handleStreamingResponse(c, resp, requestTimestamp, responseTimestamp,
apiKeyID, req.Model, backendModel, requestTokenCount,
apiKeyID, modelName, backendModel, requestTokenCount,
string(requestBody), h.DB)
return
}
@@ -225,7 +246,7 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 创建并记录日志
logEntry := createRequestLog(apiKeyID, req.Model, backendModel,
logEntry := createRequestLog(apiKeyID, modelName, backendModel,
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
cost, string(requestBody), string(responseBody))
logger.LogRequest(h.DB, logEntry)
@@ -241,13 +262,12 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
// 记录请求开始时间
requestTimestamp := time.Now()
var req ResponsesRequest
// 读取并解析请求体
// 读取请求体
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
if err := c.ShouldBindJSON(&req); err != nil {
// 解析为 map只解析一次
var jsonBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "Invalid request format",
@@ -257,12 +277,30 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
return
}
// 从 map 中提取需要的字段
modelName, _ := jsonBody["model"].(string)
// 提取 input 并转换为计算 token 所需的格式
var messages []billing.ChatCompletionMessage
if inputRaw, ok := jsonBody["input"].([]interface{}); ok {
messages = make([]billing.ChatCompletionMessage, 0, len(inputRaw))
for _, msgRaw := range inputRaw {
if msgMap, ok := msgRaw.(map[string]interface{}); ok {
role, _ := msgMap["role"].(string)
content, _ := msgMap["content"].(string)
messages = append(messages, billing.ChatCompletionMessage{
Role: role,
Content: content,
})
}
}
}
// 计算请求token数
messages := convertToTikTokenMessages(req.Input)
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
// 选择后端模型
backendModel, err := router.SelectBackendModel(h.DB, req.Model, requestTokenCount)
backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"error": gin.H{
@@ -273,8 +311,11 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
return
}
// 准备后端请求
requestBody, err := prepareBackendRequest(bodyBytes, backendModel.Name)
// 修改 model 字段为后端模型名称
jsonBody["model"] = backendModel.Name
// Marshal 请求体用于转发
requestBody, err := json.Marshal(jsonBody)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
@@ -327,7 +368,7 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
apiKeyID := getAPIKeyID(c)
// 创建并记录日志
logEntry := createRequestLog(apiKeyID, req.Model, backendModel,
logEntry := createRequestLog(apiKeyID, modelName, backendModel,
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
cost, string(requestBody), string(responseBody))
logger.LogRequest(h.DB, logEntry)