Files
AIRouter/backend/api/handlers.go
2025-11-17 11:42:17 +08:00

700 lines
20 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package api
import (
"ai-gateway/internal/billing"
"ai-gateway/internal/middleware"
"ai-gateway/internal/models"
"ai-gateway/internal/scheduler"
"bytes"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// APIHandler 持有数据库连接并处理API请求
type APIHandler struct {
DB *gorm.DB
LogCleaner *scheduler.LogCleaner
WebUIPassword string
}
// RequestLogListItem 精简的请求日志列表项不包含RequestBody和ResponseBody
type RequestLogListItem struct {
ID uint `json:"id"`
APIKeyID uint `json:"api_key_id"`
ProviderName string `json:"provider_name"`
VirtualModelName string `json:"virtual_model_name"`
BackendModelName string `json:"backend_model_name"`
RequestTimestamp time.Time `json:"request_timestamp"`
ResponseTimestamp time.Time `json:"response_timestamp"`
StatusCode int `json:"status_code,omitempty"`
ErrorMessage string `json:"error_message,omitempty"`
RequestTokens int `json:"request_tokens"`
ResponseTokens int `json:"response_tokens"`
Cost float64 `json:"cost"`
CreatedAt time.Time `json:"created_at"`
}
// HealthCheckHandler 健康检查端点(无需认证)
func (h *APIHandler) HealthCheckHandler(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"status": "ok",
"timestamp": time.Now().Unix(),
})
}
// ModelListResponse 符合OpenAI /v1/models API响应格式
type ModelListResponse struct {
Object string `json:"object"`
Data []ModelData `json:"data"`
}
// ModelData 单个模型的数据结构
type ModelData struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
// ChatCompletionRequest 聊天补全请求结构
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
Stream bool `json:"stream,omitempty"`
}
// ChatCompletionMessage 聊天消息结构
type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// BackendChatCompletionRequest 是实际发送到后端模型的请求结构
// 它只包含通用字段,以避免发送不被支持的参数
type BackendChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
Stream bool `json:"stream,omitempty"`
}
// ResponsesRequest /v1/responses 端点请求结构
type ResponsesRequest struct {
Model string `json:"model"`
Input []ChatCompletionMessage `json:"input"`
Stream bool `json:"stream,omitempty"`
}
// BackendModelAssociation 代表一个后端模型关联及其配置
type BackendModelAssociation struct {
BackendModelID uint `json:"backend_model_id" binding:"required"`
Priority int `json:"priority" binding:"required"`
CostThreshold float64 `json:"cost_threshold"`
}
// CreateVirtualModelRequest 创建虚拟模型的请求结构
type CreateVirtualModelRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
BackendModels []BackendModelAssociation `json:"backend_models"`
}
// UpdateVirtualModelRequest 更新虚拟模型的请求结构
type UpdateVirtualModelRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
BackendModels []BackendModelAssociation `json:"backend_models"`
}
// forwardRequest 转发请求到后端API
func forwardRequest(backendURL string, apiKey string, requestBody []byte, userAgent string) (*http.Response, error) {
// 创建HTTP请求
httpReq, err := http.NewRequest("POST", backendURL, bytes.NewReader(requestBody))
if err != nil {
return nil, fmt.Errorf("failed to create backend request: %w", err)
}
// 设置请求头
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
// 设置User-Agent
if userAgent != "" {
httpReq.Header.Set("User-Agent", userAgent)
}
// 执行请求
client := &http.Client{
Timeout: 10 * time.Minute,
}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to connect to backend service: %w", err)
}
return resp, nil
}
// extractResponseTokenCount 从响应体中提取token数量
func extractResponseTokenCount(responseBody []byte) int {
var responseData map[string]interface{}
if err := json.Unmarshal(responseBody, &responseData); err != nil {
return 0
}
if choices, ok := responseData["choices"].([]interface{}); ok && len(choices) > 0 {
if choice, ok := choices[0].(map[string]interface{}); ok {
if message, ok := choice["message"].(map[string]interface{}); ok {
if content, ok := message["content"].(string); ok {
return billing.CalculateTextTokensSimple(content)
}
}
}
}
return 0
}
// createRequestLog 创建请求日志记录
func createRequestLog(apiKeyID uint, virtualModelName string, backendModel *models.BackendModel,
requestTimestamp, responseTimestamp time.Time, requestTokenCount, responseTokenCount int,
cost float64, requestBody, responseBody string) *models.RequestLog {
return &models.RequestLog{
APIKeyID: apiKeyID,
VirtualModelName: virtualModelName,
BackendModelName: backendModel.Name,
ProviderName: backendModel.Provider.Name,
RequestTimestamp: requestTimestamp,
ResponseTimestamp: responseTimestamp,
RequestTokens: requestTokenCount,
ResponseTokens: responseTokenCount,
Cost: cost,
RequestBody: requestBody,
ResponseBody: responseBody,
}
}
// getAPIKeyID 从gin上下文中提取API Key ID
func getAPIKeyID(c *gin.Context) uint {
apiKeyValue, exists := c.Get("apiKey")
if !exists {
return 0
}
if apiKey, ok := apiKeyValue.(models.APIKey); ok {
return apiKey.ID
}
return 0
}
// copyResponseHeaders 复制响应头到gin context
func copyResponseHeaders(c *gin.Context, resp *http.Response) {
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
}
// GetRequestLogsHandler 获取请求日志列表精简版不包含RequestBody和ResponseBody
func (h *APIHandler) GetRequestLogsHandler(c *gin.Context) {
// 获取分页参数
page := 1
pageSize := 20
if pageParam := c.Query("page"); pageParam != "" {
fmt.Sscanf(pageParam, "%d", &page)
}
if pageSizeParam := c.Query("page_size"); pageSizeParam != "" {
fmt.Sscanf(pageSizeParam, "%d", &pageSize)
}
if page < 1 {
page = 1
}
if pageSize < 1 || pageSize > 100 {
pageSize = 20
}
// 构建查询
query := h.DB.Model(&models.RequestLog{})
// 按时间范围过滤
if startTime := c.Query("start_time"); startTime != "" {
if t, err := time.Parse(time.RFC3339, startTime); err == nil {
query = query.Where("request_timestamp >= ?", t)
}
}
if endTime := c.Query("end_time"); endTime != "" {
if t, err := time.Parse(time.RFC3339, endTime); err == nil {
query = query.Where("request_timestamp <= ?", t)
}
}
// 按虚拟模型名称过滤
if virtualModel := c.Query("virtual_model"); virtualModel != "" {
query = query.Where("virtual_model_name = ?", virtualModel)
}
// 按后端模型名称过滤
if backendModel := c.Query("backend_model"); backendModel != "" {
query = query.Where("backend_model_name = ?", backendModel)
}
// 获取总数
var total int64
if err := query.Count(&total).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to count logs"})
return
}
// 获取日志列表(完整数据)
var logs []models.RequestLog
offset := (page - 1) * pageSize
if err := query.Order("request_timestamp DESC").Limit(pageSize).Offset(offset).Find(&logs).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch logs"})
return
}
// 转换为精简DTO
logItems := make([]RequestLogListItem, len(logs))
for i, log := range logs {
logItems[i] = RequestLogListItem{
ID: log.ID,
APIKeyID: log.APIKeyID,
ProviderName: log.ProviderName,
VirtualModelName: log.VirtualModelName,
BackendModelName: log.BackendModelName,
RequestTimestamp: log.RequestTimestamp,
ResponseTimestamp: log.ResponseTimestamp,
RequestTokens: log.RequestTokens,
ResponseTokens: log.ResponseTokens,
Cost: log.Cost,
CreatedAt: log.CreatedAt,
}
}
c.JSON(http.StatusOK, gin.H{
"logs": logItems,
"total": total,
"page": page,
"page_size": pageSize,
"total_pages": (total + int64(pageSize) - 1) / int64(pageSize),
})
}
// GetRequestLogDetailHandler 获取单个请求日志的完整详情包含RequestBody和ResponseBody
func (h *APIHandler) GetRequestLogDetailHandler(c *gin.Context) {
// 获取日志ID
logID := c.Param("id")
if logID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Log ID is required"})
return
}
// 查询日志
var log models.RequestLog
if err := h.DB.First(&log, logID).Error; err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{"error": "Log not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch log detail"})
return
}
// 返回完整的日志信息
c.JSON(http.StatusOK, log)
}
// GetRequestLogStatsHandler 获取请求日志统计信息
func (h *APIHandler) GetRequestLogStatsHandler(c *gin.Context) {
// 获取时间范围参数
var startTime, endTime time.Time
var err error
if startTimeParam := c.Query("start_time"); startTimeParam != "" {
startTime, err = time.Parse(time.RFC3339, startTimeParam)
if err != nil {
// 默认最近7天
startTime = time.Now().AddDate(0, 0, -7)
}
} else {
startTime = time.Now().AddDate(0, 0, -7)
}
if endTimeParam := c.Query("end_time"); endTimeParam != "" {
endTime, err = time.Parse(time.RFC3339, endTimeParam)
if err != nil {
endTime = time.Now()
}
} else {
endTime = time.Now()
}
// 统计总请求数
var totalRequests int64
h.DB.Model(&models.RequestLog{}).
Where("request_timestamp BETWEEN ? AND ?", startTime, endTime).
Count(&totalRequests)
// 统计总成本
var totalCost float64
h.DB.Model(&models.RequestLog{}).
Where("request_timestamp BETWEEN ? AND ?", startTime, endTime).
Select("COALESCE(SUM(cost), 0)").
Scan(&totalCost)
// 统计总token数
var totalTokens struct {
RequestTokens int64
ResponseTokens int64
}
h.DB.Model(&models.RequestLog{}).
Where("request_timestamp BETWEEN ? AND ?", startTime, endTime).
Select("COALESCE(SUM(request_tokens), 0) as request_tokens, COALESCE(SUM(response_tokens), 0) as response_tokens").
Scan(&totalTokens)
// 按虚拟模型统计
type ModelStats struct {
ModelName string `json:"model_name"`
Count int64 `json:"count"`
TotalCost float64 `json:"total_cost"`
}
var virtualModelStats []ModelStats
h.DB.Model(&models.RequestLog{}).
Where("request_timestamp BETWEEN ? AND ?", startTime, endTime).
Select("virtual_model_name as model_name, COUNT(*) as count, COALESCE(SUM(cost), 0) as total_cost").
Group("virtual_model_name").
Order("count DESC").
Limit(10).
Scan(&virtualModelStats)
// 按后端模型统计
var backendModelStats []ModelStats
h.DB.Model(&models.RequestLog{}).
Where("request_timestamp BETWEEN ? AND ?", startTime, endTime).
Select("backend_model_name as model_name, COUNT(*) as count, COALESCE(SUM(cost), 0) as total_cost").
Group("backend_model_name").
Order("count DESC").
Limit(10).
Scan(&backendModelStats)
c.JSON(http.StatusOK, gin.H{
"start_time": startTime,
"end_time": endTime,
"total_requests": totalRequests,
"total_cost": totalCost,
"total_request_tokens": totalTokens.RequestTokens,
"total_response_tokens": totalTokens.ResponseTokens,
"virtual_model_stats": virtualModelStats,
"backend_model_stats": backendModelStats,
})
}
// ClearRequestLogsHandler 清空请求日志
func (h *APIHandler) ClearRequestLogsHandler(c *gin.Context) {
// 获取查询参数
olderThan := c.Query("older_than") // 清空多少天前的日志
// 先查询当前总记录数(包括软删除的记录)
var totalCountBefore int64
h.DB.Unscoped().Model(&models.RequestLog{}).Count(&totalCountBefore)
// 查询活跃记录数(不包括软删除的记录)
var activeCountBefore int64
h.DB.Model(&models.RequestLog{}).Count(&activeCountBefore)
var startTime time.Time
if olderThan != "" {
// 解析天数
var days int
if _, err := fmt.Sscanf(olderThan, "%d", &days); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid older_than parameter, must be a number representing days"})
return
}
if days < 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "older_than must be 0 or greater"})
return
}
// 如果是0天要删除所有日志不需要时间限制
if days == 0 {
// 使用Unscoped进行硬删除真正删除记录而非软删除
result := h.DB.Unscoped().Where("1 = 1").Delete(&models.RequestLog{})
if result.Error != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to clear logs"})
return
}
// 执行VACUUM来回收数据库空间
vacuumError := ""
if err := h.DB.Exec("VACUUM").Error; err != nil {
vacuumError = fmt.Sprintf("VACUUM failed: %v", err)
fmt.Printf("Warning: %s\n", vacuumError)
}
// 查询删除后的记录数
var totalCountAfter int64
h.DB.Unscoped().Model(&models.RequestLog{}).Count(&totalCountAfter)
// 返回详细的删除信息
response := gin.H{
"message": "All logs cleared successfully with hard delete and vacuum",
"deleted_count": result.RowsAffected,
"older_than_days": olderThan,
"total_count_before": totalCountBefore,
"active_count_before": activeCountBefore,
"total_count_after": totalCountAfter,
"hard_delete": true,
"vacuum_executed": vacuumError == "",
}
if vacuumError != "" {
response["vacuum_error"] = vacuumError
}
c.JSON(http.StatusOK, response)
return
}
startTime = time.Now().AddDate(0, 0, -days)
} else {
// 如果没有指定天数默认清空30天前的日志
startTime = time.Now().AddDate(0, 0, -30)
}
// 使用Unscoped执行硬删除操作真正删除记录而非软删除
result := h.DB.Unscoped().Where("created_at < ?", startTime).Delete(&models.RequestLog{})
if result.Error != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to clear logs"})
return
}
// 执行VACUUM来回收数据库空间
vacuumError := ""
if err := h.DB.Exec("VACUUM").Error; err != nil {
vacuumError = fmt.Sprintf("VACUUM failed: %v", err)
fmt.Printf("Warning: %s\n", vacuumError)
}
// 查询删除后的记录数
var totalCountAfter int64
h.DB.Unscoped().Model(&models.RequestLog{}).Count(&totalCountAfter)
// 返回详细的删除信息
response := gin.H{
"message": "Logs cleared successfully with hard delete and vacuum",
"deleted_count": result.RowsAffected,
"older_than_days": olderThan,
"cutoff_time": startTime,
"total_count_before": totalCountBefore,
"active_count_before": activeCountBefore,
"total_count_after": totalCountAfter,
"hard_delete": true,
"vacuum_executed": vacuumError == "",
}
if vacuumError != "" {
response["vacuum_error"] = vacuumError
}
c.JSON(http.StatusOK, response)
}
// GetLogCleanerStatusHandler 获取日志清理器状态
func (h *APIHandler) GetLogCleanerStatusHandler(c *gin.Context) {
if h.LogCleaner == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"message": "Log cleaner is not initialized",
})
return
}
config := h.LogCleaner.GetConfig()
nextExecuteTime := h.LogCleaner.GetNextExecuteTime()
c.JSON(http.StatusOK, gin.H{
"enabled": config.Enabled,
"execute_time": config.ExecuteTime,
"retention_days": config.RetentionDays,
"check_interval": config.CheckInterval,
"next_execute_time": nextExecuteTime.Format("2006-01-02 15:04:05"),
"time_until_next": time.Until(nextExecuteTime).String(),
})
}
// ForceLogCleanupHandler 手动触发日志清理
func (h *APIHandler) ForceLogCleanupHandler(c *gin.Context) {
if h.LogCleaner == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": "Log cleaner is not initialized",
})
return
}
// 执行清理
report := h.LogCleaner.ForceCleanup()
// 返回清理报告
c.JSON(http.StatusOK, gin.H{
"message": "Manual log cleanup completed",
"execute_time": report.ExecuteTime.Format("2006-01-02 15:04:05"),
"duration": report.Duration.String(),
"deleted_count": report.DeletedCount,
"total_count_before": report.TotalCountBefore,
"active_count_before": report.ActiveCountBefore,
"total_count_after": report.TotalCountAfter,
"cutoff_time": report.CutoffTime.Format("2006-01-02 15:04:05"),
"vacuum_duration": report.VacuumDuration.String(),
"vacuum_error": report.VacuumError,
"success": report.Success,
})
}
// ============ API Key 管理相关处理器 ============
// APIKeyListResponse API Key列表响应
type APIKeyListResponse struct {
ID uint `json:"id"`
Key string `json:"key"`
CreatedAt time.Time `json:"created_at"`
}
// GetAPIKeysHandler 获取所有API Key列表
func (h *APIHandler) GetAPIKeysHandler(c *gin.Context) {
var apiKeys []models.APIKey
// 查询所有API Key
if err := h.DB.Order("created_at DESC").Find(&apiKeys).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch API keys"})
return
}
// 转换为响应格式
response := make([]APIKeyListResponse, len(apiKeys))
for i, key := range apiKeys {
response[i] = APIKeyListResponse{
ID: key.ID,
Key: key.Key,
CreatedAt: key.CreatedAt,
}
}
c.JSON(http.StatusOK, gin.H{
"api_keys": response,
"total": len(response),
})
}
// CreateAPIKeyRequest 创建API Key的请求结构
type CreateAPIKeyRequest struct {
Key string `json:"key" binding:"required"`
}
// CreateAPIKeyHandler 创建新的API Key
func (h *APIHandler) CreateAPIKeyHandler(c *gin.Context) {
var req CreateAPIKeyRequest
// 解析请求
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format: " + err.Error()})
return
}
// 验证Key不为空
if req.Key == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "API key cannot be empty"})
return
}
// 检查Key是否已存在
var existingKey models.APIKey
if err := h.DB.Where("key = ?", req.Key).First(&existingKey).Error; err == nil {
c.JSON(http.StatusConflict, gin.H{"error": "API key already exists"})
return
}
// 创建新的API Key
newAPIKey := models.APIKey{
Key: req.Key,
}
if err := h.DB.Create(&newAPIKey).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create API key"})
return
}
// 返回创建的API Key
c.JSON(http.StatusCreated, gin.H{
"message": "API key created successfully",
"api_key": APIKeyListResponse{
ID: newAPIKey.ID,
Key: newAPIKey.Key,
CreatedAt: newAPIKey.CreatedAt,
},
})
}
// DeleteAPIKeyHandler 删除指定的API Key
func (h *APIHandler) DeleteAPIKeyHandler(c *gin.Context) {
// 获取API Key ID
keyID := c.Param("id")
if keyID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "API key ID is required"})
return
}
// 查找API Key
var apiKey models.APIKey
if err := h.DB.First(&apiKey, keyID).Error; err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{"error": "API key not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to find API key"})
return
}
// 删除API Key
if err := h.DB.Delete(&apiKey).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete API key"})
return
}
c.JSON(http.StatusOK, gin.H{
"message": "API key deleted successfully",
"id": apiKey.ID,
})
}
type LoginRequest struct {
Password string `json:"password"`
}
func (h *APIHandler) LoginHandler(c *gin.Context) {
var req LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
return
}
if req.Password != h.WebUIPassword {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid password"})
return
}
token, err := middleware.GenerateJWT()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
return
}
c.JSON(http.StatusOK, gin.H{"token": token})
}