重构API处理逻辑,新增provider和virtual model处理功能,优化请求体解析和日志记录

This commit is contained in:
2025-11-09 00:22:23 +08:00
parent 5bffcc6244
commit 38889be072
4 changed files with 623 additions and 623 deletions

View File

@@ -2,18 +2,11 @@ package api
import (
"ai-gateway/internal/billing"
"ai-gateway/internal/db"
"ai-gateway/internal/logger"
"ai-gateway/internal/models"
"ai-gateway/internal/router"
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
@@ -88,328 +81,6 @@ type UpdateVirtualModelRequest struct {
BackendModels []BackendModelAssociation `json:"backend_models"`
}
// ListModels 处理 GET /models 请求
func (h *APIHandler) ListModels(c *gin.Context) {
var virtualModels []models.VirtualModel
// 查询所有虚拟模型
if err := h.DB.Find(&virtualModels).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to retrieve models",
"type": "internal_error",
},
})
return
}
// 格式化为OpenAI API响应格式
modelData := make([]ModelData, len(virtualModels))
for i, vm := range virtualModels {
modelData[i] = ModelData{
ID: vm.Name,
Object: "model",
Created: vm.CreatedAt.Unix(),
OwnedBy: "ai-gateway",
}
}
response := ModelListResponse{
Object: "list",
Data: modelData,
}
c.JSON(http.StatusOK, response)
}
// ChatCompletions 处理 POST /v1/chat/completions 请求
func (h *APIHandler) ChatCompletions(c *gin.Context) {
// 记录请求开始时间
requestTimestamp := time.Now()
// 读取请求体
bodyBytes, _ := io.ReadAll(c.Request.Body)
// 解析为 map只解析一次
var jsonBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
log.Printf("Failed to unmarshal JSON: %v", err)
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "Invalid request format",
"type": "invalid_request_error",
},
})
return
}
// 从 map 中提取需要的字段
modelName, _ := jsonBody["model"].(string)
isStream, _ := jsonBody["stream"].(bool)
// 提取 messages 并转换为计算 token 所需的格式
var messages []billing.ChatCompletionMessage
if messagesRaw, ok := jsonBody["messages"].([]interface{}); ok {
messages = make([]billing.ChatCompletionMessage, 0, len(messagesRaw))
for _, msgRaw := range messagesRaw {
if msgMap, ok := msgRaw.(map[string]interface{}); ok {
role, _ := msgMap["role"].(string)
content, _ := msgMap["content"].(string)
messages = append(messages, billing.ChatCompletionMessage{
Role: role,
Content: content,
})
}
}
}
// 计算请求token数
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
// 选择后端模型
backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "invalid_request_error",
},
})
return
}
// 修改 model 字段为后端模型名称
jsonBody["model"] = backendModel.Name
// Marshal 请求体用于转发
requestBody, err := json.Marshal(jsonBody)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to process request",
"type": "internal_error",
},
})
return
}
// 构建后端API URL
backendURL := backendModel.Provider.BaseURL + "/v1/chat/completions"
// 转发请求到后端
resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent"))
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"message": "Failed to connect to backend service",
"type": "service_unavailable",
},
})
return
}
defer resp.Body.Close()
// 记录响应时间
responseTimestamp := time.Now()
// 获取API Key ID
apiKeyID := getAPIKeyID(c)
// 处理流式响应
if isStream {
handleStreamingResponse(c, resp, requestTimestamp, responseTimestamp,
apiKeyID, modelName, backendModel, requestTokenCount,
string(requestBody), h.DB)
return
}
// 处理非流式响应
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to read backend response",
"type": "internal_error",
},
})
return
}
log.Printf("Backend Response Status: %s", resp.Status)
log.Printf("Backend Response Body: %s", string(responseBody))
// 计算响应token数
responseTokenCount := extractResponseTokenCount(responseBody)
// 计算费用
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 创建并记录日志
logEntry := createRequestLog(apiKeyID, modelName, backendModel,
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
cost, string(requestBody), string(responseBody))
logger.LogRequest(h.DB, logEntry)
// 复制响应头并返回响应
copyResponseHeaders(c, resp)
c.Status(resp.StatusCode)
c.Writer.Write(responseBody)
}
// ResponsesCompletions 处理 POST /v1/responses 请求
func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
// 记录请求开始时间
requestTimestamp := time.Now()
// 读取请求体
bodyBytes, _ := io.ReadAll(c.Request.Body)
// 解析为 map只解析一次
var jsonBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "Invalid request format",
"type": "invalid_request_error",
},
})
return
}
// 从 map 中提取需要的字段
modelName, _ := jsonBody["model"].(string)
// 提取 input 并转换为计算 token 所需的格式
var messages []billing.ChatCompletionMessage
if inputRaw, ok := jsonBody["input"].([]interface{}); ok {
messages = make([]billing.ChatCompletionMessage, 0, len(inputRaw))
for _, msgRaw := range inputRaw {
if msgMap, ok := msgRaw.(map[string]interface{}); ok {
role, _ := msgMap["role"].(string)
content, _ := msgMap["content"].(string)
messages = append(messages, billing.ChatCompletionMessage{
Role: role,
Content: content,
})
}
}
}
// 计算请求token数
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
// 选择后端模型
backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "invalid_request_error",
},
})
return
}
// 修改 model 字段为后端模型名称
jsonBody["model"] = backendModel.Name
// Marshal 请求体用于转发
requestBody, err := json.Marshal(jsonBody)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to process request",
"type": "internal_error",
},
})
return
}
// 构建后端API URL
backendURL := backendModel.Provider.BaseURL + "/v1/responses"
// 转发请求到后端
resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent"))
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"message": "Failed to connect to backend service",
"type": "service_unavailable",
},
})
return
}
defer resp.Body.Close()
// 记录响应时间
responseTimestamp := time.Now()
// 读取响应体
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to read backend response",
"type": "internal_error",
},
})
return
}
// 计算响应token数
responseTokenCount := extractResponseTokenCount(responseBody)
// 计算费用
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 获取API Key ID
apiKeyID := getAPIKeyID(c)
// 创建并记录日志
logEntry := createRequestLog(apiKeyID, modelName, backendModel,
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
cost, string(requestBody), string(responseBody))
logger.LogRequest(h.DB, logEntry)
// 复制响应头并返回响应
copyResponseHeaders(c, resp)
c.Status(resp.StatusCode)
c.Writer.Write(responseBody)
}
// convertToTikTokenMessages 将ChatCompletionMessage转换为billing包的消息格式
func convertToTikTokenMessages(messages []ChatCompletionMessage) []billing.ChatCompletionMessage {
result := make([]billing.ChatCompletionMessage, len(messages))
for i, msg := range messages {
result[i] = billing.ChatCompletionMessage{
Role: msg.Role,
Content: msg.Content,
}
}
return result
}
// prepareBackendRequest 准备后端请求体,替换model字段
func prepareBackendRequest(bodyBytes []byte, backendModelName string) ([]byte, error) {
var jsonBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
return nil, fmt.Errorf("failed to unmarshal request body: %w", err)
}
// 修改model字段为后端模型名称
jsonBody["model"] = backendModelName
// 重新marshal
requestBody, err := json.Marshal(jsonBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
return requestBody, nil
}
// forwardRequest 转发请求到后端API
func forwardRequest(backendURL string, apiKey string, requestBody []byte, userAgent string) (*http.Response, error) {
// 创建HTTP请求
@@ -501,297 +172,3 @@ func copyResponseHeaders(c *gin.Context, resp *http.Response) {
}
}
}
// handleStreamingResponse 处理流式响应
func handleStreamingResponse(c *gin.Context, resp *http.Response, requestTimestamp, responseTimestamp time.Time,
apiKeyID uint, virtualModelName string, backendModel *models.BackendModel, requestTokenCount int,
requestBody string, database *gorm.DB) {
// 设置流式响应\u5934
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Transfer-Encoding", "chunked")
// 复制其他响应头
for key, values := range resp.Header {
if key != "Content-Length" && key != "Transfer-Encoding" {
for _, value := range values {
c.Header(key, value)
}
}
}
c.Status(resp.StatusCode)
// 用于累积完整响应内容
var fullContent strings.Builder
var responseBody strings.Builder
// 创建一个 scanner 来逐行读取流
scanner := bufio.NewScanner(resp.Body)
flusher, ok := c.Writer.(http.Flusher)
for scanner.Scan() {
line := scanner.Text()
// 将原始行写入响应体缓冲区
responseBody.WriteString(line)
responseBody.WriteString("\n")
// SSE 格式data: {...}
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
// 检查是否是结束标记
if data == "[DONE]" {
_, err := c.Writer.Write([]byte(line + "\n\n"))
if err != nil {
return
}
if ok {
flusher.Flush()
}
break
}
// 解析 JSON 数据以提取内容
var chunk map[string]interface{}
if err := json.Unmarshal([]byte(data), &chunk); err == nil {
if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 {
if choice, ok := choices[0].(map[string]interface{}); ok {
if delta, ok := choice["delta"].(map[string]interface{}); ok {
if content, ok := delta["content"].(string); ok {
fullContent.WriteString(content)
}
}
}
}
}
}
// 转发数据到客户端
_, err := c.Writer.Write([]byte(line + "\n"))
if err != nil {
return
}
// 如果是空行SSE 消息分隔符),刷新
if line == "" {
if ok {
flusher.Flush()
}
}
}
// 确保发送最后的数据
if ok {
flusher.Flush()
}
// 扫描可能出现的错误
if err := scanner.Err(); err != nil {
log.Printf("Error reading stream: %v", err)
}
// 计算响应 token 数
responseTokenCount := billing.CalculateTextTokensSimple(fullContent.String())
// 计算费用
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 创建日志记录
logEntry := &models.RequestLog{
APIKeyID: apiKeyID,
VirtualModelName: virtualModelName,
BackendModelName: backendModel.Name,
RequestTimestamp: requestTimestamp,
ResponseTimestamp: time.Now(), // 使用实际结束时间
RequestTokens: requestTokenCount,
ResponseTokens: responseTokenCount,
Cost: cost,
RequestBody: requestBody,
ResponseBody: responseBody.String(),
}
// 异步记录日志
logger.LogRequest(database, logEntry)
}
// GetProvidersHandler 处理 GET /api/providers 请求
func (h *APIHandler) GetProvidersHandler(c *gin.Context) {
providers, err := db.GetProviders(h.DB)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to retrieve providers",
"type": "internal_error",
},
})
return
}
c.JSON(http.StatusOK, providers)
}
// GetProviderHandler 处理 GET /api/providers/:id 请求
func (h *APIHandler) GetProviderHandler(c *gin.Context) {
id := c.Param("id")
var providerID uint
if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
return
}
provider, err := db.GetProviderByID(h.DB, providerID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Provider not found"})
return
}
c.JSON(http.StatusOK, provider)
}
// CreateProviderHandler 处理 POST /api/providers 请求
func (h *APIHandler) CreateProviderHandler(c *gin.Context) {
var provider models.Provider
if err := c.ShouldBindJSON(&provider); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := db.CreateProvider(h.DB, &provider); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider"})
return
}
c.JSON(http.StatusCreated, provider)
}
// UpdateProviderHandler 处理 PUT /api/providers/:id 请求
func (h *APIHandler) UpdateProviderHandler(c *gin.Context) {
id := c.Param("id")
var providerID uint
if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
return
}
var provider models.Provider
if err := c.ShouldBindJSON(&provider); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
provider.ID = providerID
if err := db.UpdateProvider(h.DB, &provider); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider"})
return
}
c.JSON(http.StatusOK, provider)
}
// DeleteProviderHandler 处理 DELETE /api/providers/:id 请求
func (h *APIHandler) DeleteProviderHandler(c *gin.Context) {
id := c.Param("id")
var providerID uint
if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
return
}
if err := db.DeleteProvider(h.DB, providerID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete provider"})
return
}
c.JSON(http.StatusNoContent, nil)
}
// CreateVirtualModelHandler 处理 POST /api/virtual-models 请求
func (h *APIHandler) CreateVirtualModelHandler(c *gin.Context) {
var vm models.VirtualModel
if err := c.ShouldBindJSON(&vm); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := db.CreateVirtualModel(h.DB, &vm); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create virtual model"})
return
}
c.JSON(http.StatusCreated, vm)
}
// GetVirtualModelsHandler 处理 GET /api/virtual-models 请求
func (h *APIHandler) GetVirtualModelsHandler(c *gin.Context) {
virtualModels, err := db.GetVirtualModels(h.DB)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve virtual models"})
return
}
c.JSON(http.StatusOK, virtualModels)
}
// GetVirtualModelHandler 处理 GET /api/virtual-models/:id 请求
func (h *APIHandler) GetVirtualModelHandler(c *gin.Context) {
id := c.Param("id")
var virtualModelID uint
if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"})
return
}
virtualModel, err := db.GetVirtualModelByID(h.DB, virtualModelID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Virtual model not found"})
return
}
c.JSON(http.StatusOK, virtualModel)
}
// UpdateVirtualModelHandler 处理 PUT /api/virtual-models/:id 请求
func (h *APIHandler) UpdateVirtualModelHandler(c *gin.Context) {
id := c.Param("id")
var virtualModelID uint
if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"})
return
}
var vm models.VirtualModel
if err := c.ShouldBindJSON(&vm); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
vm.ID = virtualModelID
if err := db.UpdateVirtualModel(h.DB, &vm); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update virtual model"})
return
}
c.JSON(http.StatusOK, vm)
}
// DeleteVirtualModelHandler 处理 DELETE /api/virtual-models/:id 请求
func (h *APIHandler) DeleteVirtualModelHandler(c *gin.Context) {
id := c.Param("id")
var virtualModelID uint
if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"})
return
}
if err := db.DeleteVirtualModel(h.DB, virtualModelID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete virtual model"})
return
}
c.JSON(http.StatusNoContent, nil)
}

