Files
AIRouter/backend/api/handlers.go
2025-11-08 17:47:47 +08:00

466 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package api
import (
"ai-gateway/internal/db"
"ai-gateway/internal/logger"
"ai-gateway/internal/models"
"ai-gateway/internal/router"
"bytes"
"encoding/json"
"io"
"log"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
"gorm.io/gorm"
)
// APIHandler 持有数据库连接并处理API请求
type APIHandler struct {
DB *gorm.DB
}
// ModelListResponse 符合OpenAI /v1/models API响应格式
type ModelListResponse struct {
Object string `json:"object"`
Data []ModelData `json:"data"`
}
// ModelData 单个模型的数据结构
type ModelData struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
// ChatCompletionRequest 聊天补全请求结构
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
Stream bool `json:"stream,omitempty"`
}
// ChatCompletionMessage 聊天消息结构
type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// ResponsesRequest /v1/responses 端点请求结构
type ResponsesRequest struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
Stream bool `json:"stream,omitempty"`
}
// ListModels 处理 GET /models 请求
func (h *APIHandler) ListModels(c *gin.Context) {
var virtualModels []models.VirtualModel
// 查询所有虚拟模型
if err := h.DB.Find(&virtualModels).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to retrieve models",
"type": "internal_error",
},
})
return
}
// 格式化为OpenAI API响应格式
modelData := make([]ModelData, len(virtualModels))
for i, vm := range virtualModels {
modelData[i] = ModelData{
ID: vm.Name,
Object: "model",
Created: vm.CreatedAt.Unix(),
OwnedBy: "ai-gateway",
}
}
response := ModelListResponse{
Object: "list",
Data: modelData,
}
c.JSON(http.StatusOK, response)
}
// ChatCompletions 处理 POST /v1/chat/completions 请求
func (h *APIHandler) ChatCompletions(c *gin.Context) {
// 记录请求开始时间
requestTimestamp := time.Now()
var req ChatCompletionRequest
// 解析请求体
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "Invalid request format",
"type": "invalid_request_error",
},
})
return
}
// 使用tiktoken精确计算请求token数
requestTokenCount := calculateTokenCount(req.Messages)
// 选择后端模型
backendModel, err := router.SelectBackendModel(h.DB, req.Model, requestTokenCount)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "invalid_request_error",
},
})
return
}
// 准备转发请求
requestBody, err := json.Marshal(req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to process request",
"type": "internal_error",
},
})
return
}
// 构建后端API URL
backendURL := backendModel.Provider.BaseURL + "/v1/chat/completions"
// 创建HTTP请求
httpReq, err := http.NewRequest("POST", backendURL, bytes.NewReader(requestBody))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to create backend request",
"type": "internal_error",
},
})
return
}
// 设置请求头
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+backendModel.Provider.ApiKey)
// 复制原始请求的其他相关头部
if userAgent := c.GetHeader("User-Agent"); userAgent != "" {
httpReq.Header.Set("User-Agent", userAgent)
}
// 执行请求
client := &http.Client{
Timeout: 120 * time.Second,
}
resp, err := client.Do(httpReq)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"message": "Failed to connect to backend service",
"type": "service_unavailable",
},
})
return
}
defer resp.Body.Close()
// 记录响应时间
responseTimestamp := time.Now()
// 读取响应体
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to read backend response",
"type": "internal_error",
},
})
return
}
// 计算响应token数
responseTokenCount := 0
var responseData map[string]interface{}
if err := json.Unmarshal(responseBody, &responseData); err == nil {
if choices, ok := responseData["choices"].([]interface{}); ok && len(choices) > 0 {
if choice, ok := choices[0].(map[string]interface{}); ok {
if message, ok := choice["message"].(map[string]interface{}); ok {
if content, ok := message["content"].(string); ok {
responseTokenCount = calculateTokenCountFromText(content)
}
}
}
}
}
// 计算费用
var cost float64
switch backendModel.BillingMethod {
case models.BillingMethodToken:
cost = float64(requestTokenCount)*backendModel.PromptTokenPrice + float64(responseTokenCount)*backendModel.CompletionTokenPrice
case models.BillingMethodRequest:
cost = backendModel.FixedPrice
}
// 从上下文获取API密钥
apiKeyValue, exists := c.Get("apiKey")
var apiKeyID uint
if exists {
if apiKey, ok := apiKeyValue.(models.APIKey); ok {
apiKeyID = apiKey.ID
}
}
// 创建日志记录
logEntry := &models.RequestLog{
APIKeyID: apiKeyID,
VirtualModelName: req.Model,
BackendModelName: backendModel.Name,
RequestTimestamp: requestTimestamp,
ResponseTimestamp: responseTimestamp,
RequestTokens: requestTokenCount,
ResponseTokens: responseTokenCount,
Cost: cost,
RequestBody: string(requestBody),
ResponseBody: string(responseBody),
}
// 异步记录日志
logger.LogRequest(h.DB, logEntry)
// 复制响应头
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
// 设置响应状态码并返回响应体
c.Status(resp.StatusCode)
c.Writer.Write(responseBody)
}
// ResponsesCompletions 处理 POST /v1/responses 请求
func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
// 记录请求开始时间
requestTimestamp := time.Now()
var req ResponsesRequest
// 解析请求\u4f53
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "Invalid request format",
"type": "invalid_request_error",
},
})
return
}
// 使用tiktoken精确计算请求token数
requestTokenCount := calculateTokenCount(req.Messages)
// 选择后端模型
backendModel, err := router.SelectBackendModel(h.DB, req.Model, requestTokenCount)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "invalid_request_error",
},
})
return
}
// 准备转发请求
requestBody, err := json.Marshal(req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to process request",
"type": "internal_error",
},
})
return
}
// 构建后端API URL
backendURL := backendModel.Provider.BaseURL + "/v1/responses"
// 创建HTTP请求
httpReq, err := http.NewRequest("POST", backendURL, bytes.NewReader(requestBody))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to create backend request",
"type": "internal_error",
},
})
return
}
// 设置请求头
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+backendModel.Provider.ApiKey)
// 复制原始请求的其他相关头部
if userAgent := c.GetHeader("User-Agent"); userAgent != "" {
httpReq.Header.Set("User-Agent", userAgent)
}
// 执行请求
client := &http.Client{
Timeout: 120 * time.Second,
}
resp, err := client.Do(httpReq)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"message": "Failed to connect to backend service",
"type": "service_unavailable",
},
})
return
}
defer resp.Body.Close()
// 记录响应时间
responseTimestamp := time.Now()
// 读取响应体
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to read backend response",
"type": "internal_error",
},
})
return
}
// 计算响应token数
responseTokenCount := 0
var responseData map[string]interface{}
if err := json.Unmarshal(responseBody, &responseData); err == nil {
if choices, ok := responseData["choices"].([]interface{}); ok && len(choices) > 0 {
if choice, ok := choices[0].(map[string]interface{}); ok {
if message, ok := choice["message"].(map[string]interface{}); ok {
if content, ok := message["content"].(string); ok {
responseTokenCount = calculateTokenCountFromText(content)
}
}
}
}
}
// 计算费用
var cost float64
switch backendModel.BillingMethod {
case models.BillingMethodToken:
cost = float64(requestTokenCount)*backendModel.PromptTokenPrice + float64(responseTokenCount)*backendModel.CompletionTokenPrice
case models.BillingMethodRequest:
cost = backendModel.FixedPrice
}
// 从上下文获取API密钥
apiKeyValue, exists := c.Get("apiKey")
var apiKeyID uint
if exists {
if apiKey, ok := apiKeyValue.(models.APIKey); ok {
apiKeyID = apiKey.ID
}
}
// 创建日志记录
logEntry := &models.RequestLog{
APIKeyID: apiKeyID,
VirtualModelName: req.Model,
BackendModelName: backendModel.Name,
RequestTimestamp: requestTimestamp,
ResponseTimestamp: responseTimestamp,
RequestTokens: requestTokenCount,
ResponseTokens: responseTokenCount,
Cost: cost,
RequestBody: string(requestBody),
ResponseBody: string(responseBody),
}
// 异步记录日志
logger.LogRequest(h.DB, logEntry)
// 复制响应体
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
// 设置响应状态码并返回响应体
c.Status(resp.StatusCode)
c.Writer.Write(responseBody)
}
// calculateTokenCount 计算消息列表的token总数
func calculateTokenCount(messages []ChatCompletionMessage) int {
encoding, err := tiktoken.GetEncoding("cl100k_base")
if err != nil {
log.Printf("Failed to get tiktoken encoding: %v", err)
return 0
}
totalTokens := 0
for _, msg := range messages {
// 每条消息的基础开销role + 分隔符等)
totalTokens += 4
// role的token数
totalTokens += len(encoding.Encode(msg.Role, nil, nil))
// content的token数
totalTokens += len(encoding.Encode(msg.Content, nil, nil))
}
// 对话的基础开销
totalTokens += 2
return totalTokens
}
// calculateTokenCountFromText 从文本计算token数
func calculateTokenCountFromText(text string) int {
encoding, err := tiktoken.GetEncoding("cl100k_base")
if err != nil {
log.Printf("Failed to get tiktoken encoding: %v", err)
return 0
}
return len(encoding.Encode(text, nil, nil))
}
// GetProvidersHandler 处理 GET /api/providers 请求
func (h *APIHandler) GetProvidersHandler(c *gin.Context) {
providers, err := db.GetProviders(h.DB)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to retrieve providers",
"type": "internal_error",
},
})
return
}
c.JSON(http.StatusOK, providers)
}