Files
AIRouter/backend/api/handlers.go

175 lines
5.0 KiB
Go

package api
import (
"ai-gateway/internal/billing"
"ai-gateway/internal/models"
"bytes"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// APIHandler 持有数据库连接并处理API请求
type APIHandler struct {
DB *gorm.DB
}
// ModelListResponse 符合OpenAI /v1/models API响应格式
type ModelListResponse struct {
Object string `json:"object"`
Data []ModelData `json:"data"`
}
// ModelData 单个模型的数据结构
type ModelData struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
// ChatCompletionRequest 聊天补全请求结构
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
Stream bool `json:"stream,omitempty"`
}
// ChatCompletionMessage 聊天消息结构
type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// BackendChatCompletionRequest 是实际发送到后端模型的请求结构
// 它只包含通用字段,以避免发送不被支持的参数
type BackendChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
Stream bool `json:"stream,omitempty"`
}
// ResponsesRequest /v1/responses 端点请求结构
type ResponsesRequest struct {
Model string `json:"model"`
Input []ChatCompletionMessage `json:"input"`
Stream bool `json:"stream,omitempty"`
}
// BackendModelAssociation 代表一个后端模型关联及其配置
type BackendModelAssociation struct {
BackendModelID uint `json:"backend_model_id" binding:"required"`
Priority int `json:"priority" binding:"required"`
CostThreshold float64 `json:"cost_threshold"`
}
// CreateVirtualModelRequest 创建虚拟模型的请求结构
type CreateVirtualModelRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
BackendModels []BackendModelAssociation `json:"backend_models"`
}
// UpdateVirtualModelRequest 更新虚拟模型的请求结构
type UpdateVirtualModelRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
BackendModels []BackendModelAssociation `json:"backend_models"`
}
// forwardRequest 转发请求到后端API
func forwardRequest(backendURL string, apiKey string, requestBody []byte, userAgent string) (*http.Response, error) {
// 创建HTTP请求
httpReq, err := http.NewRequest("POST", backendURL, bytes.NewReader(requestBody))
if err != nil {
return nil, fmt.Errorf("failed to create backend request: %w", err)
}
// 设置请求头
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
// 设置User-Agent
if userAgent != "" {
httpReq.Header.Set("User-Agent", userAgent)
}
// 执行请求
client := &http.Client{
Timeout: 10 * time.Minute,
}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to connect to backend service: %w", err)
}
return resp, nil
}
// extractResponseTokenCount 从响应体中提取token数量
func extractResponseTokenCount(responseBody []byte) int {
var responseData map[string]interface{}
if err := json.Unmarshal(responseBody, &responseData); err != nil {
return 0
}
if choices, ok := responseData["choices"].([]interface{}); ok && len(choices) > 0 {
if choice, ok := choices[0].(map[string]interface{}); ok {
if message, ok := choice["message"].(map[string]interface{}); ok {
if content, ok := message["content"].(string); ok {
return billing.CalculateTextTokensSimple(content)
}
}
}
}
return 0
}
// createRequestLog 创建请求日志记录
func createRequestLog(apiKeyID uint, virtualModelName string, backendModel *models.BackendModel,
requestTimestamp, responseTimestamp time.Time, requestTokenCount, responseTokenCount int,
cost float64, requestBody, responseBody string) *models.RequestLog {
return &models.RequestLog{
APIKeyID: apiKeyID,
VirtualModelName: virtualModelName,
BackendModelName: backendModel.Name,
RequestTimestamp: requestTimestamp,
ResponseTimestamp: responseTimestamp,
RequestTokens: requestTokenCount,
ResponseTokens: responseTokenCount,
Cost: cost,
RequestBody: requestBody,
ResponseBody: responseBody,
}
}
// getAPIKeyID 从gin上下文中提取API Key ID
func getAPIKeyID(c *gin.Context) uint {
apiKeyValue, exists := c.Get("apiKey")
if !exists {
return 0
}
if apiKey, ok := apiKeyValue.(models.APIKey); ok {
return apiKey.ID
}
return 0
}
// copyResponseHeaders 复制响应头到gin context
func copyResponseHeaders(c *gin.Context, resp *http.Response) {
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
}