175 lines
5.0 KiB
Go
175 lines
5.0 KiB
Go
package api
|
|
|
|
import (
|
|
"ai-gateway/internal/billing"
|
|
"ai-gateway/internal/models"
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// APIHandler 持有数据库连接并处理API请求
|
|
type APIHandler struct {
|
|
DB *gorm.DB
|
|
}
|
|
|
|
// ModelListResponse 符合OpenAI /v1/models API响应格式
|
|
type ModelListResponse struct {
|
|
Object string `json:"object"`
|
|
Data []ModelData `json:"data"`
|
|
}
|
|
|
|
// ModelData 单个模型的数据结构
|
|
type ModelData struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
OwnedBy string `json:"owned_by"`
|
|
}
|
|
|
|
// ChatCompletionRequest 聊天补全请求结构
|
|
type ChatCompletionRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []ChatCompletionMessage `json:"messages"`
|
|
Stream bool `json:"stream,omitempty"`
|
|
}
|
|
|
|
// ChatCompletionMessage 聊天消息结构
|
|
type ChatCompletionMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
// BackendChatCompletionRequest 是实际发送到后端模型的请求结构
|
|
// 它只包含通用字段,以避免发送不被支持的参数
|
|
type BackendChatCompletionRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []ChatCompletionMessage `json:"messages"`
|
|
Stream bool `json:"stream,omitempty"`
|
|
}
|
|
|
|
// ResponsesRequest /v1/responses 端点请求结构
|
|
type ResponsesRequest struct {
|
|
Model string `json:"model"`
|
|
Input []ChatCompletionMessage `json:"input"`
|
|
Stream bool `json:"stream,omitempty"`
|
|
}
|
|
|
|
// BackendModelAssociation 代表一个后端模型关联及其配置
|
|
type BackendModelAssociation struct {
|
|
BackendModelID uint `json:"backend_model_id" binding:"required"`
|
|
Priority int `json:"priority" binding:"required"`
|
|
CostThreshold float64 `json:"cost_threshold"`
|
|
}
|
|
|
|
// CreateVirtualModelRequest 创建虚拟模型的请求结构
|
|
type CreateVirtualModelRequest struct {
|
|
Name string `json:"name" binding:"required"`
|
|
Description string `json:"description"`
|
|
BackendModels []BackendModelAssociation `json:"backend_models"`
|
|
}
|
|
|
|
// UpdateVirtualModelRequest 更新虚拟模型的请求结构
|
|
type UpdateVirtualModelRequest struct {
|
|
Name string `json:"name" binding:"required"`
|
|
Description string `json:"description"`
|
|
BackendModels []BackendModelAssociation `json:"backend_models"`
|
|
}
|
|
|
|
// forwardRequest 转发请求到后端API
|
|
func forwardRequest(backendURL string, apiKey string, requestBody []byte, userAgent string) (*http.Response, error) {
|
|
// 创建HTTP请求
|
|
httpReq, err := http.NewRequest("POST", backendURL, bytes.NewReader(requestBody))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create backend request: %w", err)
|
|
}
|
|
|
|
// 设置请求头
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
|
|
|
|
// 设置User-Agent
|
|
if userAgent != "" {
|
|
httpReq.Header.Set("User-Agent", userAgent)
|
|
}
|
|
|
|
// 执行请求
|
|
client := &http.Client{
|
|
Timeout: 10 * time.Minute,
|
|
}
|
|
|
|
resp, err := client.Do(httpReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to backend service: %w", err)
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
// extractResponseTokenCount 从响应体中提取token数量
|
|
func extractResponseTokenCount(responseBody []byte) int {
|
|
var responseData map[string]interface{}
|
|
if err := json.Unmarshal(responseBody, &responseData); err != nil {
|
|
return 0
|
|
}
|
|
|
|
if choices, ok := responseData["choices"].([]interface{}); ok && len(choices) > 0 {
|
|
if choice, ok := choices[0].(map[string]interface{}); ok {
|
|
if message, ok := choice["message"].(map[string]interface{}); ok {
|
|
if content, ok := message["content"].(string); ok {
|
|
return billing.CalculateTextTokensSimple(content)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return 0
|
|
}
|
|
|
|
// createRequestLog 创建请求日志记录
|
|
func createRequestLog(apiKeyID uint, virtualModelName string, backendModel *models.BackendModel,
|
|
requestTimestamp, responseTimestamp time.Time, requestTokenCount, responseTokenCount int,
|
|
cost float64, requestBody, responseBody string) *models.RequestLog {
|
|
|
|
return &models.RequestLog{
|
|
APIKeyID: apiKeyID,
|
|
VirtualModelName: virtualModelName,
|
|
BackendModelName: backendModel.Name,
|
|
RequestTimestamp: requestTimestamp,
|
|
ResponseTimestamp: responseTimestamp,
|
|
RequestTokens: requestTokenCount,
|
|
ResponseTokens: responseTokenCount,
|
|
Cost: cost,
|
|
RequestBody: requestBody,
|
|
ResponseBody: responseBody,
|
|
}
|
|
}
|
|
|
|
// getAPIKeyID 从gin上下文中提取API Key ID
|
|
func getAPIKeyID(c *gin.Context) uint {
|
|
apiKeyValue, exists := c.Get("apiKey")
|
|
if !exists {
|
|
return 0
|
|
}
|
|
|
|
if apiKey, ok := apiKeyValue.(models.APIKey); ok {
|
|
return apiKey.ID
|
|
}
|
|
|
|
return 0
|
|
}
|
|
|
|
// copyResponseHeaders 复制响应头到gin context
|
|
func copyResponseHeaders(c *gin.Context, resp *http.Response) {
|
|
for key, values := range resp.Header {
|
|
for _, value := range values {
|
|
c.Header(key, value)
|
|
}
|
|
}
|
|
}
|