466 lines
11 KiB
Go
466 lines
11 KiB
Go
package api
|
||
|
||
import (
|
||
"ai-gateway/internal/db"
|
||
"ai-gateway/internal/logger"
|
||
"ai-gateway/internal/models"
|
||
"ai-gateway/internal/router"
|
||
"bytes"
|
||
"encoding/json"
|
||
"io"
|
||
"log"
|
||
"net/http"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/pkoukk/tiktoken-go"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// APIHandler 持有数据库连接并处理API请求
|
||
type APIHandler struct {
|
||
DB *gorm.DB
|
||
}
|
||
|
||
// ModelListResponse 符合OpenAI /v1/models API响应格式
|
||
type ModelListResponse struct {
|
||
Object string `json:"object"`
|
||
Data []ModelData `json:"data"`
|
||
}
|
||
|
||
// ModelData 单个模型的数据结构
|
||
type ModelData struct {
|
||
ID string `json:"id"`
|
||
Object string `json:"object"`
|
||
Created int64 `json:"created"`
|
||
OwnedBy string `json:"owned_by"`
|
||
}
|
||
|
||
// ChatCompletionRequest 聊天补全请求结构
|
||
type ChatCompletionRequest struct {
|
||
Model string `json:"model"`
|
||
Messages []ChatCompletionMessage `json:"messages"`
|
||
Stream bool `json:"stream,omitempty"`
|
||
}
|
||
|
||
// ChatCompletionMessage 聊天消息结构
|
||
type ChatCompletionMessage struct {
|
||
Role string `json:"role"`
|
||
Content string `json:"content"`
|
||
}
|
||
|
||
// ResponsesRequest /v1/responses 端点请求结构
|
||
type ResponsesRequest struct {
|
||
Model string `json:"model"`
|
||
Messages []ChatCompletionMessage `json:"messages"`
|
||
Stream bool `json:"stream,omitempty"`
|
||
}
|
||
|
||
// ListModels 处理 GET /models 请求
|
||
func (h *APIHandler) ListModels(c *gin.Context) {
|
||
var virtualModels []models.VirtualModel
|
||
|
||
// 查询所有虚拟模型
|
||
if err := h.DB.Find(&virtualModels).Error; err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"error": gin.H{
|
||
"message": "Failed to retrieve models",
|
||
"type": "internal_error",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
// 格式化为OpenAI API响应格式
|
||
modelData := make([]ModelData, len(virtualModels))
|
||
for i, vm := range virtualModels {
|
||
modelData[i] = ModelData{
|
||
ID: vm.Name,
|
||
Object: "model",
|
||
Created: vm.CreatedAt.Unix(),
|
||
OwnedBy: "ai-gateway",
|
||
}
|
||
}
|
||
|
||
response := ModelListResponse{
|
||
Object: "list",
|
||
Data: modelData,
|
||
}
|
||
|
||
c.JSON(http.StatusOK, response)
|
||
}
|
||
|
||
// ChatCompletions 处理 POST /v1/chat/completions 请求
|
||
func (h *APIHandler) ChatCompletions(c *gin.Context) {
|
||
// 记录请求开始时间
|
||
requestTimestamp := time.Now()
|
||
|
||
var req ChatCompletionRequest
|
||
|
||
// 解析请求体
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"error": gin.H{
|
||
"message": "Invalid request format",
|
||
"type": "invalid_request_error",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
// 使用tiktoken精确计算请求token数
|
||
requestTokenCount := calculateTokenCount(req.Messages)
|
||
|
||
// 选择后端模型
|
||
backendModel, err := router.SelectBackendModel(h.DB, req.Model, requestTokenCount)
|
||
if err != nil {
|
||
c.JSON(http.StatusNotFound, gin.H{
|
||
"error": gin.H{
|
||
"message": err.Error(),
|
||
"type": "invalid_request_error",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
// 准备转发请求
|
||
requestBody, err := json.Marshal(req)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"error": gin.H{
|
||
"message": "Failed to process request",
|
||
"type": "internal_error",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
// 构建后端API URL
|
||
backendURL := backendModel.Provider.BaseURL + "/v1/chat/completions"
|
||
|
||
// 创建HTTP请求
|
||
httpReq, err := http.NewRequest("POST", backendURL, bytes.NewReader(requestBody))
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"error": gin.H{
|
||
"message": "Failed to create backend request",
|
||
"type": "internal_error",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
// 设置请求头
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
httpReq.Header.Set("Authorization", "Bearer "+backendModel.Provider.ApiKey)
|
||
|
||
// 复制原始请求的其他相关头部
|
||
if userAgent := c.GetHeader("User-Agent"); userAgent != "" {
|
||
httpReq.Header.Set("User-Agent", userAgent)
|
||
}
|
||
|
||
// 执行请求
|
||
client := &http.Client{
|
||
Timeout: 120 * time.Second,
|
||
}
|
||
resp, err := client.Do(httpReq)
|
||
if err != nil {
|
||
c.JSON(http.StatusBadGateway, gin.H{
|
||
"error": gin.H{
|
||
"message": "Failed to connect to backend service",
|
||
"type": "service_unavailable",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 记录响应时间
|
||
responseTimestamp := time.Now()
|
||
|
||
// 读取响应体
|
||
responseBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"error": gin.H{
|
||
"message": "Failed to read backend response",
|
||
"type": "internal_error",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
// 计算响应token数
|
||
responseTokenCount := 0
|
||
var responseData map[string]interface{}
|
||
if err := json.Unmarshal(responseBody, &responseData); err == nil {
|
||
if choices, ok := responseData["choices"].([]interface{}); ok && len(choices) > 0 {
|
||
if choice, ok := choices[0].(map[string]interface{}); ok {
|
||
if message, ok := choice["message"].(map[string]interface{}); ok {
|
||
if content, ok := message["content"].(string); ok {
|
||
responseTokenCount = calculateTokenCountFromText(content)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 计算费用
|
||
var cost float64
|
||
switch backendModel.BillingMethod {
|
||
case models.BillingMethodToken:
|
||
cost = float64(requestTokenCount)*backendModel.PromptTokenPrice + float64(responseTokenCount)*backendModel.CompletionTokenPrice
|
||
case models.BillingMethodRequest:
|
||
cost = backendModel.FixedPrice
|
||
}
|
||
|
||
// 从上下文获取API密钥
|
||
apiKeyValue, exists := c.Get("apiKey")
|
||
var apiKeyID uint
|
||
if exists {
|
||
if apiKey, ok := apiKeyValue.(models.APIKey); ok {
|
||
apiKeyID = apiKey.ID
|
||
}
|
||
}
|
||
|
||
// 创建日志记录
|
||
logEntry := &models.RequestLog{
|
||
APIKeyID: apiKeyID,
|
||
VirtualModelName: req.Model,
|
||
BackendModelName: backendModel.Name,
|
||
RequestTimestamp: requestTimestamp,
|
||
ResponseTimestamp: responseTimestamp,
|
||
RequestTokens: requestTokenCount,
|
||
ResponseTokens: responseTokenCount,
|
||
Cost: cost,
|
||
RequestBody: string(requestBody),
|
||
ResponseBody: string(responseBody),
|
||
}
|
||
|
||
// 异步记录日志
|
||
logger.LogRequest(h.DB, logEntry)
|
||
|
||
// 复制响应头
|
||
for key, values := range resp.Header {
|
||
for _, value := range values {
|
||
c.Header(key, value)
|
||
}
|
||
}
|
||
|
||
// 设置响应状态码并返回响应体
|
||
c.Status(resp.StatusCode)
|
||
c.Writer.Write(responseBody)
|
||
}
|
||
|
||
// ResponsesCompletions 处理 POST /v1/responses 请求
|
||
func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
|
||
// 记录请求开始时间
|
||
requestTimestamp := time.Now()
|
||
|
||
var req ResponsesRequest
|
||
|
||
// 解析请求\u4f53
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"error": gin.H{
|
||
"message": "Invalid request format",
|
||
"type": "invalid_request_error",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
// 使用tiktoken精确计算请求token数
|
||
requestTokenCount := calculateTokenCount(req.Messages)
|
||
|
||
// 选择后端模型
|
||
backendModel, err := router.SelectBackendModel(h.DB, req.Model, requestTokenCount)
|
||
if err != nil {
|
||
c.JSON(http.StatusNotFound, gin.H{
|
||
"error": gin.H{
|
||
"message": err.Error(),
|
||
"type": "invalid_request_error",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
// 准备转发请求
|
||
requestBody, err := json.Marshal(req)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"error": gin.H{
|
||
"message": "Failed to process request",
|
||
"type": "internal_error",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
// 构建后端API URL
|
||
backendURL := backendModel.Provider.BaseURL + "/v1/responses"
|
||
|
||
// 创建HTTP请求
|
||
httpReq, err := http.NewRequest("POST", backendURL, bytes.NewReader(requestBody))
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"error": gin.H{
|
||
"message": "Failed to create backend request",
|
||
"type": "internal_error",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
// 设置请求头
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
httpReq.Header.Set("Authorization", "Bearer "+backendModel.Provider.ApiKey)
|
||
|
||
// 复制原始请求的其他相关头部
|
||
if userAgent := c.GetHeader("User-Agent"); userAgent != "" {
|
||
httpReq.Header.Set("User-Agent", userAgent)
|
||
}
|
||
|
||
// 执行请求
|
||
client := &http.Client{
|
||
Timeout: 120 * time.Second,
|
||
}
|
||
resp, err := client.Do(httpReq)
|
||
if err != nil {
|
||
c.JSON(http.StatusBadGateway, gin.H{
|
||
"error": gin.H{
|
||
"message": "Failed to connect to backend service",
|
||
"type": "service_unavailable",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 记录响应时间
|
||
responseTimestamp := time.Now()
|
||
|
||
// 读取响应体
|
||
responseBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"error": gin.H{
|
||
"message": "Failed to read backend response",
|
||
"type": "internal_error",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
// 计算响应token数
|
||
responseTokenCount := 0
|
||
var responseData map[string]interface{}
|
||
if err := json.Unmarshal(responseBody, &responseData); err == nil {
|
||
if choices, ok := responseData["choices"].([]interface{}); ok && len(choices) > 0 {
|
||
if choice, ok := choices[0].(map[string]interface{}); ok {
|
||
if message, ok := choice["message"].(map[string]interface{}); ok {
|
||
if content, ok := message["content"].(string); ok {
|
||
responseTokenCount = calculateTokenCountFromText(content)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 计算费用
|
||
var cost float64
|
||
switch backendModel.BillingMethod {
|
||
case models.BillingMethodToken:
|
||
cost = float64(requestTokenCount)*backendModel.PromptTokenPrice + float64(responseTokenCount)*backendModel.CompletionTokenPrice
|
||
case models.BillingMethodRequest:
|
||
cost = backendModel.FixedPrice
|
||
}
|
||
|
||
// 从上下文获取API密钥
|
||
apiKeyValue, exists := c.Get("apiKey")
|
||
var apiKeyID uint
|
||
if exists {
|
||
if apiKey, ok := apiKeyValue.(models.APIKey); ok {
|
||
apiKeyID = apiKey.ID
|
||
}
|
||
}
|
||
|
||
// 创建日志记录
|
||
logEntry := &models.RequestLog{
|
||
APIKeyID: apiKeyID,
|
||
VirtualModelName: req.Model,
|
||
BackendModelName: backendModel.Name,
|
||
RequestTimestamp: requestTimestamp,
|
||
ResponseTimestamp: responseTimestamp,
|
||
RequestTokens: requestTokenCount,
|
||
ResponseTokens: responseTokenCount,
|
||
Cost: cost,
|
||
RequestBody: string(requestBody),
|
||
ResponseBody: string(responseBody),
|
||
}
|
||
|
||
// 异步记录日志
|
||
logger.LogRequest(h.DB, logEntry)
|
||
|
||
// 复制响应体
|
||
for key, values := range resp.Header {
|
||
for _, value := range values {
|
||
c.Header(key, value)
|
||
}
|
||
}
|
||
|
||
// 设置响应状态码并返回响应体
|
||
c.Status(resp.StatusCode)
|
||
c.Writer.Write(responseBody)
|
||
}
|
||
|
||
// calculateTokenCount 计算消息列表的token总数
|
||
func calculateTokenCount(messages []ChatCompletionMessage) int {
|
||
encoding, err := tiktoken.GetEncoding("cl100k_base")
|
||
if err != nil {
|
||
log.Printf("Failed to get tiktoken encoding: %v", err)
|
||
return 0
|
||
}
|
||
|
||
totalTokens := 0
|
||
for _, msg := range messages {
|
||
// 每条消息的基础开销(role + 分隔符等)
|
||
totalTokens += 4
|
||
// role的token数
|
||
totalTokens += len(encoding.Encode(msg.Role, nil, nil))
|
||
// content的token数
|
||
totalTokens += len(encoding.Encode(msg.Content, nil, nil))
|
||
}
|
||
// 对话的基础开销
|
||
totalTokens += 2
|
||
|
||
return totalTokens
|
||
}
|
||
|
||
// calculateTokenCountFromText 从文本计算token数
|
||
func calculateTokenCountFromText(text string) int {
|
||
encoding, err := tiktoken.GetEncoding("cl100k_base")
|
||
if err != nil {
|
||
log.Printf("Failed to get tiktoken encoding: %v", err)
|
||
return 0
|
||
}
|
||
|
||
return len(encoding.Encode(text, nil, nil))
|
||
}
|
||
|
||
// GetProvidersHandler 处理 GET /api/providers 请求
|
||
func (h *APIHandler) GetProvidersHandler(c *gin.Context) {
|
||
providers, err := db.GetProviders(h.DB)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"error": gin.H{
|
||
"message": "Failed to retrieve providers",
|
||
"type": "internal_error",
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, providers)
|
||
}
|