From 38889be0723cc3e874100fe38534a6da615d0dcc Mon Sep 17 00:00:00 2001 From: nanako <469449812@qq.com> Date: Sun, 9 Nov 2025 00:22:23 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84API=E5=A4=84=E7=90=86?= =?UTF-8?q?=E9=80=BB=E8=BE=91=EF=BC=8C=E6=96=B0=E5=A2=9Eprovider=E5=92=8Cv?= =?UTF-8?q?irtual=20model=E5=A4=84=E7=90=86=E5=8A=9F=E8=83=BD=EF=BC=8C?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E8=AF=B7=E6=B1=82=E4=BD=93=E8=A7=A3=E6=9E=90?= =?UTF-8?q?=E5=92=8C=E6=97=A5=E5=BF=97=E8=AE=B0=E5=BD=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/api/handlers.go | 623 -------------------------- backend/api/openai_handlers.go | 426 ++++++++++++++++++ backend/api/provider_handlers.go | 101 +++++ backend/api/virtual_model_handlers.go | 96 ++++ 4 files changed, 623 insertions(+), 623 deletions(-) create mode 100644 backend/api/openai_handlers.go create mode 100644 backend/api/provider_handlers.go create mode 100644 backend/api/virtual_model_handlers.go diff --git a/backend/api/handlers.go b/backend/api/handlers.go index 6412baa..8682dc5 100644 --- a/backend/api/handlers.go +++ b/backend/api/handlers.go @@ -2,18 +2,11 @@ package api import ( "ai-gateway/internal/billing" - "ai-gateway/internal/db" - "ai-gateway/internal/logger" "ai-gateway/internal/models" - "ai-gateway/internal/router" - "bufio" "bytes" "encoding/json" "fmt" - "io" - "log" "net/http" - "strings" "time" "github.com/gin-gonic/gin" @@ -88,328 +81,6 @@ type UpdateVirtualModelRequest struct { BackendModels []BackendModelAssociation `json:"backend_models"` } -// ListModels 处理 GET /models 请求 -func (h *APIHandler) ListModels(c *gin.Context) { - var virtualModels []models.VirtualModel - - // 查询所有虚拟模型 - if err := h.DB.Find(&virtualModels).Error; err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to retrieve models", - "type": "internal_error", - }, - }) - return - } - - // 格式化为OpenAI API响应格式 - modelData := make([]ModelData, len(virtualModels)) - for i, vm := range virtualModels { - modelData[i] = ModelData{ - ID: vm.Name, - Object: "model", - Created: vm.CreatedAt.Unix(), - OwnedBy: "ai-gateway", - } - } - - response := ModelListResponse{ - Object: "list", - Data: modelData, - } - - c.JSON(http.StatusOK, response) -} - -// ChatCompletions 处理 POST /v1/chat/completions 请求 -func (h *APIHandler) ChatCompletions(c *gin.Context) { - // 记录请求开始时间 - requestTimestamp := time.Now() - - // 读取请求体 - bodyBytes, _ := io.ReadAll(c.Request.Body) - - // 解析为 map,只解析一次 - var jsonBody map[string]interface{} - if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil { - log.Printf("Failed to unmarshal JSON: %v", err) - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": "Invalid request format", - "type": "invalid_request_error", - }, - }) - return - } - - // 从 map 中提取需要的字段 - modelName, _ := jsonBody["model"].(string) - isStream, _ := jsonBody["stream"].(bool) - - // 提取 messages 并转换为计算 token 所需的格式 - var messages []billing.ChatCompletionMessage - if messagesRaw, ok := jsonBody["messages"].([]interface{}); ok { - messages = make([]billing.ChatCompletionMessage, 0, len(messagesRaw)) - for _, msgRaw := range messagesRaw { - if msgMap, ok := msgRaw.(map[string]interface{}); ok { - role, _ := msgMap["role"].(string) - content, _ := msgMap["content"].(string) - messages = append(messages, billing.ChatCompletionMessage{ - Role: role, - Content: content, - }) - } - } - } - - // 计算请求token数 - requestTokenCount := billing.CalculateMessagesTokensSimple(messages) - - // 选择后端模型 - backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{ - "error": gin.H{ - "message": err.Error(), - "type": "invalid_request_error", - }, - }) - return - } - - // 修改 model 字段为后端模型名称 - jsonBody["model"] = backendModel.Name - - // Marshal 请求体用于转发 - requestBody, err := json.Marshal(jsonBody) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to process request", - "type": "internal_error", - }, - }) - return - } - - // 构建后端API URL - backendURL := backendModel.Provider.BaseURL + "/v1/chat/completions" - - // 转发请求到后端 - resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent")) - if err != nil { - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "message": "Failed to connect to backend service", - "type": "service_unavailable", - }, - }) - return - } - defer resp.Body.Close() - - // 记录响应时间 - responseTimestamp := time.Now() - - // 获取API Key ID - apiKeyID := getAPIKeyID(c) - - // 处理流式响应 - if isStream { - handleStreamingResponse(c, resp, requestTimestamp, responseTimestamp, - apiKeyID, modelName, backendModel, requestTokenCount, - string(requestBody), h.DB) - return - } - - // 处理非流式响应 - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to read backend response", - "type": "internal_error", - }, - }) - return - } - - log.Printf("Backend Response Status: %s", resp.Status) - log.Printf("Backend Response Body: %s", string(responseBody)) - - // 计算响应token数 - responseTokenCount := extractResponseTokenCount(responseBody) - - // 计算费用 - costCalculator := billing.NewCostCalculator() - cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount) - - // 创建并记录日志 - logEntry := createRequestLog(apiKeyID, modelName, backendModel, - requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount, - cost, string(requestBody), string(responseBody)) - logger.LogRequest(h.DB, logEntry) - - // 复制响应头并返回响应 - copyResponseHeaders(c, resp) - c.Status(resp.StatusCode) - c.Writer.Write(responseBody) -} - -// ResponsesCompletions 处理 POST /v1/responses 请求 -func (h *APIHandler) ResponsesCompletions(c *gin.Context) { - // 记录请求开始时间 - requestTimestamp := time.Now() - - // 读取请求体 - bodyBytes, _ := io.ReadAll(c.Request.Body) - - // 解析为 map,只解析一次 - var jsonBody map[string]interface{} - if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": "Invalid request format", - "type": "invalid_request_error", - }, - }) - return - } - - // 从 map 中提取需要的字段 - modelName, _ := jsonBody["model"].(string) - - // 提取 input 并转换为计算 token 所需的格式 - var messages []billing.ChatCompletionMessage - if inputRaw, ok := jsonBody["input"].([]interface{}); ok { - messages = make([]billing.ChatCompletionMessage, 0, len(inputRaw)) - for _, msgRaw := range inputRaw { - if msgMap, ok := msgRaw.(map[string]interface{}); ok { - role, _ := msgMap["role"].(string) - content, _ := msgMap["content"].(string) - messages = append(messages, billing.ChatCompletionMessage{ - Role: role, - Content: content, - }) - } - } - } - - // 计算请求token数 - requestTokenCount := billing.CalculateMessagesTokensSimple(messages) - - // 选择后端模型 - backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{ - "error": gin.H{ - "message": err.Error(), - "type": "invalid_request_error", - }, - }) - return - } - - // 修改 model 字段为后端模型名称 - jsonBody["model"] = backendModel.Name - - // Marshal 请求体用于转发 - requestBody, err := json.Marshal(jsonBody) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to process request", - "type": "internal_error", - }, - }) - return - } - - // 构建后端API URL - backendURL := backendModel.Provider.BaseURL + "/v1/responses" - - // 转发请求到后端 - resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent")) - if err != nil { - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "message": "Failed to connect to backend service", - "type": "service_unavailable", - }, - }) - return - } - defer resp.Body.Close() - - // 记录响应时间 - responseTimestamp := time.Now() - - // 读取响应体 - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to read backend response", - "type": "internal_error", - }, - }) - return - } - - // 计算响应token数 - responseTokenCount := extractResponseTokenCount(responseBody) - - // 计算费用 - costCalculator := billing.NewCostCalculator() - cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount) - - // 获取API Key ID - apiKeyID := getAPIKeyID(c) - - // 创建并记录日志 - logEntry := createRequestLog(apiKeyID, modelName, backendModel, - requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount, - cost, string(requestBody), string(responseBody)) - logger.LogRequest(h.DB, logEntry) - - // 复制响应头并返回响应 - copyResponseHeaders(c, resp) - c.Status(resp.StatusCode) - c.Writer.Write(responseBody) -} - -// convertToTikTokenMessages 将ChatCompletionMessage转换为billing包的消息格式 -func convertToTikTokenMessages(messages []ChatCompletionMessage) []billing.ChatCompletionMessage { - result := make([]billing.ChatCompletionMessage, len(messages)) - for i, msg := range messages { - result[i] = billing.ChatCompletionMessage{ - Role: msg.Role, - Content: msg.Content, - } - } - return result -} - -// prepareBackendRequest 准备后端请求体,替换model字段 -func prepareBackendRequest(bodyBytes []byte, backendModelName string) ([]byte, error) { - var jsonBody map[string]interface{} - if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil { - return nil, fmt.Errorf("failed to unmarshal request body: %w", err) - } - - // 修改model字段为后端模型名称 - jsonBody["model"] = backendModelName - - // 重新marshal - requestBody, err := json.Marshal(jsonBody) - if err != nil { - return nil, fmt.Errorf("failed to marshal request body: %w", err) - } - - return requestBody, nil -} - // forwardRequest 转发请求到后端API func forwardRequest(backendURL string, apiKey string, requestBody []byte, userAgent string) (*http.Response, error) { // 创建HTTP请求 @@ -501,297 +172,3 @@ func copyResponseHeaders(c *gin.Context, resp *http.Response) { } } } - -// handleStreamingResponse 处理流式响应 -func handleStreamingResponse(c *gin.Context, resp *http.Response, requestTimestamp, responseTimestamp time.Time, - apiKeyID uint, virtualModelName string, backendModel *models.BackendModel, requestTokenCount int, - requestBody string, database *gorm.DB) { - - // 设置流式响应\u5934 - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Transfer-Encoding", "chunked") - - // 复制其他响应头 - for key, values := range resp.Header { - if key != "Content-Length" && key != "Transfer-Encoding" { - for _, value := range values { - c.Header(key, value) - } - } - } - - c.Status(resp.StatusCode) - - // 用于累积完整响应内容 - var fullContent strings.Builder - var responseBody strings.Builder - - // 创建一个 scanner 来逐行读取流 - scanner := bufio.NewScanner(resp.Body) - flusher, ok := c.Writer.(http.Flusher) - - for scanner.Scan() { - line := scanner.Text() - - // 将原始行写入响应体缓冲区 - responseBody.WriteString(line) - responseBody.WriteString("\n") - - // SSE 格式:data: {...} - if strings.HasPrefix(line, "data: ") { - data := strings.TrimPrefix(line, "data: ") - - // 检查是否是结束标记 - if data == "[DONE]" { - _, err := c.Writer.Write([]byte(line + "\n\n")) - if err != nil { - return - } - if ok { - flusher.Flush() - } - break - } - - // 解析 JSON 数据以提取内容 - var chunk map[string]interface{} - if err := json.Unmarshal([]byte(data), &chunk); err == nil { - if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 { - if choice, ok := choices[0].(map[string]interface{}); ok { - if delta, ok := choice["delta"].(map[string]interface{}); ok { - if content, ok := delta["content"].(string); ok { - fullContent.WriteString(content) - } - } - } - } - } - } - - // 转发数据到客户端 - _, err := c.Writer.Write([]byte(line + "\n")) - if err != nil { - return - } - - // 如果是空行(SSE 消息分隔符),刷新 - if line == "" { - if ok { - flusher.Flush() - } - } - } - - // 确保发送最后的数据 - if ok { - flusher.Flush() - } - - // 扫描可能出现的错误 - if err := scanner.Err(); err != nil { - log.Printf("Error reading stream: %v", err) - } - - // 计算响应 token 数 - responseTokenCount := billing.CalculateTextTokensSimple(fullContent.String()) - - // 计算费用 - costCalculator := billing.NewCostCalculator() - cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount) - - // 创建日志记录 - logEntry := &models.RequestLog{ - APIKeyID: apiKeyID, - VirtualModelName: virtualModelName, - BackendModelName: backendModel.Name, - RequestTimestamp: requestTimestamp, - ResponseTimestamp: time.Now(), // 使用实际结束时间 - RequestTokens: requestTokenCount, - ResponseTokens: responseTokenCount, - Cost: cost, - RequestBody: requestBody, - ResponseBody: responseBody.String(), - } - - // 异步记录日志 - logger.LogRequest(database, logEntry) -} - -// GetProvidersHandler 处理 GET /api/providers 请求 -func (h *APIHandler) GetProvidersHandler(c *gin.Context) { - providers, err := db.GetProviders(h.DB) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to retrieve providers", - "type": "internal_error", - }, - }) - return - } - - c.JSON(http.StatusOK, providers) -} - -// GetProviderHandler 处理 GET /api/providers/:id 请求 -func (h *APIHandler) GetProviderHandler(c *gin.Context) { - id := c.Param("id") - var providerID uint - if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"}) - return - } - - provider, err := db.GetProviderByID(h.DB, providerID) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "Provider not found"}) - return - } - - c.JSON(http.StatusOK, provider) -} - -// CreateProviderHandler 处理 POST /api/providers 请求 -func (h *APIHandler) CreateProviderHandler(c *gin.Context) { - var provider models.Provider - if err := c.ShouldBindJSON(&provider); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - if err := db.CreateProvider(h.DB, &provider); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider"}) - return - } - - c.JSON(http.StatusCreated, provider) -} - -// UpdateProviderHandler 处理 PUT /api/providers/:id 请求 -func (h *APIHandler) UpdateProviderHandler(c *gin.Context) { - id := c.Param("id") - var providerID uint - if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"}) - return - } - - var provider models.Provider - if err := c.ShouldBindJSON(&provider); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - provider.ID = providerID - - if err := db.UpdateProvider(h.DB, &provider); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider"}) - return - } - - c.JSON(http.StatusOK, provider) -} - -// DeleteProviderHandler 处理 DELETE /api/providers/:id 请求 -func (h *APIHandler) DeleteProviderHandler(c *gin.Context) { - id := c.Param("id") - var providerID uint - if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"}) - return - } - - if err := db.DeleteProvider(h.DB, providerID); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete provider"}) - return - } - - c.JSON(http.StatusNoContent, nil) -} - -// CreateVirtualModelHandler 处理 POST /api/virtual-models 请求 -func (h *APIHandler) CreateVirtualModelHandler(c *gin.Context) { - var vm models.VirtualModel - if err := c.ShouldBindJSON(&vm); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - if err := db.CreateVirtualModel(h.DB, &vm); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create virtual model"}) - return - } - - c.JSON(http.StatusCreated, vm) -} - -// GetVirtualModelsHandler 处理 GET /api/virtual-models 请求 -func (h *APIHandler) GetVirtualModelsHandler(c *gin.Context) { - virtualModels, err := db.GetVirtualModels(h.DB) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve virtual models"}) - return - } - - c.JSON(http.StatusOK, virtualModels) -} - -// GetVirtualModelHandler 处理 GET /api/virtual-models/:id 请求 -func (h *APIHandler) GetVirtualModelHandler(c *gin.Context) { - id := c.Param("id") - var virtualModelID uint - if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"}) - return - } - - virtualModel, err := db.GetVirtualModelByID(h.DB, virtualModelID) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "Virtual model not found"}) - return - } - - c.JSON(http.StatusOK, virtualModel) -} - -// UpdateVirtualModelHandler 处理 PUT /api/virtual-models/:id 请求 -func (h *APIHandler) UpdateVirtualModelHandler(c *gin.Context) { - id := c.Param("id") - var virtualModelID uint - if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"}) - return - } - - var vm models.VirtualModel - if err := c.ShouldBindJSON(&vm); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - vm.ID = virtualModelID - - if err := db.UpdateVirtualModel(h.DB, &vm); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update virtual model"}) - return - } - - c.JSON(http.StatusOK, vm) -} - -// DeleteVirtualModelHandler 处理 DELETE /api/virtual-models/:id 请求 -func (h *APIHandler) DeleteVirtualModelHandler(c *gin.Context) { - id := c.Param("id") - var virtualModelID uint - if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"}) - return - } - - if err := db.DeleteVirtualModel(h.DB, virtualModelID); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete virtual model"}) - return - } - - c.JSON(http.StatusNoContent, nil) -} diff --git a/backend/api/openai_handlers.go b/backend/api/openai_handlers.go new file mode 100644 index 0000000..d80094a --- /dev/null +++ b/backend/api/openai_handlers.go @@ -0,0 +1,426 @@ +package api + +import ( + "ai-gateway/internal/billing" + "ai-gateway/internal/logger" + "ai-gateway/internal/models" + "ai-gateway/internal/router" + "bufio" + "encoding/json" + "io" + "log" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +// ListModels 处理 GET /models 请求 +func (h *APIHandler) ListModels(c *gin.Context) { + var virtualModels []models.VirtualModel + + // 查询所有虚拟模型 + if err := h.DB.Find(&virtualModels).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": "Failed to retrieve models", + "type": "internal_error", + }, + }) + return + } + + // 格式化为OpenAI API响应格式 + modelData := make([]ModelData, len(virtualModels)) + for i, vm := range virtualModels { + modelData[i] = ModelData{ + ID: vm.Name, + Object: "model", + Created: vm.CreatedAt.Unix(), + OwnedBy: "ai-gateway", + } + } + + response := ModelListResponse{ + Object: "list", + Data: modelData, + } + + c.JSON(http.StatusOK, response) +} + +// ChatCompletions 处理 POST /v1/chat/completions 请求 +func (h *APIHandler) ChatCompletions(c *gin.Context) { + // 记录请求开始时间 + requestTimestamp := time.Now() + + // 读取请求体 + bodyBytes, _ := io.ReadAll(c.Request.Body) + + // 解析为 map,只解析一次 + var jsonBody map[string]interface{} + if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil { + log.Printf("Failed to unmarshal JSON: %v", err) + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": "Invalid request format", + "type": "invalid_request_error", + }, + }) + return + } + + // 从 map 中提取需要的字段 + modelName, _ := jsonBody["model"].(string) + isStream, _ := jsonBody["stream"].(bool) + + // 提取 messages 并转换为计算 token 所需的格式 + var messages []billing.ChatCompletionMessage + if messagesRaw, ok := jsonBody["messages"].([]interface{}); ok { + messages = make([]billing.ChatCompletionMessage, 0, len(messagesRaw)) + for _, msgRaw := range messagesRaw { + if msgMap, ok := msgRaw.(map[string]interface{}); ok { + role, _ := msgMap["role"].(string) + content, _ := msgMap["content"].(string) + messages = append(messages, billing.ChatCompletionMessage{ + Role: role, + Content: content, + }) + } + } + } + + // 计算请求token数 + requestTokenCount := billing.CalculateMessagesTokensSimple(messages) + + // 选择后端模型 + backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{ + "error": gin.H{ + "message": err.Error(), + "type": "invalid_request_error", + }, + }) + return + } + + // 修改 model 字段为后端模型名称 + jsonBody["model"] = backendModel.Name + + // Marshal 请求体用于转发 + requestBody, err := json.Marshal(jsonBody) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": "Failed to process request", + "type": "internal_error", + }, + }) + return + } + + // 构建后端API URL + backendURL := backendModel.Provider.BaseURL + "/v1/chat/completions" + + // 转发请求到后端 + resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent")) + if err != nil { + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "message": "Failed to connect to backend service", + "type": "service_unavailable", + }, + }) + return + } + defer resp.Body.Close() + + // 记录响应时间 + responseTimestamp := time.Now() + + // 获取API Key ID + apiKeyID := getAPIKeyID(c) + + // 处理流式响应 + if isStream { + handleStreamingResponse(c, resp, requestTimestamp, responseTimestamp, + apiKeyID, modelName, backendModel, requestTokenCount, + string(requestBody), h.DB) + return + } + + // 处理非流式响应 + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": "Failed to read backend response", + "type": "internal_error", + }, + }) + return + } + + log.Printf("Backend Response Status: %s", resp.Status) + log.Printf("Backend Response Body: %s", string(responseBody)) + + // 计算响应token数 + responseTokenCount := extractResponseTokenCount(responseBody) + + // 计算费用 + costCalculator := billing.NewCostCalculator() + cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount) + + // 创建并记录日志 + logEntry := createRequestLog(apiKeyID, modelName, backendModel, + requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount, + cost, string(requestBody), string(responseBody)) + logger.LogRequest(h.DB, logEntry) + + // 复制响应头并返回响应 + copyResponseHeaders(c, resp) + c.Status(resp.StatusCode) + c.Writer.Write(responseBody) +} + +// ResponsesCompletions 处理 POST /v1/responses 请求 +func (h *APIHandler) ResponsesCompletions(c *gin.Context) { + // 记录请求开始时间 + requestTimestamp := time.Now() + + // 读取请求体 + bodyBytes, _ := io.ReadAll(c.Request.Body) + + // 解析为 map,只解析一次 + var jsonBody map[string]interface{} + if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": "Invalid request format", + "type": "invalid_request_error", + }, + }) + return + } + + // 从 map 中提取需要的字段 + modelName, _ := jsonBody["model"].(string) + + // 提取 input 并转换为计算 token 所需的格式 + var messages []billing.ChatCompletionMessage + if inputRaw, ok := jsonBody["input"].([]interface{}); ok { + messages = make([]billing.ChatCompletionMessage, 0, len(inputRaw)) + for _, msgRaw := range inputRaw { + if msgMap, ok := msgRaw.(map[string]interface{}); ok { + role, _ := msgMap["role"].(string) + content, _ := msgMap["content"].(string) + messages = append(messages, billing.ChatCompletionMessage{ + Role: role, + Content: content, + }) + } + } + } + + // 计算请求token数 + requestTokenCount := billing.CalculateMessagesTokensSimple(messages) + + // 选择后端模型 + backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{ + "error": gin.H{ + "message": err.Error(), + "type": "invalid_request_error", + }, + }) + return + } + + // 修改 model 字段为后端模型名称 + jsonBody["model"] = backendModel.Name + + // Marshal 请求体用于转发 + requestBody, err := json.Marshal(jsonBody) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": "Failed to process request", + "type": "internal_error", + }, + }) + return + } + + // 构建后端API URL + backendURL := backendModel.Provider.BaseURL + "/v1/responses" + + // 转发请求到后端 + resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent")) + if err != nil { + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "message": "Failed to connect to backend service", + "type": "service_unavailable", + }, + }) + return + } + defer resp.Body.Close() + + // 记录响应时间 + responseTimestamp := time.Now() + + // 读取响应体 + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": "Failed to read backend response", + "type": "internal_error", + }, + }) + return + } + + // 计算响应token数 + responseTokenCount := extractResponseTokenCount(responseBody) + + // 计算费用 + costCalculator := billing.NewCostCalculator() + cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount) + + // 获取API Key ID + apiKeyID := getAPIKeyID(c) + + // 创建并记录日志 + logEntry := createRequestLog(apiKeyID, modelName, backendModel, + requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount, + cost, string(requestBody), string(responseBody)) + logger.LogRequest(h.DB, logEntry) + + // 复制响应头并返回响应 + copyResponseHeaders(c, resp) + c.Status(resp.StatusCode) + c.Writer.Write(responseBody) +} + +// handleStreamingResponse 处理流式响应 +func handleStreamingResponse(c *gin.Context, resp *http.Response, requestTimestamp, responseTimestamp time.Time, + apiKeyID uint, virtualModelName string, backendModel *models.BackendModel, requestTokenCount int, + requestBody string, database *gorm.DB) { + + // 设置流式响应头 + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Transfer-Encoding", "chunked") + + // 复制其他响应头 + for key, values := range resp.Header { + if key != "Content-Length" && key != "Transfer-Encoding" { + for _, value := range values { + c.Header(key, value) + } + } + } + + c.Status(resp.StatusCode) + + // 用于累积完整响应内容 + var fullContent strings.Builder + var responseBody strings.Builder + + // 创建一个 scanner 来逐行读取流 + scanner := bufio.NewScanner(resp.Body) + flusher, ok := c.Writer.(http.Flusher) + + for scanner.Scan() { + line := scanner.Text() + + // 将原始行写入响应体缓冲区 + responseBody.WriteString(line) + responseBody.WriteString("\n") + + // SSE 格式:data: {...} + if strings.HasPrefix(line, "data: ") { + data := strings.TrimPrefix(line, "data: ") + + // 检查是否是结束标记 + if data == "[DONE]" { + _, err := c.Writer.Write([]byte(line + "\n\n")) + if err != nil { + return + } + if ok { + flusher.Flush() + } + break + } + + // 解析 JSON 数据以提取内容 + var chunk map[string]interface{} + if err := json.Unmarshal([]byte(data), &chunk); err == nil { + if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 { + if choice, ok := choices[0].(map[string]interface{}); ok { + if delta, ok := choice["delta"].(map[string]interface{}); ok { + if content, ok := delta["content"].(string); ok { + fullContent.WriteString(content) + } + } + } + } + } + } + + // 转发数据到客户端 + _, err := c.Writer.Write([]byte(line + "\n")) + if err != nil { + return + } + + // 如果是空行(SSE 消息分隔符),刷新 + if line == "" { + if ok { + flusher.Flush() + } + } + } + + // 确保发送最后的数据 + if ok { + flusher.Flush() + } + + // 扫描可能出现的错误 + if err := scanner.Err(); err != nil { + log.Printf("Error reading stream: %v", err) + } + + // 计算响应 token 数 + responseTokenCount := billing.CalculateTextTokensSimple(fullContent.String()) + + // 计算费用 + costCalculator := billing.NewCostCalculator() + cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount) + + // 创建日志记录 + logEntry := &models.RequestLog{ + APIKeyID: apiKeyID, + VirtualModelName: virtualModelName, + BackendModelName: backendModel.Name, + RequestTimestamp: requestTimestamp, + ResponseTimestamp: time.Now(), // 使用实际结束时间 + RequestTokens: requestTokenCount, + ResponseTokens: responseTokenCount, + Cost: cost, + RequestBody: requestBody, + ResponseBody: responseBody.String(), + } + + // 异步记录日志 + logger.LogRequest(database, logEntry) +} diff --git a/backend/api/provider_handlers.go b/backend/api/provider_handlers.go new file mode 100644 index 0000000..c906fe3 --- /dev/null +++ b/backend/api/provider_handlers.go @@ -0,0 +1,101 @@ +package api + +import ( + "ai-gateway/internal/db" + "ai-gateway/internal/models" + "fmt" + "net/http" + + "github.com/gin-gonic/gin" +) + +// GetProvidersHandler 处理 GET /api/providers 请求 +func (h *APIHandler) GetProvidersHandler(c *gin.Context) { + providers, err := db.GetProviders(h.DB) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": "Failed to retrieve providers", + "type": "internal_error", + }, + }) + return + } + + c.JSON(http.StatusOK, providers) +} + +// GetProviderHandler 处理 GET /api/providers/:id 请求 +func (h *APIHandler) GetProviderHandler(c *gin.Context) { + id := c.Param("id") + var providerID uint + if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"}) + return + } + + provider, err := db.GetProviderByID(h.DB, providerID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Provider not found"}) + return + } + + c.JSON(http.StatusOK, provider) +} + +// CreateProviderHandler 处理 POST /api/providers 请求 +func (h *APIHandler) CreateProviderHandler(c *gin.Context) { + var provider models.Provider + if err := c.ShouldBindJSON(&provider); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if err := db.CreateProvider(h.DB, &provider); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider"}) + return + } + + c.JSON(http.StatusCreated, provider) +} + +// UpdateProviderHandler 处理 PUT /api/providers/:id 请求 +func (h *APIHandler) UpdateProviderHandler(c *gin.Context) { + id := c.Param("id") + var providerID uint + if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"}) + return + } + + var provider models.Provider + if err := c.ShouldBindJSON(&provider); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + provider.ID = providerID + + if err := db.UpdateProvider(h.DB, &provider); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider"}) + return + } + + c.JSON(http.StatusOK, provider) +} + +// DeleteProviderHandler 处理 DELETE /api/providers/:id 请求 +func (h *APIHandler) DeleteProviderHandler(c *gin.Context) { + id := c.Param("id") + var providerID uint + if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"}) + return + } + + if err := db.DeleteProvider(h.DB, providerID); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete provider"}) + return + } + + c.JSON(http.StatusNoContent, nil) +} diff --git a/backend/api/virtual_model_handlers.go b/backend/api/virtual_model_handlers.go new file mode 100644 index 0000000..94a9a92 --- /dev/null +++ b/backend/api/virtual_model_handlers.go @@ -0,0 +1,96 @@ +package api + +import ( + "ai-gateway/internal/db" + "ai-gateway/internal/models" + "fmt" + "net/http" + + "github.com/gin-gonic/gin" +) + +// CreateVirtualModelHandler 处理 POST /api/virtual-models 请求 +func (h *APIHandler) CreateVirtualModelHandler(c *gin.Context) { + var vm models.VirtualModel + if err := c.ShouldBindJSON(&vm); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if err := db.CreateVirtualModel(h.DB, &vm); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create virtual model"}) + return + } + + c.JSON(http.StatusCreated, vm) +} + +// GetVirtualModelsHandler 处理 GET /api/virtual-models 请求 +func (h *APIHandler) GetVirtualModelsHandler(c *gin.Context) { + virtualModels, err := db.GetVirtualModels(h.DB) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve virtual models"}) + return + } + + c.JSON(http.StatusOK, virtualModels) +} + +// GetVirtualModelHandler 处理 GET /api/virtual-models/:id 请求 +func (h *APIHandler) GetVirtualModelHandler(c *gin.Context) { + id := c.Param("id") + var virtualModelID uint + if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"}) + return + } + + virtualModel, err := db.GetVirtualModelByID(h.DB, virtualModelID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Virtual model not found"}) + return + } + + c.JSON(http.StatusOK, virtualModel) +} + +// UpdateVirtualModelHandler 处理 PUT /api/virtual-models/:id 请求 +func (h *APIHandler) UpdateVirtualModelHandler(c *gin.Context) { + id := c.Param("id") + var virtualModelID uint + if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"}) + return + } + + var vm models.VirtualModel + if err := c.ShouldBindJSON(&vm); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + vm.ID = virtualModelID + + if err := db.UpdateVirtualModel(h.DB, &vm); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update virtual model"}) + return + } + + c.JSON(http.StatusOK, vm) +} + +// DeleteVirtualModelHandler 处理 DELETE /api/virtual-models/:id 请求 +func (h *APIHandler) DeleteVirtualModelHandler(c *gin.Context) { + id := c.Param("id") + var virtualModelID uint + if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"}) + return + } + + if err := db.DeleteVirtualModel(h.DB, virtualModelID); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete virtual model"}) + return + } + + c.JSON(http.StatusNoContent, nil) +}