Files
AIRouter/backend/api/handlers.go

798 lines
22 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package api
import (
"ai-gateway/internal/billing"
"ai-gateway/internal/db"
"ai-gateway/internal/logger"
"ai-gateway/internal/models"
"ai-gateway/internal/router"
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// APIHandler 持有数据库连接并处理API请求
type APIHandler struct {
DB *gorm.DB
}
// ModelListResponse 符合OpenAI /v1/models API响应格式
type ModelListResponse struct {
Object string `json:"object"`
Data []ModelData `json:"data"`
}
// ModelData 单个模型的数据结构
type ModelData struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
// ChatCompletionRequest 聊天补全请求结构
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
Stream bool `json:"stream,omitempty"`
}
// ChatCompletionMessage 聊天消息结构
type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// BackendChatCompletionRequest 是实际发送到后端模型的请求结构
// 它只包含通用字段,以避免发送不被支持的参数
type BackendChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
Stream bool `json:"stream,omitempty"`
}
// ResponsesRequest /v1/responses 端点请求结构
type ResponsesRequest struct {
Model string `json:"model"`
Input []ChatCompletionMessage `json:"input"`
Stream bool `json:"stream,omitempty"`
}
// BackendModelAssociation 代表一个后端模型关联及其配置
type BackendModelAssociation struct {
BackendModelID uint `json:"backend_model_id" binding:"required"`
Priority int `json:"priority" binding:"required"`
CostThreshold float64 `json:"cost_threshold"`
}
// CreateVirtualModelRequest 创建虚拟模型的请求结构
type CreateVirtualModelRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
BackendModels []BackendModelAssociation `json:"backend_models"`
}
// UpdateVirtualModelRequest 更新虚拟模型的请求结构
type UpdateVirtualModelRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
BackendModels []BackendModelAssociation `json:"backend_models"`
}
// ListModels 处理 GET /models 请求
func (h *APIHandler) ListModels(c *gin.Context) {
var virtualModels []models.VirtualModel
// 查询所有虚拟模型
if err := h.DB.Find(&virtualModels).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to retrieve models",
"type": "internal_error",
},
})
return
}
// 格式化为OpenAI API响应格式
modelData := make([]ModelData, len(virtualModels))
for i, vm := range virtualModels {
modelData[i] = ModelData{
ID: vm.Name,
Object: "model",
Created: vm.CreatedAt.Unix(),
OwnedBy: "ai-gateway",
}
}
response := ModelListResponse{
Object: "list",
Data: modelData,
}
c.JSON(http.StatusOK, response)
}
// ChatCompletions 处理 POST /v1/chat/completions 请求
func (h *APIHandler) ChatCompletions(c *gin.Context) {
// 记录请求开始时间
requestTimestamp := time.Now()
// 读取请求体
bodyBytes, _ := io.ReadAll(c.Request.Body)
// 解析为 map只解析一次
var jsonBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
log.Printf("Failed to unmarshal JSON: %v", err)
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "Invalid request format",
"type": "invalid_request_error",
},
})
return
}
// 从 map 中提取需要的字段
modelName, _ := jsonBody["model"].(string)
isStream, _ := jsonBody["stream"].(bool)
// 提取 messages 并转换为计算 token 所需的格式
var messages []billing.ChatCompletionMessage
if messagesRaw, ok := jsonBody["messages"].([]interface{}); ok {
messages = make([]billing.ChatCompletionMessage, 0, len(messagesRaw))
for _, msgRaw := range messagesRaw {
if msgMap, ok := msgRaw.(map[string]interface{}); ok {
role, _ := msgMap["role"].(string)
content, _ := msgMap["content"].(string)
messages = append(messages, billing.ChatCompletionMessage{
Role: role,
Content: content,
})
}
}
}
// 计算请求token数
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
// 选择后端模型
backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "invalid_request_error",
},
})
return
}
// 修改 model 字段为后端模型名称
jsonBody["model"] = backendModel.Name
// Marshal 请求体用于转发
requestBody, err := json.Marshal(jsonBody)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to process request",
"type": "internal_error",
},
})
return
}
// 构建后端API URL
backendURL := backendModel.Provider.BaseURL + "/v1/chat/completions"
// 转发请求到后端
resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent"))
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"message": "Failed to connect to backend service",
"type": "service_unavailable",
},
})
return
}
defer resp.Body.Close()
// 记录响应时间
responseTimestamp := time.Now()
// 获取API Key ID
apiKeyID := getAPIKeyID(c)
// 处理流式响应
if isStream {
handleStreamingResponse(c, resp, requestTimestamp, responseTimestamp,
apiKeyID, modelName, backendModel, requestTokenCount,
string(requestBody), h.DB)
return
}
// 处理非流式响应
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to read backend response",
"type": "internal_error",
},
})
return
}
log.Printf("Backend Response Status: %s", resp.Status)
log.Printf("Backend Response Body: %s", string(responseBody))
// 计算响应token数
responseTokenCount := extractResponseTokenCount(responseBody)
// 计算费用
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 创建并记录日志
logEntry := createRequestLog(apiKeyID, modelName, backendModel,
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
cost, string(requestBody), string(responseBody))
logger.LogRequest(h.DB, logEntry)
// 复制响应头并返回响应
copyResponseHeaders(c, resp)
c.Status(resp.StatusCode)
c.Writer.Write(responseBody)
}
// ResponsesCompletions 处理 POST /v1/responses 请求
func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
// 记录请求开始时间
requestTimestamp := time.Now()
// 读取请求体
bodyBytes, _ := io.ReadAll(c.Request.Body)
// 解析为 map只解析一次
var jsonBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "Invalid request format",
"type": "invalid_request_error",
},
})
return
}
// 从 map 中提取需要的字段
modelName, _ := jsonBody["model"].(string)
// 提取 input 并转换为计算 token 所需的格式
var messages []billing.ChatCompletionMessage
if inputRaw, ok := jsonBody["input"].([]interface{}); ok {
messages = make([]billing.ChatCompletionMessage, 0, len(inputRaw))
for _, msgRaw := range inputRaw {
if msgMap, ok := msgRaw.(map[string]interface{}); ok {
role, _ := msgMap["role"].(string)
content, _ := msgMap["content"].(string)
messages = append(messages, billing.ChatCompletionMessage{
Role: role,
Content: content,
})
}
}
}
// 计算请求token数
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
// 选择后端模型
backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "invalid_request_error",
},
})
return
}
// 修改 model 字段为后端模型名称
jsonBody["model"] = backendModel.Name
// Marshal 请求体用于转发
requestBody, err := json.Marshal(jsonBody)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to process request",
"type": "internal_error",
},
})
return
}
// 构建后端API URL
backendURL := backendModel.Provider.BaseURL + "/v1/responses"
// 转发请求到后端
resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent"))
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"message": "Failed to connect to backend service",
"type": "service_unavailable",
},
})
return
}
defer resp.Body.Close()
// 记录响应时间
responseTimestamp := time.Now()
// 读取响应体
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to read backend response",
"type": "internal_error",
},
})
return
}
// 计算响应token数
responseTokenCount := extractResponseTokenCount(responseBody)
// 计算费用
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 获取API Key ID
apiKeyID := getAPIKeyID(c)
// 创建并记录日志
logEntry := createRequestLog(apiKeyID, modelName, backendModel,
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
cost, string(requestBody), string(responseBody))
logger.LogRequest(h.DB, logEntry)
// 复制响应头并返回响应
copyResponseHeaders(c, resp)
c.Status(resp.StatusCode)
c.Writer.Write(responseBody)
}
// convertToTikTokenMessages 将ChatCompletionMessage转换为billing包的消息格式
func convertToTikTokenMessages(messages []ChatCompletionMessage) []billing.ChatCompletionMessage {
result := make([]billing.ChatCompletionMessage, len(messages))
for i, msg := range messages {
result[i] = billing.ChatCompletionMessage{
Role: msg.Role,
Content: msg.Content,
}
}
return result
}
// prepareBackendRequest 准备后端请求体,替换model字段
func prepareBackendRequest(bodyBytes []byte, backendModelName string) ([]byte, error) {
var jsonBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
return nil, fmt.Errorf("failed to unmarshal request body: %w", err)
}
// 修改model字段为后端模型名称
jsonBody["model"] = backendModelName
// 重新marshal
requestBody, err := json.Marshal(jsonBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
return requestBody, nil
}
// forwardRequest 转发请求到后端API
func forwardRequest(backendURL string, apiKey string, requestBody []byte, userAgent string) (*http.Response, error) {
// 创建HTTP请求
httpReq, err := http.NewRequest("POST", backendURL, bytes.NewReader(requestBody))
if err != nil {
return nil, fmt.Errorf("failed to create backend request: %w", err)
}
// 设置请求头
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
// 设置User-Agent
if userAgent != "" {
httpReq.Header.Set("User-Agent", userAgent)
}
// 执行请求
client := &http.Client{
Timeout: 10 * time.Minute,
}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to connect to backend service: %w", err)
}
return resp, nil
}
// extractResponseTokenCount 从响应体中提取token数量
func extractResponseTokenCount(responseBody []byte) int {
var responseData map[string]interface{}
if err := json.Unmarshal(responseBody, &responseData); err != nil {
return 0
}
if choices, ok := responseData["choices"].([]interface{}); ok && len(choices) > 0 {
if choice, ok := choices[0].(map[string]interface{}); ok {
if message, ok := choice["message"].(map[string]interface{}); ok {
if content, ok := message["content"].(string); ok {
return billing.CalculateTextTokensSimple(content)
}
}
}
}
return 0
}
// createRequestLog 创建请求日志记录
func createRequestLog(apiKeyID uint, virtualModelName string, backendModel *models.BackendModel,
requestTimestamp, responseTimestamp time.Time, requestTokenCount, responseTokenCount int,
cost float64, requestBody, responseBody string) *models.RequestLog {
return &models.RequestLog{
APIKeyID: apiKeyID,
VirtualModelName: virtualModelName,
BackendModelName: backendModel.Name,
RequestTimestamp: requestTimestamp,
ResponseTimestamp: responseTimestamp,
RequestTokens: requestTokenCount,
ResponseTokens: responseTokenCount,
Cost: cost,
RequestBody: requestBody,
ResponseBody: responseBody,
}
}
// getAPIKeyID 从gin上下文中提取API Key ID
func getAPIKeyID(c *gin.Context) uint {
apiKeyValue, exists := c.Get("apiKey")
if !exists {
return 0
}
if apiKey, ok := apiKeyValue.(models.APIKey); ok {
return apiKey.ID
}
return 0
}
// copyResponseHeaders 复制响应头到gin context
func copyResponseHeaders(c *gin.Context, resp *http.Response) {
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
}
// handleStreamingResponse 处理流式响应
func handleStreamingResponse(c *gin.Context, resp *http.Response, requestTimestamp, responseTimestamp time.Time,
apiKeyID uint, virtualModelName string, backendModel *models.BackendModel, requestTokenCount int,
requestBody string, database *gorm.DB) {
// 设置流式响应\u5934
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Transfer-Encoding", "chunked")
// 复制其他响应头
for key, values := range resp.Header {
if key != "Content-Length" && key != "Transfer-Encoding" {
for _, value := range values {
c.Header(key, value)
}
}
}
c.Status(resp.StatusCode)
// 用于累积完整响应内容
var fullContent strings.Builder
var responseBody strings.Builder
// 创建一个 scanner 来逐行读取流
scanner := bufio.NewScanner(resp.Body)
flusher, ok := c.Writer.(http.Flusher)
for scanner.Scan() {
line := scanner.Text()
// 将原始行写入响应体缓冲区
responseBody.WriteString(line)
responseBody.WriteString("\n")
// SSE 格式data: {...}
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
// 检查是否是结束标记
if data == "[DONE]" {
_, err := c.Writer.Write([]byte(line + "\n\n"))
if err != nil {
return
}
if ok {
flusher.Flush()
}
break
}
// 解析 JSON 数据以提取内容
var chunk map[string]interface{}
if err := json.Unmarshal([]byte(data), &chunk); err == nil {
if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 {
if choice, ok := choices[0].(map[string]interface{}); ok {
if delta, ok := choice["delta"].(map[string]interface{}); ok {
if content, ok := delta["content"].(string); ok {
fullContent.WriteString(content)
}
}
}
}
}
}
// 转发数据到客户端
_, err := c.Writer.Write([]byte(line + "\n"))
if err != nil {
return
}
// 如果是空行SSE 消息分隔符),刷新
if line == "" {
if ok {
flusher.Flush()
}
}
}
// 确保发送最后的数据
if ok {
flusher.Flush()
}
// 扫描可能出现的错误
if err := scanner.Err(); err != nil {
log.Printf("Error reading stream: %v", err)
}
// 计算响应 token 数
responseTokenCount := billing.CalculateTextTokensSimple(fullContent.String())
// 计算费用
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 创建日志记录
logEntry := &models.RequestLog{
APIKeyID: apiKeyID,
VirtualModelName: virtualModelName,
BackendModelName: backendModel.Name,
RequestTimestamp: requestTimestamp,
ResponseTimestamp: time.Now(), // 使用实际结束时间
RequestTokens: requestTokenCount,
ResponseTokens: responseTokenCount,
Cost: cost,
RequestBody: requestBody,
ResponseBody: responseBody.String(),
}
// 异步记录日志
logger.LogRequest(database, logEntry)
}
// GetProvidersHandler 处理 GET /api/providers 请求
func (h *APIHandler) GetProvidersHandler(c *gin.Context) {
providers, err := db.GetProviders(h.DB)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to retrieve providers",
"type": "internal_error",
},
})
return
}
c.JSON(http.StatusOK, providers)
}
// GetProviderHandler 处理 GET /api/providers/:id 请求
func (h *APIHandler) GetProviderHandler(c *gin.Context) {
id := c.Param("id")
var providerID uint
if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
return
}
provider, err := db.GetProviderByID(h.DB, providerID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Provider not found"})
return
}
c.JSON(http.StatusOK, provider)
}
// CreateProviderHandler 处理 POST /api/providers 请求
func (h *APIHandler) CreateProviderHandler(c *gin.Context) {
var provider models.Provider
if err := c.ShouldBindJSON(&provider); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := db.CreateProvider(h.DB, &provider); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider"})
return
}
c.JSON(http.StatusCreated, provider)
}
// UpdateProviderHandler 处理 PUT /api/providers/:id 请求
func (h *APIHandler) UpdateProviderHandler(c *gin.Context) {
id := c.Param("id")
var providerID uint
if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
return
}
var provider models.Provider
if err := c.ShouldBindJSON(&provider); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
provider.ID = providerID
if err := db.UpdateProvider(h.DB, &provider); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider"})
return
}
c.JSON(http.StatusOK, provider)
}
// DeleteProviderHandler 处理 DELETE /api/providers/:id 请求
func (h *APIHandler) DeleteProviderHandler(c *gin.Context) {
id := c.Param("id")
var providerID uint
if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
return
}
if err := db.DeleteProvider(h.DB, providerID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete provider"})
return
}
c.JSON(http.StatusNoContent, nil)
}
// CreateVirtualModelHandler 处理 POST /api/virtual-models 请求
func (h *APIHandler) CreateVirtualModelHandler(c *gin.Context) {
var vm models.VirtualModel
if err := c.ShouldBindJSON(&vm); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := db.CreateVirtualModel(h.DB, &vm); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create virtual model"})
return
}
c.JSON(http.StatusCreated, vm)
}
// GetVirtualModelsHandler 处理 GET /api/virtual-models 请求
func (h *APIHandler) GetVirtualModelsHandler(c *gin.Context) {
virtualModels, err := db.GetVirtualModels(h.DB)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve virtual models"})
return
}
c.JSON(http.StatusOK, virtualModels)
}
// GetVirtualModelHandler 处理 GET /api/virtual-models/:id 请求
func (h *APIHandler) GetVirtualModelHandler(c *gin.Context) {
id := c.Param("id")
var virtualModelID uint
if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"})
return
}
virtualModel, err := db.GetVirtualModelByID(h.DB, virtualModelID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Virtual model not found"})
return
}
c.JSON(http.StatusOK, virtualModel)
}
// UpdateVirtualModelHandler 处理 PUT /api/virtual-models/:id 请求
func (h *APIHandler) UpdateVirtualModelHandler(c *gin.Context) {
id := c.Param("id")
var virtualModelID uint
if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"})
return
}
var vm models.VirtualModel
if err := c.ShouldBindJSON(&vm); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
vm.ID = virtualModelID
if err := db.UpdateVirtualModel(h.DB, &vm); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update virtual model"})
return
}
c.JSON(http.StatusOK, vm)
}
// DeleteVirtualModelHandler 处理 DELETE /api/virtual-models/:id 请求
func (h *APIHandler) DeleteVirtualModelHandler(c *gin.Context) {
id := c.Param("id")
var virtualModelID uint
if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"})
return
}
if err := db.DeleteVirtualModel(h.DB, virtualModelID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete virtual model"})
return
}
c.JSON(http.StatusNoContent, nil)
}