675 lines
20 KiB
Go
675 lines
20 KiB
Go
package api
|
||
|
||
import (
|
||
"ai-gateway/internal/billing"
|
||
"ai-gateway/internal/models"
|
||
"ai-gateway/internal/scheduler"
|
||
"bytes"
|
||
"encoding/json"
|
||
"fmt"
|
||
"net/http"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// APIHandler 持有数据库连接并处理API请求
|
||
type APIHandler struct {
|
||
DB *gorm.DB
|
||
LogCleaner *scheduler.LogCleaner
|
||
}
|
||
|
||
// RequestLogListItem 精简的请求日志列表项(不包含RequestBody和ResponseBody)
|
||
type RequestLogListItem struct {
|
||
ID uint `json:"id"`
|
||
APIKeyID uint `json:"api_key_id"`
|
||
ProviderName string `json:"provider_name"`
|
||
VirtualModelName string `json:"virtual_model_name"`
|
||
BackendModelName string `json:"backend_model_name"`
|
||
RequestTimestamp time.Time `json:"request_timestamp"`
|
||
ResponseTimestamp time.Time `json:"response_timestamp"`
|
||
StatusCode int `json:"status_code,omitempty"`
|
||
ErrorMessage string `json:"error_message,omitempty"`
|
||
RequestTokens int `json:"request_tokens"`
|
||
ResponseTokens int `json:"response_tokens"`
|
||
Cost float64 `json:"cost"`
|
||
CreatedAt time.Time `json:"created_at"`
|
||
}
|
||
|
||
// HealthCheckHandler 健康检查端点(无需认证)
|
||
func (h *APIHandler) HealthCheckHandler(c *gin.Context) {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"status": "ok",
|
||
"timestamp": time.Now().Unix(),
|
||
})
|
||
}
|
||
|
||
// ModelListResponse 符合OpenAI /v1/models API响应格式
|
||
type ModelListResponse struct {
|
||
Object string `json:"object"`
|
||
Data []ModelData `json:"data"`
|
||
}
|
||
|
||
// ModelData 单个模型的数据结构
|
||
type ModelData struct {
|
||
ID string `json:"id"`
|
||
Object string `json:"object"`
|
||
Created int64 `json:"created"`
|
||
OwnedBy string `json:"owned_by"`
|
||
}
|
||
|
||
// ChatCompletionRequest 聊天补全请求结构
|
||
type ChatCompletionRequest struct {
|
||
Model string `json:"model"`
|
||
Messages []ChatCompletionMessage `json:"messages"`
|
||
Stream bool `json:"stream,omitempty"`
|
||
}
|
||
|
||
// ChatCompletionMessage 聊天消息结构
|
||
type ChatCompletionMessage struct {
|
||
Role string `json:"role"`
|
||
Content string `json:"content"`
|
||
}
|
||
|
||
// BackendChatCompletionRequest 是实际发送到后端模型的请求结构
|
||
// 它只包含通用字段,以避免发送不被支持的参数
|
||
type BackendChatCompletionRequest struct {
|
||
Model string `json:"model"`
|
||
Messages []ChatCompletionMessage `json:"messages"`
|
||
Stream bool `json:"stream,omitempty"`
|
||
}
|
||
|
||
// ResponsesRequest /v1/responses 端点请求结构
|
||
type ResponsesRequest struct {
|
||
Model string `json:"model"`
|
||
Input []ChatCompletionMessage `json:"input"`
|
||
Stream bool `json:"stream,omitempty"`
|
||
}
|
||
|
||
// BackendModelAssociation 代表一个后端模型关联及其配置
|
||
type BackendModelAssociation struct {
|
||
BackendModelID uint `json:"backend_model_id" binding:"required"`
|
||
Priority int `json:"priority" binding:"required"`
|
||
CostThreshold float64 `json:"cost_threshold"`
|
||
}
|
||
|
||
// CreateVirtualModelRequest 创建虚拟模型的请求结构
|
||
type CreateVirtualModelRequest struct {
|
||
Name string `json:"name" binding:"required"`
|
||
Description string `json:"description"`
|
||
BackendModels []BackendModelAssociation `json:"backend_models"`
|
||
}
|
||
|
||
// UpdateVirtualModelRequest 更新虚拟模型的请求结构
|
||
type UpdateVirtualModelRequest struct {
|
||
Name string `json:"name" binding:"required"`
|
||
Description string `json:"description"`
|
||
BackendModels []BackendModelAssociation `json:"backend_models"`
|
||
}
|
||
|
||
// forwardRequest 转发请求到后端API
|
||
func forwardRequest(backendURL string, apiKey string, requestBody []byte, userAgent string) (*http.Response, error) {
|
||
// 创建HTTP请求
|
||
httpReq, err := http.NewRequest("POST", backendURL, bytes.NewReader(requestBody))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create backend request: %w", err)
|
||
}
|
||
|
||
// 设置请求头
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||
|
||
// 设置User-Agent
|
||
if userAgent != "" {
|
||
httpReq.Header.Set("User-Agent", userAgent)
|
||
}
|
||
|
||
// 执行请求
|
||
client := &http.Client{
|
||
Timeout: 10 * time.Minute,
|
||
}
|
||
|
||
resp, err := client.Do(httpReq)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to connect to backend service: %w", err)
|
||
}
|
||
|
||
return resp, nil
|
||
}
|
||
|
||
// extractResponseTokenCount 从响应体中提取token数量
|
||
func extractResponseTokenCount(responseBody []byte) int {
|
||
var responseData map[string]interface{}
|
||
if err := json.Unmarshal(responseBody, &responseData); err != nil {
|
||
return 0
|
||
}
|
||
|
||
if choices, ok := responseData["choices"].([]interface{}); ok && len(choices) > 0 {
|
||
if choice, ok := choices[0].(map[string]interface{}); ok {
|
||
if message, ok := choice["message"].(map[string]interface{}); ok {
|
||
if content, ok := message["content"].(string); ok {
|
||
return billing.CalculateTextTokensSimple(content)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
return 0
|
||
}
|
||
|
||
// createRequestLog 创建请求日志记录
|
||
func createRequestLog(apiKeyID uint, virtualModelName string, backendModel *models.BackendModel,
|
||
requestTimestamp, responseTimestamp time.Time, requestTokenCount, responseTokenCount int,
|
||
cost float64, requestBody, responseBody string) *models.RequestLog {
|
||
|
||
return &models.RequestLog{
|
||
APIKeyID: apiKeyID,
|
||
VirtualModelName: virtualModelName,
|
||
BackendModelName: backendModel.Name,
|
||
ProviderName: backendModel.Provider.Name,
|
||
RequestTimestamp: requestTimestamp,
|
||
ResponseTimestamp: responseTimestamp,
|
||
RequestTokens: requestTokenCount,
|
||
ResponseTokens: responseTokenCount,
|
||
Cost: cost,
|
||
RequestBody: requestBody,
|
||
ResponseBody: responseBody,
|
||
}
|
||
}
|
||
|
||
// getAPIKeyID 从gin上下文中提取API Key ID
|
||
func getAPIKeyID(c *gin.Context) uint {
|
||
apiKeyValue, exists := c.Get("apiKey")
|
||
if !exists {
|
||
return 0
|
||
}
|
||
|
||
if apiKey, ok := apiKeyValue.(models.APIKey); ok {
|
||
return apiKey.ID
|
||
}
|
||
|
||
return 0
|
||
}
|
||
|
||
// copyResponseHeaders 复制响应头到gin context
|
||
func copyResponseHeaders(c *gin.Context, resp *http.Response) {
|
||
for key, values := range resp.Header {
|
||
for _, value := range values {
|
||
c.Header(key, value)
|
||
}
|
||
}
|
||
}
|
||
|
||
// GetRequestLogsHandler 获取请求日志列表(精简版,不包含RequestBody和ResponseBody)
|
||
func (h *APIHandler) GetRequestLogsHandler(c *gin.Context) {
|
||
// 获取分页参数
|
||
page := 1
|
||
pageSize := 20
|
||
if pageParam := c.Query("page"); pageParam != "" {
|
||
fmt.Sscanf(pageParam, "%d", &page)
|
||
}
|
||
if pageSizeParam := c.Query("page_size"); pageSizeParam != "" {
|
||
fmt.Sscanf(pageSizeParam, "%d", &pageSize)
|
||
}
|
||
if page < 1 {
|
||
page = 1
|
||
}
|
||
if pageSize < 1 || pageSize > 100 {
|
||
pageSize = 20
|
||
}
|
||
|
||
// 构建查询
|
||
query := h.DB.Model(&models.RequestLog{})
|
||
|
||
// 按时间范围过滤
|
||
if startTime := c.Query("start_time"); startTime != "" {
|
||
if t, err := time.Parse(time.RFC3339, startTime); err == nil {
|
||
query = query.Where("request_timestamp >= ?", t)
|
||
}
|
||
}
|
||
if endTime := c.Query("end_time"); endTime != "" {
|
||
if t, err := time.Parse(time.RFC3339, endTime); err == nil {
|
||
query = query.Where("request_timestamp <= ?", t)
|
||
}
|
||
}
|
||
|
||
// 按虚拟模型名称过滤
|
||
if virtualModel := c.Query("virtual_model"); virtualModel != "" {
|
||
query = query.Where("virtual_model_name = ?", virtualModel)
|
||
}
|
||
|
||
// 按后端模型名称过滤
|
||
if backendModel := c.Query("backend_model"); backendModel != "" {
|
||
query = query.Where("backend_model_name = ?", backendModel)
|
||
}
|
||
|
||
// 获取总数
|
||
var total int64
|
||
if err := query.Count(&total).Error; err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to count logs"})
|
||
return
|
||
}
|
||
|
||
// 获取日志列表(完整数据)
|
||
var logs []models.RequestLog
|
||
offset := (page - 1) * pageSize
|
||
if err := query.Order("request_timestamp DESC").Limit(pageSize).Offset(offset).Find(&logs).Error; err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch logs"})
|
||
return
|
||
}
|
||
|
||
// 转换为精简DTO
|
||
logItems := make([]RequestLogListItem, len(logs))
|
||
for i, log := range logs {
|
||
logItems[i] = RequestLogListItem{
|
||
ID: log.ID,
|
||
APIKeyID: log.APIKeyID,
|
||
ProviderName: log.ProviderName,
|
||
VirtualModelName: log.VirtualModelName,
|
||
BackendModelName: log.BackendModelName,
|
||
RequestTimestamp: log.RequestTimestamp,
|
||
ResponseTimestamp: log.ResponseTimestamp,
|
||
RequestTokens: log.RequestTokens,
|
||
ResponseTokens: log.ResponseTokens,
|
||
Cost: log.Cost,
|
||
CreatedAt: log.CreatedAt,
|
||
}
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"logs": logItems,
|
||
"total": total,
|
||
"page": page,
|
||
"page_size": pageSize,
|
||
"total_pages": (total + int64(pageSize) - 1) / int64(pageSize),
|
||
})
|
||
}
|
||
|
||
// GetRequestLogDetailHandler 获取单个请求日志的完整详情(包含RequestBody和ResponseBody)
|
||
func (h *APIHandler) GetRequestLogDetailHandler(c *gin.Context) {
|
||
// 获取日志ID
|
||
logID := c.Param("id")
|
||
if logID == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "Log ID is required"})
|
||
return
|
||
}
|
||
|
||
// 查询日志
|
||
var log models.RequestLog
|
||
if err := h.DB.First(&log, logID).Error; err != nil {
|
||
if err == gorm.ErrRecordNotFound {
|
||
c.JSON(http.StatusNotFound, gin.H{"error": "Log not found"})
|
||
return
|
||
}
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch log detail"})
|
||
return
|
||
}
|
||
|
||
// 返回完整的日志信息
|
||
c.JSON(http.StatusOK, log)
|
||
}
|
||
|
||
// GetRequestLogStatsHandler 获取请求日志统计信息
|
||
func (h *APIHandler) GetRequestLogStatsHandler(c *gin.Context) {
|
||
// 获取时间范围参数
|
||
var startTime, endTime time.Time
|
||
var err error
|
||
|
||
if startTimeParam := c.Query("start_time"); startTimeParam != "" {
|
||
startTime, err = time.Parse(time.RFC3339, startTimeParam)
|
||
if err != nil {
|
||
// 默认最近7天
|
||
startTime = time.Now().AddDate(0, 0, -7)
|
||
}
|
||
} else {
|
||
startTime = time.Now().AddDate(0, 0, -7)
|
||
}
|
||
|
||
if endTimeParam := c.Query("end_time"); endTimeParam != "" {
|
||
endTime, err = time.Parse(time.RFC3339, endTimeParam)
|
||
if err != nil {
|
||
endTime = time.Now()
|
||
}
|
||
} else {
|
||
endTime = time.Now()
|
||
}
|
||
|
||
// 统计总请求数
|
||
var totalRequests int64
|
||
h.DB.Model(&models.RequestLog{}).
|
||
Where("request_timestamp BETWEEN ? AND ?", startTime, endTime).
|
||
Count(&totalRequests)
|
||
|
||
// 统计总成本
|
||
var totalCost float64
|
||
h.DB.Model(&models.RequestLog{}).
|
||
Where("request_timestamp BETWEEN ? AND ?", startTime, endTime).
|
||
Select("COALESCE(SUM(cost), 0)").
|
||
Scan(&totalCost)
|
||
|
||
// 统计总token数
|
||
var totalTokens struct {
|
||
RequestTokens int64
|
||
ResponseTokens int64
|
||
}
|
||
h.DB.Model(&models.RequestLog{}).
|
||
Where("request_timestamp BETWEEN ? AND ?", startTime, endTime).
|
||
Select("COALESCE(SUM(request_tokens), 0) as request_tokens, COALESCE(SUM(response_tokens), 0) as response_tokens").
|
||
Scan(&totalTokens)
|
||
|
||
// 按虚拟模型统计
|
||
type ModelStats struct {
|
||
ModelName string `json:"model_name"`
|
||
Count int64 `json:"count"`
|
||
TotalCost float64 `json:"total_cost"`
|
||
}
|
||
var virtualModelStats []ModelStats
|
||
h.DB.Model(&models.RequestLog{}).
|
||
Where("request_timestamp BETWEEN ? AND ?", startTime, endTime).
|
||
Select("virtual_model_name as model_name, COUNT(*) as count, COALESCE(SUM(cost), 0) as total_cost").
|
||
Group("virtual_model_name").
|
||
Order("count DESC").
|
||
Limit(10).
|
||
Scan(&virtualModelStats)
|
||
|
||
// 按后端模型统计
|
||
var backendModelStats []ModelStats
|
||
h.DB.Model(&models.RequestLog{}).
|
||
Where("request_timestamp BETWEEN ? AND ?", startTime, endTime).
|
||
Select("backend_model_name as model_name, COUNT(*) as count, COALESCE(SUM(cost), 0) as total_cost").
|
||
Group("backend_model_name").
|
||
Order("count DESC").
|
||
Limit(10).
|
||
Scan(&backendModelStats)
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"start_time": startTime,
|
||
"end_time": endTime,
|
||
"total_requests": totalRequests,
|
||
"total_cost": totalCost,
|
||
"total_request_tokens": totalTokens.RequestTokens,
|
||
"total_response_tokens": totalTokens.ResponseTokens,
|
||
"virtual_model_stats": virtualModelStats,
|
||
"backend_model_stats": backendModelStats,
|
||
})
|
||
}
|
||
|
||
// ClearRequestLogsHandler 清空请求日志
|
||
func (h *APIHandler) ClearRequestLogsHandler(c *gin.Context) {
|
||
// 获取查询参数
|
||
olderThan := c.Query("older_than") // 清空多少天前的日志
|
||
|
||
// 先查询当前总记录数(包括软删除的记录)
|
||
var totalCountBefore int64
|
||
h.DB.Unscoped().Model(&models.RequestLog{}).Count(&totalCountBefore)
|
||
|
||
// 查询活跃记录数(不包括软删除的记录)
|
||
var activeCountBefore int64
|
||
h.DB.Model(&models.RequestLog{}).Count(&activeCountBefore)
|
||
|
||
var startTime time.Time
|
||
if olderThan != "" {
|
||
// 解析天数
|
||
var days int
|
||
if _, err := fmt.Sscanf(olderThan, "%d", &days); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid older_than parameter, must be a number representing days"})
|
||
return
|
||
}
|
||
|
||
if days < 0 {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "older_than must be 0 or greater"})
|
||
return
|
||
}
|
||
|
||
// 如果是0天,要删除所有日志,不需要时间限制
|
||
if days == 0 {
|
||
// 使用Unscoped进行硬删除,真正删除记录而非软删除
|
||
result := h.DB.Unscoped().Where("1 = 1").Delete(&models.RequestLog{})
|
||
if result.Error != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to clear logs"})
|
||
return
|
||
}
|
||
|
||
// 执行VACUUM来回收数据库空间
|
||
vacuumError := ""
|
||
if err := h.DB.Exec("VACUUM").Error; err != nil {
|
||
vacuumError = fmt.Sprintf("VACUUM failed: %v", err)
|
||
fmt.Printf("Warning: %s\n", vacuumError)
|
||
}
|
||
|
||
// 查询删除后的记录数
|
||
var totalCountAfter int64
|
||
h.DB.Unscoped().Model(&models.RequestLog{}).Count(&totalCountAfter)
|
||
|
||
// 返回详细的删除信息
|
||
response := gin.H{
|
||
"message": "All logs cleared successfully with hard delete and vacuum",
|
||
"deleted_count": result.RowsAffected,
|
||
"older_than_days": olderThan,
|
||
"total_count_before": totalCountBefore,
|
||
"active_count_before": activeCountBefore,
|
||
"total_count_after": totalCountAfter,
|
||
"hard_delete": true,
|
||
"vacuum_executed": vacuumError == "",
|
||
}
|
||
if vacuumError != "" {
|
||
response["vacuum_error"] = vacuumError
|
||
}
|
||
|
||
c.JSON(http.StatusOK, response)
|
||
return
|
||
}
|
||
|
||
startTime = time.Now().AddDate(0, 0, -days)
|
||
} else {
|
||
// 如果没有指定天数,默认清空30天前的日志
|
||
startTime = time.Now().AddDate(0, 0, -30)
|
||
}
|
||
|
||
// 使用Unscoped执行硬删除操作,真正删除记录而非软删除
|
||
result := h.DB.Unscoped().Where("created_at < ?", startTime).Delete(&models.RequestLog{})
|
||
if result.Error != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to clear logs"})
|
||
return
|
||
}
|
||
|
||
// 执行VACUUM来回收数据库空间
|
||
vacuumError := ""
|
||
if err := h.DB.Exec("VACUUM").Error; err != nil {
|
||
vacuumError = fmt.Sprintf("VACUUM failed: %v", err)
|
||
fmt.Printf("Warning: %s\n", vacuumError)
|
||
}
|
||
|
||
// 查询删除后的记录数
|
||
var totalCountAfter int64
|
||
h.DB.Unscoped().Model(&models.RequestLog{}).Count(&totalCountAfter)
|
||
|
||
// 返回详细的删除信息
|
||
response := gin.H{
|
||
"message": "Logs cleared successfully with hard delete and vacuum",
|
||
"deleted_count": result.RowsAffected,
|
||
"older_than_days": olderThan,
|
||
"cutoff_time": startTime,
|
||
"total_count_before": totalCountBefore,
|
||
"active_count_before": activeCountBefore,
|
||
"total_count_after": totalCountAfter,
|
||
"hard_delete": true,
|
||
"vacuum_executed": vacuumError == "",
|
||
}
|
||
if vacuumError != "" {
|
||
response["vacuum_error"] = vacuumError
|
||
}
|
||
|
||
c.JSON(http.StatusOK, response)
|
||
}
|
||
|
||
|
||
// GetLogCleanerStatusHandler 获取日志清理器状态
|
||
func (h *APIHandler) GetLogCleanerStatusHandler(c *gin.Context) {
|
||
if h.LogCleaner == nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"enabled": false,
|
||
"message": "Log cleaner is not initialized",
|
||
})
|
||
return
|
||
}
|
||
|
||
config := h.LogCleaner.GetConfig()
|
||
nextExecuteTime := h.LogCleaner.GetNextExecuteTime()
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"enabled": config.Enabled,
|
||
"execute_time": config.ExecuteTime,
|
||
"retention_days": config.RetentionDays,
|
||
"check_interval": config.CheckInterval,
|
||
"next_execute_time": nextExecuteTime.Format("2006-01-02 15:04:05"),
|
||
"time_until_next": time.Until(nextExecuteTime).String(),
|
||
})
|
||
}
|
||
|
||
// ForceLogCleanupHandler 手动触发日志清理
|
||
func (h *APIHandler) ForceLogCleanupHandler(c *gin.Context) {
|
||
if h.LogCleaner == nil {
|
||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||
"error": "Log cleaner is not initialized",
|
||
})
|
||
return
|
||
}
|
||
|
||
// 执行清理
|
||
report := h.LogCleaner.ForceCleanup()
|
||
|
||
// 返回清理报告
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"message": "Manual log cleanup completed",
|
||
"execute_time": report.ExecuteTime.Format("2006-01-02 15:04:05"),
|
||
"duration": report.Duration.String(),
|
||
"deleted_count": report.DeletedCount,
|
||
"total_count_before": report.TotalCountBefore,
|
||
"active_count_before": report.ActiveCountBefore,
|
||
"total_count_after": report.TotalCountAfter,
|
||
"cutoff_time": report.CutoffTime.Format("2006-01-02 15:04:05"),
|
||
"vacuum_duration": report.VacuumDuration.String(),
|
||
"vacuum_error": report.VacuumError,
|
||
"success": report.Success,
|
||
})
|
||
}
|
||
|
||
|
||
// ============ API Key 管理相关处理器 ============
|
||
|
||
// APIKeyListResponse API Key列表响应
|
||
type APIKeyListResponse struct {
|
||
ID uint `json:"id"`
|
||
Key string `json:"key"`
|
||
CreatedAt time.Time `json:"created_at"`
|
||
}
|
||
|
||
// GetAPIKeysHandler 获取所有API Key列表
|
||
func (h *APIHandler) GetAPIKeysHandler(c *gin.Context) {
|
||
var apiKeys []models.APIKey
|
||
|
||
// 查询所有API Key
|
||
if err := h.DB.Order("created_at DESC").Find(&apiKeys).Error; err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch API keys"})
|
||
return
|
||
}
|
||
|
||
// 转换为响应格式
|
||
response := make([]APIKeyListResponse, len(apiKeys))
|
||
for i, key := range apiKeys {
|
||
response[i] = APIKeyListResponse{
|
||
ID: key.ID,
|
||
Key: key.Key,
|
||
CreatedAt: key.CreatedAt,
|
||
}
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"api_keys": response,
|
||
"total": len(response),
|
||
})
|
||
}
|
||
|
||
// CreateAPIKeyRequest 创建API Key的请求结构
|
||
type CreateAPIKeyRequest struct {
|
||
Key string `json:"key" binding:"required"`
|
||
}
|
||
|
||
// CreateAPIKeyHandler 创建新的API Key
|
||
func (h *APIHandler) CreateAPIKeyHandler(c *gin.Context) {
|
||
var req CreateAPIKeyRequest
|
||
|
||
// 解析请求
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format: " + err.Error()})
|
||
return
|
||
}
|
||
|
||
// 验证Key不为空
|
||
if req.Key == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "API key cannot be empty"})
|
||
return
|
||
}
|
||
|
||
// 检查Key是否已存在
|
||
var existingKey models.APIKey
|
||
if err := h.DB.Where("key = ?", req.Key).First(&existingKey).Error; err == nil {
|
||
c.JSON(http.StatusConflict, gin.H{"error": "API key already exists"})
|
||
return
|
||
}
|
||
|
||
// 创建新的API Key
|
||
newAPIKey := models.APIKey{
|
||
Key: req.Key,
|
||
}
|
||
|
||
if err := h.DB.Create(&newAPIKey).Error; err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create API key"})
|
||
return
|
||
}
|
||
|
||
// 返回创建的API Key
|
||
c.JSON(http.StatusCreated, gin.H{
|
||
"message": "API key created successfully",
|
||
"api_key": APIKeyListResponse{
|
||
ID: newAPIKey.ID,
|
||
Key: newAPIKey.Key,
|
||
CreatedAt: newAPIKey.CreatedAt,
|
||
},
|
||
})
|
||
}
|
||
|
||
// DeleteAPIKeyHandler 删除指定的API Key
|
||
func (h *APIHandler) DeleteAPIKeyHandler(c *gin.Context) {
|
||
// 获取API Key ID
|
||
keyID := c.Param("id")
|
||
if keyID == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "API key ID is required"})
|
||
return
|
||
}
|
||
|
||
// 查找API Key
|
||
var apiKey models.APIKey
|
||
if err := h.DB.First(&apiKey, keyID).Error; err != nil {
|
||
if err == gorm.ErrRecordNotFound {
|
||
c.JSON(http.StatusNotFound, gin.H{"error": "API key not found"})
|
||
return
|
||
}
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to find API key"})
|
||
return
|
||
}
|
||
|
||
// 删除API Key
|
||
if err := h.DB.Delete(&apiKey).Error; err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete API key"})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"message": "API key deleted successfully",
|
||
"id": apiKey.ID,
|
||
})
|
||
}
|