View File

@@ -0,0 +1,426 @@
package api
import (
"ai-gateway/internal/billing"
"ai-gateway/internal/logger"
"ai-gateway/internal/models"
"ai-gateway/internal/router"
"bufio"
"encoding/json"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// ListModels 处理 GET /models 请求
func (h *APIHandler) ListModels(c *gin.Context) {
var virtualModels []models.VirtualModel
// 查询所有虚拟模型
if err := h.DB.Find(&virtualModels).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to retrieve models",
"type": "internal_error",
},
})
return
}
// 格式化为OpenAI API响应格式
modelData := make([]ModelData, len(virtualModels))
for i, vm := range virtualModels {
modelData[i] = ModelData{
ID: vm.Name,
Object: "model",
Created: vm.CreatedAt.Unix(),
OwnedBy: "ai-gateway",
}
}
response := ModelListResponse{
Object: "list",
Data: modelData,
}
c.JSON(http.StatusOK, response)
}
// ChatCompletions 处理 POST /v1/chat/completions 请求
func (h *APIHandler) ChatCompletions(c *gin.Context) {
// 记录请求开始时间
requestTimestamp := time.Now()
// 读取请求体
bodyBytes, _ := io.ReadAll(c.Request.Body)
// 解析为 map只解析一次
var jsonBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
log.Printf("Failed to unmarshal JSON: %v", err)
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "Invalid request format",
"type": "invalid_request_error",
},
})
return
}
// 从 map 中提取需要的字段
modelName, _ := jsonBody["model"].(string)
isStream, _ := jsonBody["stream"].(bool)
// 提取 messages 并转换为计算 token 所需的格式
var messages []billing.ChatCompletionMessage
if messagesRaw, ok := jsonBody["messages"].([]interface{}); ok {
messages = make([]billing.ChatCompletionMessage, 0, len(messagesRaw))
for _, msgRaw := range messagesRaw {
if msgMap, ok := msgRaw.(map[string]interface{}); ok {
role, _ := msgMap["role"].(string)
content, _ := msgMap["content"].(string)
messages = append(messages, billing.ChatCompletionMessage{
Role: role,
Content: content,
})
}
}
}
// 计算请求token数
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
// 选择后端模型
backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "invalid_request_error",
},
})
return
}
// 修改 model 字段为后端模型名称
jsonBody["model"] = backendModel.Name
// Marshal 请求体用于转发
requestBody, err := json.Marshal(jsonBody)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to process request",
"type": "internal_error",
},
})
return
}
// 构建后端API URL
backendURL := backendModel.Provider.BaseURL + "/v1/chat/completions"
// 转发请求到后端
resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent"))
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"message": "Failed to connect to backend service",
"type": "service_unavailable",
},
})
return
}
defer resp.Body.Close()
// 记录响应时间
responseTimestamp := time.Now()
// 获取API Key ID
apiKeyID := getAPIKeyID(c)
// 处理流式响应
if isStream {
handleStreamingResponse(c, resp, requestTimestamp, responseTimestamp,
apiKeyID, modelName, backendModel, requestTokenCount,
string(requestBody), h.DB)
return
}
// 处理非流式响应
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to read backend response",
"type": "internal_error",
},
})
return
}
log.Printf("Backend Response Status: %s", resp.Status)
log.Printf("Backend Response Body: %s", string(responseBody))
// 计算响应token数
responseTokenCount := extractResponseTokenCount(responseBody)
// 计算费用
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 创建并记录日志
logEntry := createRequestLog(apiKeyID, modelName, backendModel,
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
cost, string(requestBody), string(responseBody))
logger.LogRequest(h.DB, logEntry)
// 复制响应头并返回响应
copyResponseHeaders(c, resp)
c.Status(resp.StatusCode)
c.Writer.Write(responseBody)
}
// ResponsesCompletions 处理 POST /v1/responses 请求
func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
// 记录请求开始时间
requestTimestamp := time.Now()
// 读取请求体
bodyBytes, _ := io.ReadAll(c.Request.Body)
// 解析为 map只解析一次
var jsonBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "Invalid request format",
"type": "invalid_request_error",
},
})
return
}
// 从 map 中提取需要的字段
modelName, _ := jsonBody["model"].(string)
// 提取 input 并转换为计算 token 所需的格式
var messages []billing.ChatCompletionMessage
if inputRaw, ok := jsonBody["input"].([]interface{}); ok {
messages = make([]billing.ChatCompletionMessage, 0, len(inputRaw))
for _, msgRaw := range inputRaw {
if msgMap, ok := msgRaw.(map[string]interface{}); ok {
role, _ := msgMap["role"].(string)
content, _ := msgMap["content"].(string)
messages = append(messages, billing.ChatCompletionMessage{
Role: role,
Content: content,
})
}
}
}
// 计算请求token数
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
// 选择后端模型
backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "invalid_request_error",
},
})
return
}
// 修改 model 字段为后端模型名称
jsonBody["model"] = backendModel.Name
// Marshal 请求体用于转发
requestBody, err := json.Marshal(jsonBody)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to process request",
"type": "internal_error",
},
})
return
}
// 构建后端API URL
backendURL := backendModel.Provider.BaseURL + "/v1/responses"
// 转发请求到后端
resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent"))
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"message": "Failed to connect to backend service",
"type": "service_unavailable",
},
})
return
}
defer resp.Body.Close()
// 记录响应时间
responseTimestamp := time.Now()
// 读取响应体
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to read backend response",
"type": "internal_error",
},
})
return
}
// 计算响应token数
responseTokenCount := extractResponseTokenCount(responseBody)
// 计算费用
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 获取API Key ID
apiKeyID := getAPIKeyID(c)
// 创建并记录日志
logEntry := createRequestLog(apiKeyID, modelName, backendModel,
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
cost, string(requestBody), string(responseBody))
logger.LogRequest(h.DB, logEntry)
// 复制响应头并返回响应
copyResponseHeaders(c, resp)
c.Status(resp.StatusCode)
c.Writer.Write(responseBody)
}
// handleStreamingResponse 处理流式响应
func handleStreamingResponse(c *gin.Context, resp *http.Response, requestTimestamp, responseTimestamp time.Time,
apiKeyID uint, virtualModelName string, backendModel *models.BackendModel, requestTokenCount int,
requestBody string, database *gorm.DB) {
// 设置流式响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Transfer-Encoding", "chunked")
// 复制其他响应头
for key, values := range resp.Header {
if key != "Content-Length" && key != "Transfer-Encoding" {
for _, value := range values {
c.Header(key, value)
}
}
}
c.Status(resp.StatusCode)
// 用于累积完整响应内容
var fullContent strings.Builder
var responseBody strings.Builder
// 创建一个 scanner 来逐行读取流
scanner := bufio.NewScanner(resp.Body)
flusher, ok := c.Writer.(http.Flusher)
for scanner.Scan() {
line := scanner.Text()
// 将原始行写入响应体缓冲区
responseBody.WriteString(line)
responseBody.WriteString("\n")
// SSE 格式data: {...}
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
// 检查是否是结束标记
if data == "[DONE]" {
_, err := c.Writer.Write([]byte(line + "\n\n"))
if err != nil {
return
}
if ok {
flusher.Flush()
}
break
}
// 解析 JSON 数据以提取内容
var chunk map[string]interface{}
if err := json.Unmarshal([]byte(data), &chunk); err == nil {
if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 {
if choice, ok := choices[0].(map[string]interface{}); ok {
if delta, ok := choice["delta"].(map[string]interface{}); ok {
if content, ok := delta["content"].(string); ok {
fullContent.WriteString(content)
}
}
}
}
}
}
// 转发数据到客户端
_, err := c.Writer.Write([]byte(line + "\n"))
if err != nil {
return
}
// 如果是空行SSE 消息分隔符),刷新
if line == "" {
if ok {
flusher.Flush()
}
}
}
// 确保发送最后的数据
if ok {
flusher.Flush()
}
// 扫描可能出现的错误
if err := scanner.Err(); err != nil {
log.Printf("Error reading stream: %v", err)
}
// 计算响应 token 数
responseTokenCount := billing.CalculateTextTokensSimple(fullContent.String())
// 计算费用
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 创建日志记录
logEntry := &models.RequestLog{
APIKeyID: apiKeyID,
VirtualModelName: virtualModelName,
BackendModelName: backendModel.Name,
RequestTimestamp: requestTimestamp,
ResponseTimestamp: time.Now(), // 使用实际结束时间
RequestTokens: requestTokenCount,
ResponseTokens: responseTokenCount,
Cost: cost,
RequestBody: requestBody,
ResponseBody: responseBody.String(),
}
// 异步记录日志
logger.LogRequest(database, logEntry)
}

