重构API处理逻辑,新增provider和virtual model处理功能,优化请求体解析和日志记录
This commit is contained in:
@@ -2,18 +2,11 @@ package api
|
||||
|
||||
import (
|
||||
"ai-gateway/internal/billing"
|
||||
"ai-gateway/internal/db"
|
||||
"ai-gateway/internal/logger"
|
||||
"ai-gateway/internal/models"
|
||||
"ai-gateway/internal/router"
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -88,328 +81,6 @@ type UpdateVirtualModelRequest struct {
|
||||
BackendModels []BackendModelAssociation `json:"backend_models"`
|
||||
}
|
||||
|
||||
// ListModels 处理 GET /models 请求
|
||||
func (h *APIHandler) ListModels(c *gin.Context) {
|
||||
var virtualModels []models.VirtualModel
|
||||
|
||||
// 查询所有虚拟模型
|
||||
if err := h.DB.Find(&virtualModels).Error; err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to retrieve models",
|
||||
"type": "internal_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 格式化为OpenAI API响应格式
|
||||
modelData := make([]ModelData, len(virtualModels))
|
||||
for i, vm := range virtualModels {
|
||||
modelData[i] = ModelData{
|
||||
ID: vm.Name,
|
||||
Object: "model",
|
||||
Created: vm.CreatedAt.Unix(),
|
||||
OwnedBy: "ai-gateway",
|
||||
}
|
||||
}
|
||||
|
||||
response := ModelListResponse{
|
||||
Object: "list",
|
||||
Data: modelData,
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// ChatCompletions 处理 POST /v1/chat/completions 请求
|
||||
func (h *APIHandler) ChatCompletions(c *gin.Context) {
|
||||
// 记录请求开始时间
|
||||
requestTimestamp := time.Now()
|
||||
|
||||
// 读取请求体
|
||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||
|
||||
// 解析为 map,只解析一次
|
||||
var jsonBody map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
|
||||
log.Printf("Failed to unmarshal JSON: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Invalid request format",
|
||||
"type": "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 从 map 中提取需要的字段
|
||||
modelName, _ := jsonBody["model"].(string)
|
||||
isStream, _ := jsonBody["stream"].(bool)
|
||||
|
||||
// 提取 messages 并转换为计算 token 所需的格式
|
||||
var messages []billing.ChatCompletionMessage
|
||||
if messagesRaw, ok := jsonBody["messages"].([]interface{}); ok {
|
||||
messages = make([]billing.ChatCompletionMessage, 0, len(messagesRaw))
|
||||
for _, msgRaw := range messagesRaw {
|
||||
if msgMap, ok := msgRaw.(map[string]interface{}); ok {
|
||||
role, _ := msgMap["role"].(string)
|
||||
content, _ := msgMap["content"].(string)
|
||||
messages = append(messages, billing.ChatCompletionMessage{
|
||||
Role: role,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 计算请求token数
|
||||
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
|
||||
|
||||
// 选择后端模型
|
||||
backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
"type": "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 修改 model 字段为后端模型名称
|
||||
jsonBody["model"] = backendModel.Name
|
||||
|
||||
// Marshal 请求体用于转发
|
||||
requestBody, err := json.Marshal(jsonBody)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to process request",
|
||||
"type": "internal_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 构建后端API URL
|
||||
backendURL := backendModel.Provider.BaseURL + "/v1/chat/completions"
|
||||
|
||||
// 转发请求到后端
|
||||
resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to connect to backend service",
|
||||
"type": "service_unavailable",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 记录响应时间
|
||||
responseTimestamp := time.Now()
|
||||
|
||||
// 获取API Key ID
|
||||
apiKeyID := getAPIKeyID(c)
|
||||
|
||||
// 处理流式响应
|
||||
if isStream {
|
||||
handleStreamingResponse(c, resp, requestTimestamp, responseTimestamp,
|
||||
apiKeyID, modelName, backendModel, requestTokenCount,
|
||||
string(requestBody), h.DB)
|
||||
return
|
||||
}
|
||||
|
||||
// 处理非流式响应
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to read backend response",
|
||||
"type": "internal_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Backend Response Status: %s", resp.Status)
|
||||
log.Printf("Backend Response Body: %s", string(responseBody))
|
||||
|
||||
// 计算响应token数
|
||||
responseTokenCount := extractResponseTokenCount(responseBody)
|
||||
|
||||
// 计算费用
|
||||
costCalculator := billing.NewCostCalculator()
|
||||
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
|
||||
|
||||
// 创建并记录日志
|
||||
logEntry := createRequestLog(apiKeyID, modelName, backendModel,
|
||||
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
|
||||
cost, string(requestBody), string(responseBody))
|
||||
logger.LogRequest(h.DB, logEntry)
|
||||
|
||||
// 复制响应头并返回响应
|
||||
copyResponseHeaders(c, resp)
|
||||
c.Status(resp.StatusCode)
|
||||
c.Writer.Write(responseBody)
|
||||
}
|
||||
|
||||
// ResponsesCompletions 处理 POST /v1/responses 请求
|
||||
func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
|
||||
// 记录请求开始时间
|
||||
requestTimestamp := time.Now()
|
||||
|
||||
// 读取请求体
|
||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||
|
||||
// 解析为 map,只解析一次
|
||||
var jsonBody map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Invalid request format",
|
||||
"type": "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 从 map 中提取需要的字段
|
||||
modelName, _ := jsonBody["model"].(string)
|
||||
|
||||
// 提取 input 并转换为计算 token 所需的格式
|
||||
var messages []billing.ChatCompletionMessage
|
||||
if inputRaw, ok := jsonBody["input"].([]interface{}); ok {
|
||||
messages = make([]billing.ChatCompletionMessage, 0, len(inputRaw))
|
||||
for _, msgRaw := range inputRaw {
|
||||
if msgMap, ok := msgRaw.(map[string]interface{}); ok {
|
||||
role, _ := msgMap["role"].(string)
|
||||
content, _ := msgMap["content"].(string)
|
||||
messages = append(messages, billing.ChatCompletionMessage{
|
||||
Role: role,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 计算请求token数
|
||||
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
|
||||
|
||||
// 选择后端模型
|
||||
backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
"type": "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 修改 model 字段为后端模型名称
|
||||
jsonBody["model"] = backendModel.Name
|
||||
|
||||
// Marshal 请求体用于转发
|
||||
requestBody, err := json.Marshal(jsonBody)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to process request",
|
||||
"type": "internal_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 构建后端API URL
|
||||
backendURL := backendModel.Provider.BaseURL + "/v1/responses"
|
||||
|
||||
// 转发请求到后端
|
||||
resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to connect to backend service",
|
||||
"type": "service_unavailable",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 记录响应时间
|
||||
responseTimestamp := time.Now()
|
||||
|
||||
// 读取响应体
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to read backend response",
|
||||
"type": "internal_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 计算响应token数
|
||||
responseTokenCount := extractResponseTokenCount(responseBody)
|
||||
|
||||
// 计算费用
|
||||
costCalculator := billing.NewCostCalculator()
|
||||
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
|
||||
|
||||
// 获取API Key ID
|
||||
apiKeyID := getAPIKeyID(c)
|
||||
|
||||
// 创建并记录日志
|
||||
logEntry := createRequestLog(apiKeyID, modelName, backendModel,
|
||||
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
|
||||
cost, string(requestBody), string(responseBody))
|
||||
logger.LogRequest(h.DB, logEntry)
|
||||
|
||||
// 复制响应头并返回响应
|
||||
copyResponseHeaders(c, resp)
|
||||
c.Status(resp.StatusCode)
|
||||
c.Writer.Write(responseBody)
|
||||
}
|
||||
|
||||
// convertToTikTokenMessages 将ChatCompletionMessage转换为billing包的消息格式
|
||||
func convertToTikTokenMessages(messages []ChatCompletionMessage) []billing.ChatCompletionMessage {
|
||||
result := make([]billing.ChatCompletionMessage, len(messages))
|
||||
for i, msg := range messages {
|
||||
result[i] = billing.ChatCompletionMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// prepareBackendRequest 准备后端请求体,替换model字段
|
||||
func prepareBackendRequest(bodyBytes []byte, backendModelName string) ([]byte, error) {
|
||||
var jsonBody map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal request body: %w", err)
|
||||
}
|
||||
|
||||
// 修改model字段为后端模型名称
|
||||
jsonBody["model"] = backendModelName
|
||||
|
||||
// 重新marshal
|
||||
requestBody, err := json.Marshal(jsonBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
return requestBody, nil
|
||||
}
|
||||
|
||||
// forwardRequest 转发请求到后端API
|
||||
func forwardRequest(backendURL string, apiKey string, requestBody []byte, userAgent string) (*http.Response, error) {
|
||||
// 创建HTTP请求
|
||||
@@ -501,297 +172,3 @@ func copyResponseHeaders(c *gin.Context, resp *http.Response) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleStreamingResponse 处理流式响应
|
||||
func handleStreamingResponse(c *gin.Context, resp *http.Response, requestTimestamp, responseTimestamp time.Time,
|
||||
apiKeyID uint, virtualModelName string, backendModel *models.BackendModel, requestTokenCount int,
|
||||
requestBody string, database *gorm.DB) {
|
||||
|
||||
// 设置流式响应\u5934
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Transfer-Encoding", "chunked")
|
||||
|
||||
// 复制其他响应头
|
||||
for key, values := range resp.Header {
|
||||
if key != "Content-Length" && key != "Transfer-Encoding" {
|
||||
for _, value := range values {
|
||||
c.Header(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Status(resp.StatusCode)
|
||||
|
||||
// 用于累积完整响应内容
|
||||
var fullContent strings.Builder
|
||||
var responseBody strings.Builder
|
||||
|
||||
// 创建一个 scanner 来逐行读取流
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
// 将原始行写入响应体缓冲区
|
||||
responseBody.WriteString(line)
|
||||
responseBody.WriteString("\n")
|
||||
|
||||
// SSE 格式:data: {...}
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
|
||||
// 检查是否是结束标记
|
||||
if data == "[DONE]" {
|
||||
_, err := c.Writer.Write([]byte(line + "\n\n"))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// 解析 JSON 数据以提取内容
|
||||
var chunk map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err == nil {
|
||||
if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 {
|
||||
if choice, ok := choices[0].(map[string]interface{}); ok {
|
||||
if delta, ok := choice["delta"].(map[string]interface{}); ok {
|
||||
if content, ok := delta["content"].(string); ok {
|
||||
fullContent.WriteString(content)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 转发数据到客户端
|
||||
_, err := c.Writer.Write([]byte(line + "\n"))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 如果是空行(SSE 消息分隔符),刷新
|
||||
if line == "" {
|
||||
if ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 确保发送最后的数据
|
||||
if ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// 扫描可能出现的错误
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.Printf("Error reading stream: %v", err)
|
||||
}
|
||||
|
||||
// 计算响应 token 数
|
||||
responseTokenCount := billing.CalculateTextTokensSimple(fullContent.String())
|
||||
|
||||
// 计算费用
|
||||
costCalculator := billing.NewCostCalculator()
|
||||
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
|
||||
|
||||
// 创建日志记录
|
||||
logEntry := &models.RequestLog{
|
||||
APIKeyID: apiKeyID,
|
||||
VirtualModelName: virtualModelName,
|
||||
BackendModelName: backendModel.Name,
|
||||
RequestTimestamp: requestTimestamp,
|
||||
ResponseTimestamp: time.Now(), // 使用实际结束时间
|
||||
RequestTokens: requestTokenCount,
|
||||
ResponseTokens: responseTokenCount,
|
||||
Cost: cost,
|
||||
RequestBody: requestBody,
|
||||
ResponseBody: responseBody.String(),
|
||||
}
|
||||
|
||||
// 异步记录日志
|
||||
logger.LogRequest(database, logEntry)
|
||||
}
|
||||
|
||||
// GetProvidersHandler 处理 GET /api/providers 请求
|
||||
func (h *APIHandler) GetProvidersHandler(c *gin.Context) {
|
||||
providers, err := db.GetProviders(h.DB)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to retrieve providers",
|
||||
"type": "internal_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, providers)
|
||||
}
|
||||
|
||||
// GetProviderHandler 处理 GET /api/providers/:id 请求
|
||||
func (h *APIHandler) GetProviderHandler(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var providerID uint
|
||||
if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := db.GetProviderByID(h.DB, providerID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Provider not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, provider)
|
||||
}
|
||||
|
||||
// CreateProviderHandler 处理 POST /api/providers 请求
|
||||
func (h *APIHandler) CreateProviderHandler(c *gin.Context) {
|
||||
var provider models.Provider
|
||||
if err := c.ShouldBindJSON(&provider); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := db.CreateProvider(h.DB, &provider); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, provider)
|
||||
}
|
||||
|
||||
// UpdateProviderHandler 处理 PUT /api/providers/:id 请求
|
||||
func (h *APIHandler) UpdateProviderHandler(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var providerID uint
|
||||
if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var provider models.Provider
|
||||
if err := c.ShouldBindJSON(&provider); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
provider.ID = providerID
|
||||
|
||||
if err := db.UpdateProvider(h.DB, &provider); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, provider)
|
||||
}
|
||||
|
||||
// DeleteProviderHandler 处理 DELETE /api/providers/:id 请求
|
||||
func (h *APIHandler) DeleteProviderHandler(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var providerID uint
|
||||
if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := db.DeleteProvider(h.DB, providerID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete provider"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusNoContent, nil)
|
||||
}
|
||||
|
||||
// CreateVirtualModelHandler 处理 POST /api/virtual-models 请求
|
||||
func (h *APIHandler) CreateVirtualModelHandler(c *gin.Context) {
|
||||
var vm models.VirtualModel
|
||||
if err := c.ShouldBindJSON(&vm); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := db.CreateVirtualModel(h.DB, &vm); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create virtual model"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, vm)
|
||||
}
|
||||
|
||||
// GetVirtualModelsHandler 处理 GET /api/virtual-models 请求
|
||||
func (h *APIHandler) GetVirtualModelsHandler(c *gin.Context) {
|
||||
virtualModels, err := db.GetVirtualModels(h.DB)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve virtual models"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, virtualModels)
|
||||
}
|
||||
|
||||
// GetVirtualModelHandler 处理 GET /api/virtual-models/:id 请求
|
||||
func (h *APIHandler) GetVirtualModelHandler(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var virtualModelID uint
|
||||
if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"})
|
||||
return
|
||||
}
|
||||
|
||||
virtualModel, err := db.GetVirtualModelByID(h.DB, virtualModelID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Virtual model not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, virtualModel)
|
||||
}
|
||||
|
||||
// UpdateVirtualModelHandler 处理 PUT /api/virtual-models/:id 请求
|
||||
func (h *APIHandler) UpdateVirtualModelHandler(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var virtualModelID uint
|
||||
if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var vm models.VirtualModel
|
||||
if err := c.ShouldBindJSON(&vm); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
vm.ID = virtualModelID
|
||||
|
||||
if err := db.UpdateVirtualModel(h.DB, &vm); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update virtual model"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, vm)
|
||||
}
|
||||
|
||||
// DeleteVirtualModelHandler 处理 DELETE /api/virtual-models/:id 请求
|
||||
func (h *APIHandler) DeleteVirtualModelHandler(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var virtualModelID uint
|
||||
if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := db.DeleteVirtualModel(h.DB, virtualModelID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete virtual model"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusNoContent, nil)
|
||||
}
|
||||
|
||||
426
backend/api/openai_handlers.go
Normal file
426
backend/api/openai_handlers.go
Normal file
@@ -0,0 +1,426 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"ai-gateway/internal/billing"
|
||||
"ai-gateway/internal/logger"
|
||||
"ai-gateway/internal/models"
|
||||
"ai-gateway/internal/router"
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ListModels 处理 GET /models 请求
|
||||
func (h *APIHandler) ListModels(c *gin.Context) {
|
||||
var virtualModels []models.VirtualModel
|
||||
|
||||
// 查询所有虚拟模型
|
||||
if err := h.DB.Find(&virtualModels).Error; err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to retrieve models",
|
||||
"type": "internal_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 格式化为OpenAI API响应格式
|
||||
modelData := make([]ModelData, len(virtualModels))
|
||||
for i, vm := range virtualModels {
|
||||
modelData[i] = ModelData{
|
||||
ID: vm.Name,
|
||||
Object: "model",
|
||||
Created: vm.CreatedAt.Unix(),
|
||||
OwnedBy: "ai-gateway",
|
||||
}
|
||||
}
|
||||
|
||||
response := ModelListResponse{
|
||||
Object: "list",
|
||||
Data: modelData,
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// ChatCompletions 处理 POST /v1/chat/completions 请求
|
||||
func (h *APIHandler) ChatCompletions(c *gin.Context) {
|
||||
// 记录请求开始时间
|
||||
requestTimestamp := time.Now()
|
||||
|
||||
// 读取请求体
|
||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||
|
||||
// 解析为 map,只解析一次
|
||||
var jsonBody map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
|
||||
log.Printf("Failed to unmarshal JSON: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Invalid request format",
|
||||
"type": "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 从 map 中提取需要的字段
|
||||
modelName, _ := jsonBody["model"].(string)
|
||||
isStream, _ := jsonBody["stream"].(bool)
|
||||
|
||||
// 提取 messages 并转换为计算 token 所需的格式
|
||||
var messages []billing.ChatCompletionMessage
|
||||
if messagesRaw, ok := jsonBody["messages"].([]interface{}); ok {
|
||||
messages = make([]billing.ChatCompletionMessage, 0, len(messagesRaw))
|
||||
for _, msgRaw := range messagesRaw {
|
||||
if msgMap, ok := msgRaw.(map[string]interface{}); ok {
|
||||
role, _ := msgMap["role"].(string)
|
||||
content, _ := msgMap["content"].(string)
|
||||
messages = append(messages, billing.ChatCompletionMessage{
|
||||
Role: role,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 计算请求token数
|
||||
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
|
||||
|
||||
// 选择后端模型
|
||||
backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
"type": "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 修改 model 字段为后端模型名称
|
||||
jsonBody["model"] = backendModel.Name
|
||||
|
||||
// Marshal 请求体用于转发
|
||||
requestBody, err := json.Marshal(jsonBody)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to process request",
|
||||
"type": "internal_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 构建后端API URL
|
||||
backendURL := backendModel.Provider.BaseURL + "/v1/chat/completions"
|
||||
|
||||
// 转发请求到后端
|
||||
resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to connect to backend service",
|
||||
"type": "service_unavailable",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 记录响应时间
|
||||
responseTimestamp := time.Now()
|
||||
|
||||
// 获取API Key ID
|
||||
apiKeyID := getAPIKeyID(c)
|
||||
|
||||
// 处理流式响应
|
||||
if isStream {
|
||||
handleStreamingResponse(c, resp, requestTimestamp, responseTimestamp,
|
||||
apiKeyID, modelName, backendModel, requestTokenCount,
|
||||
string(requestBody), h.DB)
|
||||
return
|
||||
}
|
||||
|
||||
// 处理非流式响应
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to read backend response",
|
||||
"type": "internal_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Backend Response Status: %s", resp.Status)
|
||||
log.Printf("Backend Response Body: %s", string(responseBody))
|
||||
|
||||
// 计算响应token数
|
||||
responseTokenCount := extractResponseTokenCount(responseBody)
|
||||
|
||||
// 计算费用
|
||||
costCalculator := billing.NewCostCalculator()
|
||||
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
|
||||
|
||||
// 创建并记录日志
|
||||
logEntry := createRequestLog(apiKeyID, modelName, backendModel,
|
||||
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
|
||||
cost, string(requestBody), string(responseBody))
|
||||
logger.LogRequest(h.DB, logEntry)
|
||||
|
||||
// 复制响应头并返回响应
|
||||
copyResponseHeaders(c, resp)
|
||||
c.Status(resp.StatusCode)
|
||||
c.Writer.Write(responseBody)
|
||||
}
|
||||
|
||||
// ResponsesCompletions 处理 POST /v1/responses 请求
|
||||
func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
|
||||
// 记录请求开始时间
|
||||
requestTimestamp := time.Now()
|
||||
|
||||
// 读取请求体
|
||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||
|
||||
// 解析为 map,只解析一次
|
||||
var jsonBody map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &jsonBody); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Invalid request format",
|
||||
"type": "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 从 map 中提取需要的字段
|
||||
modelName, _ := jsonBody["model"].(string)
|
||||
|
||||
// 提取 input 并转换为计算 token 所需的格式
|
||||
var messages []billing.ChatCompletionMessage
|
||||
if inputRaw, ok := jsonBody["input"].([]interface{}); ok {
|
||||
messages = make([]billing.ChatCompletionMessage, 0, len(inputRaw))
|
||||
for _, msgRaw := range inputRaw {
|
||||
if msgMap, ok := msgRaw.(map[string]interface{}); ok {
|
||||
role, _ := msgMap["role"].(string)
|
||||
content, _ := msgMap["content"].(string)
|
||||
messages = append(messages, billing.ChatCompletionMessage{
|
||||
Role: role,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 计算请求token数
|
||||
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
|
||||
|
||||
// 选择后端模型
|
||||
backendModel, err := router.SelectBackendModel(h.DB, modelName, requestTokenCount)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
"type": "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 修改 model 字段为后端模型名称
|
||||
jsonBody["model"] = backendModel.Name
|
||||
|
||||
// Marshal 请求体用于转发
|
||||
requestBody, err := json.Marshal(jsonBody)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to process request",
|
||||
"type": "internal_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 构建后端API URL
|
||||
backendURL := backendModel.Provider.BaseURL + "/v1/responses"
|
||||
|
||||
// 转发请求到后端
|
||||
resp, err := forwardRequest(backendURL, backendModel.Provider.ApiKey, requestBody, c.GetHeader("User-Agent"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to connect to backend service",
|
||||
"type": "service_unavailable",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 记录响应时间
|
||||
responseTimestamp := time.Now()
|
||||
|
||||
// 读取响应体
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to read backend response",
|
||||
"type": "internal_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 计算响应token数
|
||||
responseTokenCount := extractResponseTokenCount(responseBody)
|
||||
|
||||
// 计算费用
|
||||
costCalculator := billing.NewCostCalculator()
|
||||
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
|
||||
|
||||
// 获取API Key ID
|
||||
apiKeyID := getAPIKeyID(c)
|
||||
|
||||
// 创建并记录日志
|
||||
logEntry := createRequestLog(apiKeyID, modelName, backendModel,
|
||||
requestTimestamp, responseTimestamp, requestTokenCount, responseTokenCount,
|
||||
cost, string(requestBody), string(responseBody))
|
||||
logger.LogRequest(h.DB, logEntry)
|
||||
|
||||
// 复制响应头并返回响应
|
||||
copyResponseHeaders(c, resp)
|
||||
c.Status(resp.StatusCode)
|
||||
c.Writer.Write(responseBody)
|
||||
}
|
||||
|
||||
// handleStreamingResponse 处理流式响应
|
||||
func handleStreamingResponse(c *gin.Context, resp *http.Response, requestTimestamp, responseTimestamp time.Time,
|
||||
apiKeyID uint, virtualModelName string, backendModel *models.BackendModel, requestTokenCount int,
|
||||
requestBody string, database *gorm.DB) {
|
||||
|
||||
// 设置流式响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Transfer-Encoding", "chunked")
|
||||
|
||||
// 复制其他响应头
|
||||
for key, values := range resp.Header {
|
||||
if key != "Content-Length" && key != "Transfer-Encoding" {
|
||||
for _, value := range values {
|
||||
c.Header(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Status(resp.StatusCode)
|
||||
|
||||
// 用于累积完整响应内容
|
||||
var fullContent strings.Builder
|
||||
var responseBody strings.Builder
|
||||
|
||||
// 创建一个 scanner 来逐行读取流
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
// 将原始行写入响应体缓冲区
|
||||
responseBody.WriteString(line)
|
||||
responseBody.WriteString("\n")
|
||||
|
||||
// SSE 格式:data: {...}
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
|
||||
// 检查是否是结束标记
|
||||
if data == "[DONE]" {
|
||||
_, err := c.Writer.Write([]byte(line + "\n\n"))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// 解析 JSON 数据以提取内容
|
||||
var chunk map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err == nil {
|
||||
if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 {
|
||||
if choice, ok := choices[0].(map[string]interface{}); ok {
|
||||
if delta, ok := choice["delta"].(map[string]interface{}); ok {
|
||||
if content, ok := delta["content"].(string); ok {
|
||||
fullContent.WriteString(content)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 转发数据到客户端
|
||||
_, err := c.Writer.Write([]byte(line + "\n"))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 如果是空行(SSE 消息分隔符),刷新
|
||||
if line == "" {
|
||||
if ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 确保发送最后的数据
|
||||
if ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// 扫描可能出现的错误
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.Printf("Error reading stream: %v", err)
|
||||
}
|
||||
|
||||
// 计算响应 token 数
|
||||
responseTokenCount := billing.CalculateTextTokensSimple(fullContent.String())
|
||||
|
||||
// 计算费用
|
||||
costCalculator := billing.NewCostCalculator()
|
||||
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
|
||||
|
||||
// 创建日志记录
|
||||
logEntry := &models.RequestLog{
|
||||
APIKeyID: apiKeyID,
|
||||
VirtualModelName: virtualModelName,
|
||||
BackendModelName: backendModel.Name,
|
||||
RequestTimestamp: requestTimestamp,
|
||||
ResponseTimestamp: time.Now(), // 使用实际结束时间
|
||||
RequestTokens: requestTokenCount,
|
||||
ResponseTokens: responseTokenCount,
|
||||
Cost: cost,
|
||||
RequestBody: requestBody,
|
||||
ResponseBody: responseBody.String(),
|
||||
}
|
||||
|
||||
// 异步记录日志
|
||||
logger.LogRequest(database, logEntry)
|
||||
}
|
||||
101
backend/api/provider_handlers.go
Normal file
101
backend/api/provider_handlers.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"ai-gateway/internal/db"
|
||||
"ai-gateway/internal/models"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GetProvidersHandler 处理 GET /api/providers 请求
|
||||
func (h *APIHandler) GetProvidersHandler(c *gin.Context) {
|
||||
providers, err := db.GetProviders(h.DB)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to retrieve providers",
|
||||
"type": "internal_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, providers)
|
||||
}
|
||||
|
||||
// GetProviderHandler 处理 GET /api/providers/:id 请求
|
||||
func (h *APIHandler) GetProviderHandler(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var providerID uint
|
||||
if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := db.GetProviderByID(h.DB, providerID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Provider not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, provider)
|
||||
}
|
||||
|
||||
// CreateProviderHandler 处理 POST /api/providers 请求
|
||||
func (h *APIHandler) CreateProviderHandler(c *gin.Context) {
|
||||
var provider models.Provider
|
||||
if err := c.ShouldBindJSON(&provider); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := db.CreateProvider(h.DB, &provider); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, provider)
|
||||
}
|
||||
|
||||
// UpdateProviderHandler 处理 PUT /api/providers/:id 请求
|
||||
func (h *APIHandler) UpdateProviderHandler(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var providerID uint
|
||||
if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var provider models.Provider
|
||||
if err := c.ShouldBindJSON(&provider); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
provider.ID = providerID
|
||||
|
||||
if err := db.UpdateProvider(h.DB, &provider); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, provider)
|
||||
}
|
||||
|
||||
// DeleteProviderHandler 处理 DELETE /api/providers/:id 请求
|
||||
func (h *APIHandler) DeleteProviderHandler(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var providerID uint
|
||||
if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := db.DeleteProvider(h.DB, providerID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete provider"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusNoContent, nil)
|
||||
}
|
||||
96
backend/api/virtual_model_handlers.go
Normal file
96
backend/api/virtual_model_handlers.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"ai-gateway/internal/db"
|
||||
"ai-gateway/internal/models"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CreateVirtualModelHandler 处理 POST /api/virtual-models 请求
|
||||
func (h *APIHandler) CreateVirtualModelHandler(c *gin.Context) {
|
||||
var vm models.VirtualModel
|
||||
if err := c.ShouldBindJSON(&vm); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := db.CreateVirtualModel(h.DB, &vm); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create virtual model"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, vm)
|
||||
}
|
||||
|
||||
// GetVirtualModelsHandler 处理 GET /api/virtual-models 请求
|
||||
func (h *APIHandler) GetVirtualModelsHandler(c *gin.Context) {
|
||||
virtualModels, err := db.GetVirtualModels(h.DB)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve virtual models"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, virtualModels)
|
||||
}
|
||||
|
||||
// GetVirtualModelHandler 处理 GET /api/virtual-models/:id 请求
|
||||
func (h *APIHandler) GetVirtualModelHandler(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var virtualModelID uint
|
||||
if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"})
|
||||
return
|
||||
}
|
||||
|
||||
virtualModel, err := db.GetVirtualModelByID(h.DB, virtualModelID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Virtual model not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, virtualModel)
|
||||
}
|
||||
|
||||
// UpdateVirtualModelHandler 处理 PUT /api/virtual-models/:id 请求
|
||||
func (h *APIHandler) UpdateVirtualModelHandler(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var virtualModelID uint
|
||||
if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var vm models.VirtualModel
|
||||
if err := c.ShouldBindJSON(&vm); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
vm.ID = virtualModelID
|
||||
|
||||
if err := db.UpdateVirtualModel(h.DB, &vm); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update virtual model"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, vm)
|
||||
}
|
||||
|
||||
// DeleteVirtualModelHandler 处理 DELETE /api/virtual-models/:id 请求
|
||||
func (h *APIHandler) DeleteVirtualModelHandler(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var virtualModelID uint
|
||||
if _, err := fmt.Sscanf(id, "%d", &virtualModelID); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid virtual model ID"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := db.DeleteVirtualModel(h.DB, virtualModelID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete virtual model"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusNoContent, nil)
|
||||
}
|
||||
Reference in New Issue
Block a user