重构请求处理逻辑,优化请求体解析,提取模型和消息字段,增加日志记录功能
This commit is contained in:
@@ -127,14 +127,13 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
|
||||
// 记录请求开始时间
|
||||
requestTimestamp := time.Now()
|
||||
|
||||
var req ChatCompletionRequest
|
||||
|
||||
// 读取并解析请求体
|
||||
// 读取请求体
|
||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
log.Printf("Failed to bind JSON: %v", err)
|
||||
// 解析为 map,只解析一次
|
||||
var jsonBody map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
|
||||
log.Printf("Failed to unmarshal JSON: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Invalid request format",
|
||||
@@ -144,12 +143,31 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 从 map 中提取需要的字段
|
||||
modelName, _ := jsonBody["model"].(string)
|
||||
isStream, _ := jsonBody["stream"].(bool)
|
||||
|
||||
// 提取 messages 并转换为计算 token 所需的格式
|
||||
var messages []billing.ChatCompletionMessage
|
||||
if messagesRaw, ok := jsonBody["messages"].([]interface{}); ok {
|
||||
messages = make([]billing.ChatCompletionMessage, 0, len(messagesRaw))
|
||||
for _, msgRaw := range messagesRaw {
|
||||
if msgMap, ok := msgRaw.(map[string]interface{}); ok {
|
||||
role, _ := msgMap["role"].(string)
|
||||
content, _ := msgMap["content"].(string)
|
||||
messages = append(messages, billing.ChatCompletionMessage{
|
||||
Role: role,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 计算请求token数
|
||||
messages := convertToTikTokenMessages(req.Messages)
|
||||
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
|
||||
|
||||
// 选择后端模型
|
||||
backendModel, err := router.SelectBackendModel(h.DB, req.Model, requestTokenCount)
|
||||
backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": gin.H{
|
||||
@@ -160,8 +178,11 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 准备后端请求
|
||||
requestBody, err := prepareBackendRequest(bodyBytes, backendModel.Name)
|
||||
// 修改 model 字段为后端模型名称
|
||||
jsonBody["model"] = backendModel.Name
|
||||
|
||||
// Marshal 请求体用于转发
|
||||
requestBody, err := json.Marshal(jsonBody)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
@@ -195,9 +216,9 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
|
||||
apiKeyID := getAPIKeyID(c)
|
||||
|
||||
// 处理流式响应
|
||||
if req.Stream {
|
||||
if isStream {
|
||||
handleStreamingResponse(c, resp, requestTimestamp, responseTimestamp,
|
||||
apiKeyID, req.Model, backendModel, requestTokenCount,
|
||||
apiKeyID, modelName, backendModel, requestTokenCount,
|
||||
string(requestBody), h.DB)
|
||||
return
|
||||
}
|
||||
@@ -225,7 +246,7 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
|
||||
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
|
||||
|
||||
// 创建并记录日志
|
||||
logEntry := createRequestLog(apiKeyID, req.Model, backendModel,
|
||||
logEntry := createRequestLog(apiKeyID, modelName, backendModel,
|
||||
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
|
||||
cost, string(requestBody), string(responseBody))
|
||||
logger.LogRequest(h.DB, logEntry)
|
||||
@@ -241,13 +262,12 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
|
||||
// 记录请求开始时间
|
||||
requestTimestamp := time.Now()
|
||||
|
||||
var req ResponsesRequest
|
||||
|
||||
// 读取并解析请求体
|
||||
// 读取请求体
|
||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
// 解析为 map,只解析一次
|
||||
var jsonBody map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Invalid request format",
|
||||
@@ -257,12 +277,30 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 从 map 中提取需要的字段
|
||||
modelName, _ := jsonBody["model"].(string)
|
||||
|
||||
// 提取 input 并转换为计算 token 所需的格式
|
||||
var messages []billing.ChatCompletionMessage
|
||||
if inputRaw, ok := jsonBody["input"].([]interface{}); ok {
|
||||
messages = make([]billing.ChatCompletionMessage, 0, len(inputRaw))
|
||||
for _, msgRaw := range inputRaw {
|
||||
if msgMap, ok := msgRaw.(map[string]interface{}); ok {
|
||||
role, _ := msgMap["role"].(string)
|
||||
content, _ := msgMap["content"].(string)
|
||||
messages = append(messages, billing.ChatCompletionMessage{
|
||||
Role: role,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 计算请求token数
|
||||
messages := convertToTikTokenMessages(req.Input)
|
||||
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
|
||||
|
||||
// 选择后端模型
|
||||
backendModel, err := router.SelectBackendModel(h.DB, req.Model, requestTokenCount)
|
||||
backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": gin.H{
|
||||
@@ -273,8 +311,11 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 准备后端请求
|
||||
requestBody, err := prepareBackendRequest(bodyBytes, backendModel.Name)
|
||||
// 修改 model 字段为后端模型名称
|
||||
jsonBody["model"] = backendModel.Name
|
||||
|
||||
// Marshal 请求体用于转发
|
||||
requestBody, err := json.Marshal(jsonBody)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
@@ -327,7 +368,7 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
|
||||
apiKeyID := getAPIKeyID(c)
|
||||
|
||||
// 创建并记录日志
|
||||
logEntry := createRequestLog(apiKeyID, req.Model, backendModel,
|
||||
logEntry := createRequestLog(apiKeyID, modelName, backendModel,
|
||||
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
|
||||
cost, string(requestBody), string(responseBody))
|
||||
logger.LogRequest(h.DB, logEntry)
|
||||
|
||||
Reference in New Issue
Block a user