Files
AIRouter/backend/api/handlers.go

780 lines
21 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package api
import (
"ai-gateway/internal/billing"
"ai-gateway/internal/db"
"ai-gateway/internal/logger"
"ai-gateway/internal/models"
"ai-gateway/internal/router"
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// APIHandler 持有数据库连接并处理API请求
type APIHandler struct {
DB *gorm.DB
}
// ModelListResponse 符合OpenAI /v1/models API响应格式
type ModelListResponse struct {
Object string `json:"object"`
Data []ModelData `json:"data"`
}
// ModelData 单个模型的数据结构
type ModelData struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
// ChatCompletionRequest 聊天补全请求结构
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
Stream bool `json:"stream,omitempty"`
}
// ChatCompletionMessage 聊天消息结构
type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// BackendChatCompletionRequest 是实际发送到后端模型的请求结构
// 它只包含通用字段,以避免发送不被支持的参数
type BackendChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
Stream bool `json:"stream,omitempty"`
}
// ResponsesRequest /v1/responses 端点请求结构
type ResponsesRequest struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
Stream bool `json:"stream,omitempty"`
}
// BackendModelAssociation 代表一个后端模型关联及其配置
type BackendModelAssociation struct {
BackendModelID uint `json:"backend_model_id" binding:"required"`
Priority int `json:"priority" binding:"required"`
CostThreshold float64 `json:"cost_threshold"`
}
// CreateVirtualModelRequest 创建虚拟模型的请求结构
type CreateVirtualModelRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
BackendModels []BackendModelAssociation `json:"backend_models"`
}
// UpdateVirtualModelRequest 更新虚拟模型的请求结构
type UpdateVirtualModelRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
BackendModels []BackendModelAssociation `json:"backend_models"`
}
// ListModels 处理 GET /models 请求
func (h *APIHandler) ListModels(c *gin.Context) {
var virtualModels []models.VirtualModel
// 查询所有虚拟模型
if err := h.DB.Find(&virtualModels).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to retrieve models",
"type": "internal_error",
},
})
return
}
// 格式化为OpenAI API响应格式
modelData := make([]ModelData, len(virtualModels))
for i, vm := range virtualModels {
modelData[i] = ModelData{
ID: vm.Name,
Object: "model",
Created: vm.CreatedAt.Unix(),
OwnedBy: "ai-gateway",
}
}
response := ModelListResponse{
Object: "list",
Data: modelData,
}
c.JSON(http.StatusOK, response)
}
// ChatCompletions 处理 POST /v1/chat/completions 请求
func (h *APIHandler) ChatCompletions(c *gin.Context) {
// 记录请求开始时间
requestTimestamp := time.Now()
var req ChatCompletionRequest
// 增加日志记录请求体
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // 把body重新写回去
// 解析请求体
if err := c.ShouldBindJSON(&req); err != nil {
log.Printf("Failed to bind JSON: %v", err) // 增加错误日志
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "Invalid request format",
"type": "invalid_request_error",
},
})
return
}
// 使用tiktoken精确计算请求token数
messages := convertToTikTokenMessages(req.Messages)
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
// 选择后端模型
backendModel, err := router.SelectBackendModel(h.DB, req.Model, requestTokenCount)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "invalid_request_error",
},
})
return
}
// 准备转发请求
// 用body创建一个json对象
var jsonBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to process request",
"type": "internal_error",
},
})
return
}
// 然后修改jsonBody的model字段
jsonBody["model"] = backendModel.Name
// 最后重新marshal回requestBody
requestBody, err := json.Marshal(jsonBody)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to process request",
"type": "internal_error",
},
})
return
}
// 构建后端API URL
backendURL := backendModel.Provider.BaseURL + "/v1/chat/completions"
// 创建HTTP请求
httpReq, err := http.NewRequest("POST", backendURL, bytes.NewReader(requestBody))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to create backend request",
"type": "internal_error",
},
})
return
}
// 设置请求头
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+backendModel.Provider.ApiKey)
// 复制原始请求的其他相关头部
if userAgent := c.GetHeader("User-Agent"); userAgent != "" {
httpReq.Header.Set("User-Agent", userAgent)
}
// 执行请求
client := &http.Client{
Timeout: 120 * time.Second,
}
resp, err := client.Do(httpReq)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"message": "Failed to connect to backend service",
"type": "service_unavailable",
},
})
return
}
defer resp.Body.Close()
// 记录响应时间
responseTimestamp := time.Now()
// 从上下文获取API密钥
apiKeyValue, exists := c.Get("apiKey")
var apiKeyID uint
if exists {
if apiKey, ok := apiKeyValue.(models.APIKey); ok {
apiKeyID = apiKey.ID
}
}
// 检查是否是流式请求
if req.Stream {
// 使用流式响应处理函数
handleStreamingResponse(c, resp, requestTimestamp, responseTimestamp,
apiKeyID, req.Model, backendModel, requestTokenCount,
string(requestBody), h.DB)
return
}
// 非流式响应处理
// 读取响应体
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to read backend response",
"type": "internal_error",
},
})
return
}
// 增加日志记录后端响应
log.Printf("Backend Response Status: %s", resp.Status)
log.Printf("Backend Response Body: %s", string(responseBody))
// 计算响应token数
responseTokenCount := 0
var responseData map[string]interface{}
if err := json.Unmarshal(responseBody, &responseData); err == nil {
if choices, ok := responseData["choices"].([]interface{}); ok && len(choices) > 0 {
if choice, ok := choices[0].(map[string]interface{}); ok {
if message, ok := choice["message"].(map[string]interface{}); ok {
if content, ok := message["content"].(string); ok {
responseTokenCount = billing.CalculateTextTokensSimple(content)
}
}
}
}
}
// 计算费用
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 创建日志记录
logEntry := &models.RequestLog{
APIKeyID: apiKeyID,
VirtualModelName: req.Model,
BackendModelName: backendModel.Name,
RequestTimestamp: requestTimestamp,
ResponseTimestamp: responseTimestamp,
RequestTokens: requestTokenCount,
ResponseTokens: responseTokenCount,
Cost: cost,
RequestBody: string(requestBody),
ResponseBody: string(responseBody),
}
// 异步记录日志
logger.LogRequest(h.DB, logEntry)
// 复制响应头
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
// 设置响应状态码并返回响应体
c.Status(resp.StatusCode)
c.Writer.Write(responseBody)
}
// ResponsesCompletions 处理 POST /v1/responses 请求
func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
// 记录请求开始时间
requestTimestamp := time.Now()
var req ResponsesRequest
// 解析请求体
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "Invalid request format",
"type": "invalid_request_error",
},
})
return
}
// 使用tiktoken精确计算请求token数
messages := convertToTikTokenMessages(req.Messages)
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
// 选择后端模型
backendModel, err := router.SelectBackendModel(h.DB, req.Model, requestTokenCount)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "invalid_request_error",
},
})
return
}
// 准备转发请求
requestBody, err := json.Marshal(req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to process request",
"type": "internal_error",
},
})
return
}
// 构建后端API URL
backendURL := backendModel.Provider.BaseURL + "/v1/responses"
// 创建HTTP请求
httpReq, err := http.NewRequest("POST", backendURL, bytes.NewReader(requestBody))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to create backend request",
"type": "internal_error",
},
})
return
}
// 设置请求头
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+backendModel.Provider.ApiKey)
// 复制原始请求的其他相关头部
if userAgent := c.GetHeader("User-Agent"); userAgent != "" {
httpReq.Header.Set("User-Agent", userAgent)
}
// 执行请求
client := &http.Client{
Timeout: 120 * time.Second,
}
resp, err := client.Do(httpReq)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"message": "Failed to connect to backend service",
"type": "service_unavailable",
},
})
return
}
defer resp.Body.Close()
// 记录响应时间
responseTimestamp := time.Now()
// 读取响应体
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to read backend response",
"type": "internal_error",
},
})
return
}
// 计算响应token数
responseTokenCount := 0
var responseData map[string]interface{}
if err := json.Unmarshal(responseBody, &responseData); err == nil {
if choices, ok := responseData["choices"].([]interface{}); ok && len(choices) > 0 {
if choice, ok := choices[0].(map[string]interface{}); ok {
if message, ok := choice["message"].(map[string]interface{}); ok {
if content, ok := message["content"].(string); ok {
responseTokenCount = billing.CalculateTextTokensSimple(content)
}
}
}
}
}
// 计算费用
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 从上下文获取API密钥
apiKeyValue, exists := c.Get("apiKey")
var apiKeyID uint
if exists {
if apiKey, ok := apiKeyValue.(models.APIKey); ok {
apiKeyID = apiKey.ID
}
}
// 创建日志记录
logEntry := &models.RequestLog{
APIKeyID: apiKeyID,
VirtualModelName: req.Model,
BackendModelName: backendModel.Name,
RequestTimestamp: requestTimestamp,
ResponseTimestamp: responseTimestamp,
RequestTokens: requestTokenCount,
ResponseTokens: responseTokenCount,
Cost: cost,
RequestBody: string(requestBody),
ResponseBody: string(responseBody),
}
// 异步记录日志
logger.LogRequest(h.DB, logEntry)
// 复制响应体
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
// 设置响应状态码并返回响应体
c.Status(resp.StatusCode)
c.Writer.Write(responseBody)
}
// convertToTikTokenMessages 将ChatCompletionMessage转换为billing包的消息格式
func convertToTikTokenMessages(messages []ChatCompletionMessage) []billing.ChatCompletionMessage {
result := make([]billing.ChatCompletionMessage, len(messages))
for i, msg := range messages {
result[i] = billing.ChatCompletionMessage{
Role: msg.Role,
Content: msg.Content,
}
}
return result
}
// handleStreamingResponse 处理流式响应
func handleStreamingResponse(c *gin.Context, resp *http.Response, requestTimestamp, responseTimestamp time.Time,
apiKeyID uint, virtualModelName string, backendModel *models.BackendModel, requestTokenCount int,
requestBody string, database *gorm.DB) {
// 设置流式响应\u5934
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Transfer-Encoding", "chunked")
// 复制其他响应头
for key, values := range resp.Header {
if key != "Content-Length" && key != "Transfer-Encoding" {
for _, value := range values {
c.Header(key, value)
}
}
}
c.Status(resp.StatusCode)
// 用于累积完整响应内容
var fullContent strings.Builder
var responseBody strings.Builder
// 创建一个 scanner 来逐行读取流
scanner := bufio.NewScanner(resp.Body)
flusher, ok := c.Writer.(http.Flusher)
for scanner.Scan() {
line := scanner.Text()
// 将原始行写入响应体缓冲区
responseBody.WriteString(line)
responseBody.WriteString("\n")
// SSE 格式data: {...}
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
// 检查是否是结束标记
if data == "[DONE]" {
_, err := c.Writer.Write([]byte(line + "\n\n"))
if err != nil {
return
}
if ok {
flusher.Flush()
}
break
}
// 解析 JSON 数据以提取内容
var chunk map[string]interface{}
if err := json.Unmarshal([]byte(data), &chunk); err == nil {
if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 {
if choice, ok := choices[0].(map[string]interface{}); ok {
if delta, ok := choice["delta"].(map[string]interface{}); ok {
if content, ok := delta["content"].(string); ok {
fullContent.WriteString(content)
}
}
}
}
}
}
// 转发数据到客户端
_, err := c.Writer.Write([]byte(line + "\n"))
if err != nil {
return
}
// 如果是空行SSE 消息分隔符),刷新
if line == "" {
if ok {
flusher.Flush()
}
}
}
// 确保发送最后的数据
if ok {
flusher.Flush()
}
// 扫描可能出现的错误
if err := scanner.Err(); err != nil {
log.Printf("Error reading stream: %v", err)
}
// 计算响应 token 数
responseTokenCount := billing.CalculateTextTokensSimple(fullContent.String())
// 计算费用
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 创建日志记录
logEntry := &models.RequestLog{
APIKeyID: apiKeyID,
VirtualModelName: virtualModelName,
BackendModelName: backendModel.Name,
RequestTimestamp: requestTimestamp,
ResponseTimestamp: time.Now(), // 使用实际结束时间
RequestTokens: requestTokenCount,
ResponseTokens: responseTokenCount,
Cost: cost,
RequestBody: requestBody,
ResponseBody: responseBody.String(),
}
// 异步记录日志
logger.LogRequest(database, logEntry)
}
// GetProvidersHandler 处理 GET /api/providers 请求
func (h *APIHandler) GetProvidersHandler(c *gin.Context) {
providers, err := db.GetProviders(h.DB)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to retrieve providers",
"type": "internal_error",
},
})
return
}
c.JSON(http.StatusOK, providers)
}
// GetProviderHandler 处理 GET /api/providers/:id 请求
func (h *APIHandler) GetProviderHandler(c *gin.Context) {
id := c.Param("id")
var providerID uint
if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
return
}
provider, err := db.GetProviderByID(h.DB, providerID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Provider not found"})
return
}
c.JSON(http.StatusOK, provider)
}
// CreateProviderHandler 处理 POST /api/providers 请求
func (h *APIHandler) CreateProviderHandler(c *gin.Context) {
var provider models.Provider
if err := c.ShouldBindJSON(&provider); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := db.CreateProvider(h.DB, &provider); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider"})
return
}
c.JSON(http.StatusCreated, provider)
}
// UpdateProviderHandler 处理 PUT /api/providers/:id 请求
func (h *APIHandler) UpdateProviderHandler(c *gin.Context) {
id := c.Param("id")
var providerID uint
if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
return
}
var provider models.Provider
if err := c.ShouldBindJSON(&provider); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
provider.ID = providerID
if err := db.UpdateProvider(h.DB, &provider); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider"})
return
}
c.JSON(http.StatusOK, provider)
}
// DeleteProviderHandler 处理 DELETE /api/providers/:id 请求
func (h *APIHandler) DeleteProviderHandler(c *gin.Context) {
id := c.Param("id")
var providerID uint
if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
return
}
if err := db.DeleteProvider(h.DB, providerID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete provider"})
return
}
c.JSON(http.StatusNoContent, nil)
}
// CreateVirtualModelHandler 处理 POST /api/virtual-models 请求
func (h *APIHandler) CreateVirtualModelHandler(c *gin.Context) {
var vm models.VirtualModel
if err := c.ShouldBindJSON(&vm); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := db.CreateVirtualModel(h.DB, &vm); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create virtual model"})
return
}
c.JSON(http.StatusCreated, vm)
}
// GetVirtualModelsHandler 处理 GET /api/virtual-models 请求
func (h *APIHandler) GetVirtualModelsHandler(c *gin.Context) {
virtualModels, err := db.GetVirtualModels(h.DB)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve virtual models"})
return
}
c.JSON(http.StatusOK, virtualModels)
}
// GetVirtualModelHandler 处理 GET /api/virtual-models/:id 请求
func (h *APIHandler) GetVirtualModelHandler(c *gin.Context) {
id := c.Param("id")
var virtualModelID uint
if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"})
return
}
virtualModel, err := db.GetVirtualModelByID(h.DB, virtualModelID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Virtual model not found"})
return
}
c.JSON(http.StatusOK, virtualModel)
}
// UpdateVirtualModelHandler 处理 PUT /api/virtual-models/:id 请求
func (h *APIHandler) UpdateVirtualModelHandler(c *gin.Context) {
id := c.Param("id")
var virtualModelID uint
if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"})
return
}
var vm models.VirtualModel
if err := c.ShouldBindJSON(&vm); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
vm.ID = virtualModelID
if err := db.UpdateVirtualModel(h.DB, &vm); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update virtual model"})
return
}
c.JSON(http.StatusOK, vm)
}
// DeleteVirtualModelHandler 处理 DELETE /api/virtual-models/:id 请求
func (h *APIHandler) DeleteVirtualModelHandler(c *gin.Context) {
id := c.Param("id")
var virtualModelID uint
if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"})
return
}
if err := db.DeleteVirtualModel(h.DB, virtualModelID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete virtual model"})
return
}
c.JSON(http.StatusNoContent, nil)
}