View File

@@ -0,0 +1,101 @@
package api
import (
"ai-gateway/internal/db"
"ai-gateway/internal/models"
"fmt"
"net/http"
"github.com/gin-gonic/gin"
)
// GetProvidersHandler 处理 GET /api/providers 请求
func (h *APIHandler) GetProvidersHandler(c *gin.Context) {
providers, err := db.GetProviders(h.DB)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to retrieve providers",
"type": "internal_error",
},
})
return
}
c.JSON(http.StatusOK, providers)
}
// GetProviderHandler 处理 GET /api/providers/:id 请求
func (h *APIHandler) GetProviderHandler(c *gin.Context) {
id := c.Param("id")
var providerID uint
if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
return
}
provider, err := db.GetProviderByID(h.DB, providerID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Provider not found"})
return
}
c.JSON(http.StatusOK, provider)
}
// CreateProviderHandler 处理 POST /api/providers 请求
func (h *APIHandler) CreateProviderHandler(c *gin.Context) {
var provider models.Provider
if err := c.ShouldBindJSON(&provider); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := db.CreateProvider(h.DB, &provider); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider"})
return
}
c.JSON(http.StatusCreated, provider)
}
// UpdateProviderHandler 处理 PUT /api/providers/:id 请求
func (h *APIHandler) UpdateProviderHandler(c *gin.Context) {
id := c.Param("id")
var providerID uint
if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
return
}
var provider models.Provider
if err := c.ShouldBindJSON(&provider); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
provider.ID = providerID
if err := db.UpdateProvider(h.DB, &provider); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider"})
return
}
c.JSON(http.StatusOK, provider)
}
// DeleteProviderHandler 处理 DELETE /api/providers/:id 请求
func (h *APIHandler) DeleteProviderHandler(c *gin.Context) {
id := c.Param("id")
var providerID uint
if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
return
}
if err := db.DeleteProvider(h.DB, providerID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete provider"})
return
}
c.JSON(http.StatusNoContent, nil)
}

