307 lines
7.9 KiB
Go
307 lines
7.9 KiB
Go
package api
|
||
|
||
import (
|
||
"ai-gateway/internal/billing"
|
||
"ai-gateway/internal/logger"
|
||
"ai-gateway/internal/models"
|
||
"ai-gateway/internal/router"
|
||
"encoding/json"
|
||
"io"
|
||
"log"
|
||
"net/http"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
// ListModels 处理 GET /models 请求
|
||
func (h *APIHandler) ListModels(c *gin.Context) {
|
||
var virtualModels []models.VirtualModel
|
||
|
||
// 查询所有虚拟模型
|
||
if err := h.DB.Find(&virtualModels).Error; err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"error": gin.H{
|
||
"message": "Failed to retrieve models",
|
||
"type": "internal_error",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
// 格式化为OpenAI API响应格式
|
||
modelData := make([]ModelData, len(virtualModels))
|
||
for i, vm := range virtualModels {
|
||
modelData[i] = ModelData{
|
||
ID: vm.Name,
|
||
Object: "model",
|
||
Created: vm.CreatedAt.Unix(),
|
||
OwnedBy: "ai-gateway",
|
||
}
|
||
}
|
||
|
||
response := ModelListResponse{
|
||
Object: "list",
|
||
Data: modelData,
|
||
}
|
||
|
||
c.JSON(http.StatusOK, response)
|
||
}
|
||
|
||
// ChatCompletions 处理 POST /v1/chat/completions 请求
|
||
func (h *APIHandler) ChatCompletions(c *gin.Context) {
|
||
// 记录请求开始时间
|
||
requestTimestamp := time.Now()
|
||
|
||
// 读取请求体
|
||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||
|
||
// 解析为 map,只解析一次
|
||
var jsonBody map[string]interface{}
|
||
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
|
||
log.Printf("Failed to unmarshal JSON: %v", err)
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"error": gin.H{
|
||
"message": "Invalid request format",
|
||
"type": "invalid_request_error",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
// 从 map 中提取需要的字段
|
||
modelName, _ := jsonBody["model"].(string)
|
||
isStream, _ := jsonBody["stream"].(bool)
|
||
|
||
// 提取 messages 并转换为计算 token 所需的格式
|
||
var messages []billing.ChatCompletionMessage
|
||
if messagesRaw, ok := jsonBody["messages"].([]interface{}); ok {
|
||
messages = make([]billing.ChatCompletionMessage, 0, len(messagesRaw))
|
||
for _, msgRaw := range messagesRaw {
|
||
if msgMap, ok := msgRaw.(map[string]interface{}); ok {
|
||
role, _ := msgMap["role"].(string)
|
||
content, _ := msgMap["content"].(string)
|
||
messages = append(messages, billing.ChatCompletionMessage{
|
||
Role: role,
|
||
Content: content,
|
||
})
|
||
}
|
||
}
|
||
}
|
||
|
||
// 计算请求token数
|
||
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
|
||
|
||
// 选择后端模型
|
||
backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount)
|
||
if err != nil {
|
||
c.JSON(http.StatusNotFound, gin.H{
|
||
"error": gin.H{
|
||
"message": err.Error(),
|
||
"type": "invalid_request_error",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
// 修改 model 字段为后端模型名称
|
||
jsonBody["model"] = backendModel.Name
|
||
|
||
// Marshal 请求体用于转发
|
||
requestBody, err := json.Marshal(jsonBody)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"error": gin.H{
|
||
"message": "Failed to process request",
|
||
"type": "internal_error",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
// 构建后端API URL
|
||
backendURL := backendModel.Provider.BaseURL + "/v1/chat/completions"
|
||
|
||
// 转发请求到后端
|
||
resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent"))
|
||
if err != nil {
|
||
c.JSON(http.StatusBadGateway, gin.H{
|
||
"error": gin.H{
|
||
"message": "Failed to connect to backend service",
|
||
"type": "service_unavailable",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 记录响应时间
|
||
responseTimestamp := time.Now()
|
||
|
||
// 获取API Key ID
|
||
apiKeyID := getAPIKeyID(c)
|
||
|
||
// 处理流式响应
|
||
if isStream {
|
||
handleStreamingResponse(c, resp, requestTimestamp, responseTimestamp,
|
||
apiKeyID, modelName, backendModel, requestTokenCount,
|
||
string(requestBody), h.DB)
|
||
return
|
||
}
|
||
|
||
// 处理非流式响应
|
||
responseBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"error": gin.H{
|
||
"message": "Failed to read backend response",
|
||
"type": "internal_error",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
log.Printf("Backend Response Status: %s", resp.Status)
|
||
log.Printf("Backend Response Body: %s", string(responseBody))
|
||
|
||
// 计算响应token数
|
||
responseTokenCount := extractResponseTokenCount(responseBody)
|
||
|
||
// 计算费用
|
||
costCalculator := billing.NewCostCalculator()
|
||
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
|
||
|
||
// 创建并记录日志
|
||
logEntry := createRequestLog(apiKeyID, modelName, backendModel,
|
||
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
|
||
cost, string(requestBody), string(responseBody))
|
||
logger.LogRequest(h.DB, logEntry)
|
||
|
||
// 复制响应头并返回响应
|
||
copyResponseHeaders(c, resp)
|
||
c.Status(resp.StatusCode)
|
||
c.Writer.Write(responseBody)
|
||
}
|
||
|
||
// ResponsesCompletions 处理 POST /v1/responses 请求
|
||
func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
|
||
// 记录请求开始时间
|
||
requestTimestamp := time.Now()
|
||
|
||
// 读取请求体
|
||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||
|
||
// 解析为 map,只解析一次
|
||
var jsonBody map[string]interface{}
|
||
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"error": gin.H{
|
||
"message": "Invalid request format",
|
||
"type": "invalid_request_error",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
// 从 map 中提取需要的字段
|
||
modelName, _ := jsonBody["model"].(string)
|
||
|
||
// 提取 input 并转换为计算 token 所需的格式
|
||
var messages []billing.ChatCompletionMessage
|
||
if inputRaw, ok := jsonBody["input"].([]interface{}); ok {
|
||
messages = make([]billing.ChatCompletionMessage, 0, len(inputRaw))
|
||
for _, msgRaw := range inputRaw {
|
||
if msgMap, ok := msgRaw.(map[string]interface{}); ok {
|
||
role, _ := msgMap["role"].(string)
|
||
content, _ := msgMap["content"].(string)
|
||
messages = append(messages, billing.ChatCompletionMessage{
|
||
Role: role,
|
||
Content: content,
|
||
})
|
||
}
|
||
}
|
||
}
|
||
|
||
// 计算请求token数
|
||
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
|
||
|
||
// 选择后端模型
|
||
backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount)
|
||
if err != nil {
|
||
c.JSON(http.StatusNotFound, gin.H{
|
||
"error": gin.H{
|
||
"message": err.Error(),
|
||
"type": "invalid_request_error",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
// 修改 model 字段为后端模型名称
|
||
jsonBody["model"] = backendModel.Name
|
||
|
||
// Marshal 请求体用于转发
|
||
requestBody, err := json.Marshal(jsonBody)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"error": gin.H{
|
||
"message": "Failed to process request",
|
||
"type": "internal_error",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
// 构建后端API URL
|
||
backendURL := backendModel.Provider.BaseURL + "/v1/responses"
|
||
|
||
// 转发请求到后端
|
||
resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent"))
|
||
if err != nil {
|
||
c.JSON(http.StatusBadGateway, gin.H{
|
||
"error": gin.H{
|
||
"message": "Failed to connect to backend service",
|
||
"type": "service_unavailable",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 记录响应时间
|
||
responseTimestamp := time.Now()
|
||
|
||
// 读取响应体
|
||
responseBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"error": gin.H{
|
||
"message": "Failed to read backend response",
|
||
"type": "internal_error",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
// 计算响应token数
|
||
responseTokenCount := extractResponseTokenCount(responseBody)
|
||
|
||
// 计算费用
|
||
costCalculator := billing.NewCostCalculator()
|
||
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
|
||
|
||
// 获取API Key ID
|
||
apiKeyID := getAPIKeyID(c)
|
||
|
||
// 创建并记录日志
|
||
logEntry := createRequestLog(apiKeyID, modelName, backendModel,
|
||
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
|
||
cost, string(requestBody), string(responseBody))
|
||
logger.LogRequest(h.DB, logEntry)
|
||
|
||
// 复制响应头并返回响应
|
||
copyResponseHeaders(c, resp)
|
||
c.Status(resp.StatusCode)
|
||
c.Writer.Write(responseBody)
|
||
}
|