diff --git a/backend/api/handlers.go b/backend/api/handlers.go index fbd99ad..563a4b9 100644 --- a/backend/api/handlers.go +++ b/backend/api/handlers.go @@ -1,6 +1,7 @@ package api import ( + "ai-gateway/internal/billing" "ai-gateway/internal/db" "ai-gateway/internal/logger" "ai-gateway/internal/models" @@ -16,7 +17,6 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/pkoukk/tiktoken-go" "gorm.io/gorm" ) @@ -146,7 +146,8 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) { } // 使用tiktoken精确计算请求token数 - requestTokenCount := calculateTokenCount(req.Messages) + messages := convertToTikTokenMessages(req.Messages) + requestTokenCount := billing.CalculateMessagesTokensSimple(messages) // 选择后端模型 backendModel, err := router.SelectBackendModel(h.DB, req.Model, requestTokenCount) @@ -273,7 +274,7 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) { if choice, ok := choices[0].(map[string]interface{}); ok { if message, ok := choice["message"].(map[string]interface{}); ok { if content, ok := message["content"].(string); ok { - responseTokenCount = calculateTokenCountFromText(content) + responseTokenCount = billing.CalculateTextTokensSimple(content) } } } @@ -281,13 +282,8 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) { } // 计算费用 - var cost float64 - switch backendModel.BillingMethod { - case models.BillingMethodToken: - cost = float64(requestTokenCount)*backendModel.PromptTokenPrice + float64(responseTokenCount)*backendModel.CompletionTokenPrice - case models.BillingMethodRequest: - cost = backendModel.FixedPrice - } + costCalculator := billing.NewCostCalculator() + cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount) // 创建日志记录 logEntry := &models.RequestLog{ @@ -325,7 +321,7 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) { var req ResponsesRequest - // 解析请求\u4f53 + // 解析请求体 if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{ "error": gin.H{ @@ -337,7 +333,8 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) { } // 使用tiktoken精确计算请求token数 - requestTokenCount := calculateTokenCount(req.Messages) + messages := convertToTikTokenMessages(req.Messages) + requestTokenCount := billing.CalculateMessagesTokensSimple(messages) // 选择后端模型 backendModel, err := router.SelectBackendModel(h.DB, req.Model, requestTokenCount) @@ -426,7 +423,7 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) { if choice, ok := choices[0].(map[string]interface{}); ok { if message, ok := choice["message"].(map[string]interface{}); ok { if content, ok := message["content"].(string); ok { - responseTokenCount = calculateTokenCountFromText(content) + responseTokenCount = billing.CalculateTextTokensSimple(content) } } } @@ -434,13 +431,8 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) { } // 计算费用 - var cost float64 - switch backendModel.BillingMethod { - case models.BillingMethodToken: - cost = float64(requestTokenCount)*backendModel.PromptTokenPrice + float64(responseTokenCount)*backendModel.CompletionTokenPrice - case models.BillingMethodRequest: - cost = backendModel.FixedPrice - } + costCalculator := billing.NewCostCalculator() + cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount) // 从上下文获取API密钥 apiKeyValue, exists := c.Get("apiKey") @@ -480,38 +472,16 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) { c.Writer.Write(responseBody) } -// calculateTokenCount 计算消息列表的token总数 -func calculateTokenCount(messages []ChatCompletionMessage) int { - encoding, err := tiktoken.GetEncoding("cl100k_base") - if err != nil { - log.Printf("Failed to get tiktoken encoding: %v", err) - return 0 +// convertToTikTokenMessages 将ChatCompletionMessage转换为billing包的消息格式 +func convertToTikTokenMessages(messages []ChatCompletionMessage) []billing.ChatCompletionMessage { + result := make([]billing.ChatCompletionMessage, len(messages)) + for i, msg := range messages { + result[i] = billing.ChatCompletionMessage{ + Role: msg.Role, + Content: msg.Content, + } } - - totalTokens := 0 - for _, msg := range messages { - // 每条消息的基础开销(role + 分隔符等) - totalTokens += 4 - // role的token数 - totalTokens += len(encoding.Encode(msg.Role, nil, nil)) - // content的token数 - totalTokens += len(encoding.Encode(msg.Content, nil, nil)) - } - // 对话的基础开销 - totalTokens += 2 - - return totalTokens -} - -// calculateTokenCountFromText 从文本计算token数 -func calculateTokenCountFromText(text string) int { - encoding, err := tiktoken.GetEncoding("cl100k_base") - if err != nil { - log.Printf("Failed to get tiktoken encoding: %v", err) - return 0 - } - - return len(encoding.Encode(text, nil, nil)) + return result } // handleStreamingResponse 处理流式响应 @@ -607,16 +577,11 @@ func handleStreamingResponse(c *gin.Context, resp *http.Response, requestTimesta } // 计算响应 token 数 - responseTokenCount := calculateTokenCountFromText(fullContent.String()) + responseTokenCount := billing.CalculateTextTokensSimple(fullContent.String()) // 计算费用 - var cost float64 - switch backendModel.BillingMethod { - case models.BillingMethodToken: - cost = float64(requestTokenCount)*backendModel.PromptTokenPrice + float64(responseTokenCount)*backendModel.CompletionTokenPrice - case models.BillingMethodRequest: - cost = backendModel.FixedPrice - } + costCalculator := billing.NewCostCalculator() + cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount) // 创建日志记录 logEntry := &models.RequestLog{ diff --git a/backend/internal/billing/calculator.go b/backend/internal/billing/calculator.go new file mode 100644 index 0000000..e6a4988 --- /dev/null +++ b/backend/internal/billing/calculator.go @@ -0,0 +1,132 @@ +package billing + +import ( + "ai-gateway/internal/models" + "log" + + "github.com/pkoukk/tiktoken-go" +) + +// TokenCalculator 提供token计算功能 +type TokenCalculator struct { + encoding *tiktoken.Tiktoken +} + +// NewTokenCalculator 创建一个新的TokenCalculator实例 +func NewTokenCalculator() (*TokenCalculator, error) { + encoding, err := tiktoken.GetEncoding("cl100k_base") + if err != nil { + return nil, err + } + return &TokenCalculator{encoding: encoding}, nil +} + +// CalculateMessagesTokens 计算消息列表的token总数 +// 根据OpenAI的token计算规则: +// - 每条消息有4个token的基础开销(role + 分隔符等) +// - role字段的token数 +// - content字段的token数 +// - 整个对话有2个token的基础开销 +func (tc *TokenCalculator) CalculateMessagesTokens(messages []ChatCompletionMessage) int { + totalTokens := 0 + for _, msg := range messages { + // 每条消息的基础开销(role + 分隔符等) + totalTokens += 4 + // role的token数 + totalTokens += len(tc.encoding.Encode(msg.Role, nil, nil)) + // content的token数 + totalTokens += len(tc.encoding.Encode(msg.Content, nil, nil)) + } + // 对话的基础开销 + totalTokens += 2 + + return totalTokens +} + +// CalculateTextTokens 从文本计算token数 +func (tc *TokenCalculator) CalculateTextTokens(text string) int { + return len(tc.encoding.Encode(text, nil, nil)) +} + +// ChatCompletionMessage 聊天消息结构 +type ChatCompletionMessage struct { + Role string + Content string +} + +// CostCalculator 提供费用计算功能 +type CostCalculator struct{} + +// NewCostCalculator 创建一个新的CostCalculator实例 +func NewCostCalculator() *CostCalculator { + return &CostCalculator{} +} + +// CalculateCost 根据计费方式计算费用 +// 支持两种计费方式: +// - token: 按token计费,分别计算输入和输出token的费用 +// - request: 按请求次数计费,使用固定价格 +func (cc *CostCalculator) CalculateCost( + billingMethod string, + requestTokenCount int, + responseTokenCount int, + promptTokenPrice float64, + completionTokenPrice float64, + fixedPrice float64, +) float64 { + switch billingMethod { + case models.BillingMethodToken: + return float64(requestTokenCount)*promptTokenPrice + float64(responseTokenCount)*completionTokenPrice + case models.BillingMethodRequest: + return fixedPrice + default: + log.Printf("Unknown billing method: %s, defaulting to 0 cost", billingMethod) + return 0 + } +} + +// CalculateModelCost 根据后端模型配置计算费用 +func (cc *CostCalculator) CalculateModelCost( + backendModel *models.BackendModel, + requestTokenCount int, + responseTokenCount int, +) float64 { + return cc.CalculateCost( + backendModel.BillingMethod, + requestTokenCount, + responseTokenCount, + backendModel.PromptTokenPrice, + backendModel.CompletionTokenPrice, + backendModel.FixedPrice, + ) +} + +// CalculateMessagesTokensSimple 简化的消息token计算函数(不需要创建calculator实例) +func CalculateMessagesTokensSimple(messages []ChatCompletionMessage) int { + encoding, err := tiktoken.GetEncoding("cl100k_base") + if err != nil { + log.Printf("Failed to get tiktoken encoding: %v", err) + return 0 + } + + totalTokens := 0 + for _, msg := range messages { + totalTokens += 4 + totalTokens += len(encoding.Encode(msg.Role, nil, nil)) + totalTokens += len(encoding.Encode(msg.Content, nil, nil)) + } + totalTokens += 2 + + return totalTokens +} + +// CalculateTextTokensSimple 简化的文本token计算函数(不需要创建calculator实例) +func CalculateTextTokensSimple(text string) int { + encoding, err := tiktoken.GetEncoding("cl100k_base") + if err != nil { + log.Printf("Failed to get tiktoken encoding: %v", err) + return 0 + } + + return len(encoding.Encode(text, nil, nil)) +}