View File

@@ -0,0 +1,96 @@
package api
import (
"ai-gateway/internal/db"
"ai-gateway/internal/models"
"fmt"
"net/http"
"github.com/gin-gonic/gin"
)
// CreateVirtualModelHandler 处理 POST /api/virtual-models 请求
func (h *APIHandler) CreateVirtualModelHandler(c *gin.Context) {
var vm models.VirtualModel
if err := c.ShouldBindJSON(&vm); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := db.CreateVirtualModel(h.DB, &vm); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create virtual model"})
return
}
c.JSON(http.StatusCreated, vm)
}
// GetVirtualModelsHandler 处理 GET /api/virtual-models 请求
func (h *APIHandler) GetVirtualModelsHandler(c *gin.Context) {
virtualModels, err := db.GetVirtualModels(h.DB)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve virtual models"})
return
}
c.JSON(http.StatusOK, virtualModels)
}
// GetVirtualModelHandler 处理 GET /api/virtual-models/:id 请求
func (h *APIHandler) GetVirtualModelHandler(c *gin.Context) {
id := c.Param("id")
var virtualModelID uint
if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"})
return
}
virtualModel, err := db.GetVirtualModelByID(h.DB, virtualModelID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Virtual model not found"})
return
}
c.JSON(http.StatusOK, virtualModel)
}
// UpdateVirtualModelHandler 处理 PUT /api/virtual-models/:id 请求
func (h *APIHandler) UpdateVirtualModelHandler(c *gin.Context) {
id := c.Param("id")
var virtualModelID uint
if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"})
return
}
var vm models.VirtualModel
if err := c.ShouldBindJSON(&vm); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
vm.ID = virtualModelID
if err := db.UpdateVirtualModel(h.DB, &vm); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update virtual model"})
return
}
c.JSON(http.StatusOK, vm)
}
// DeleteVirtualModelHandler 处理 DELETE /api/virtual-models/:id 请求
func (h *APIHandler) DeleteVirtualModelHandler(c *gin.Context) {
id := c.Param("id")
var virtualModelID uint
if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"})
return
}
if err := db.DeleteVirtualModel(h.DB, virtualModelID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete virtual model"})
return
}
c.JSON(http.StatusNoContent, nil)
}