重构token计算逻辑,新增费用计算器,优化消息和文本token计算

This commit is contained in:
2025-11-08 23:31:48 +08:00
parent a223b3d9a1
commit 36e2ace568
2 changed files with 156 additions and 59 deletions

View File

@@ -1,6 +1,7 @@
package api
import (
"ai-gateway/internal/billing"
"ai-gateway/internal/db"
"ai-gateway/internal/logger"
"ai-gateway/internal/models"
@@ -16,7 +17,6 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
"gorm.io/gorm"
)
@@ -146,7 +146,8 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
}
// 使用tiktoken精确计算请求token数
requestTokenCount := calculateTokenCount(req.Messages)
messages := convertToTikTokenMessages(req.Messages)
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
// 选择后端模型
backendModel, err := router.SelectBackendModel(h.DB, req.Model, requestTokenCount)
@@ -273,7 +274,7 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
if choice, ok := choices[0].(map[string]interface{}); ok {
if message, ok := choice["message"].(map[string]interface{}); ok {
if content, ok := message["content"].(string); ok {
responseTokenCount = calculateTokenCountFromText(content)
responseTokenCount = billing.CalculateTextTokensSimple(content)
}
}
}
@@ -281,13 +282,8 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
}
// 计算费用
var cost float64
switch backendModel.BillingMethod {
case models.BillingMethodToken:
cost = float64(requestTokenCount)*backendModel.PromptTokenPrice + float64(responseTokenCount)*backendModel.CompletionTokenPrice
case models.BillingMethodRequest:
cost = backendModel.FixedPrice
}
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 创建日志记录
logEntry := &models.RequestLog{
@@ -325,7 +321,7 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
var req ResponsesRequest
// 解析请求\u4f53
// 解析请求
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
@@ -337,7 +333,8 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
}
// 使用tiktoken精确计算请求token数
requestTokenCount := calculateTokenCount(req.Messages)
messages := convertToTikTokenMessages(req.Messages)
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
// 选择后端模型
backendModel, err := router.SelectBackendModel(h.DB, req.Model, requestTokenCount)
@@ -426,7 +423,7 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
if choice, ok := choices[0].(map[string]interface{}); ok {
if message, ok := choice["message"].(map[string]interface{}); ok {
if content, ok := message["content"].(string); ok {
responseTokenCount = calculateTokenCountFromText(content)
responseTokenCount = billing.CalculateTextTokensSimple(content)
}
}
}
@@ -434,13 +431,8 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
}
// 计算费用
var cost float64
switch backendModel.BillingMethod {
case models.BillingMethodToken:
cost = float64(requestTokenCount)*backendModel.PromptTokenPrice + float64(responseTokenCount)*backendModel.CompletionTokenPrice
case models.BillingMethodRequest:
cost = backendModel.FixedPrice
}
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 从上下文获取API密钥
apiKeyValue, exists := c.Get("apiKey")
@@ -480,38 +472,16 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
c.Writer.Write(responseBody)
}
// calculateTokenCount 计算消息列表的token总数
func calculateTokenCount(messages []ChatCompletionMessage) int {
encoding, err := tiktoken.GetEncoding("cl100k_base")
if err != nil {
log.Printf("Failed to get tiktoken encoding: %v", err)
return 0
// convertToTikTokenMessages 将ChatCompletionMessage转换为billing包的消息格式
func convertToTikTokenMessages(messages []ChatCompletionMessage) []billing.ChatCompletionMessage {
result := make([]billing.ChatCompletionMessage, len(messages))
for i, msg := range messages {
result[i] = billing.ChatCompletionMessage{
Role: msg.Role,
Content: msg.Content,
}
}
totalTokens := 0
for _, msg := range messages {
// 每条消息的基础开销role + 分隔符等)
totalTokens += 4
// role的token数
totalTokens += len(encoding.Encode(msg.Role, nil, nil))
// content的token数
totalTokens += len(encoding.Encode(msg.Content, nil, nil))
}
// 对话的基础开销
totalTokens += 2
return totalTokens
}
// calculateTokenCountFromText 从文本计算token数
func calculateTokenCountFromText(text string) int {
encoding, err := tiktoken.GetEncoding("cl100k_base")
if err != nil {
log.Printf("Failed to get tiktoken encoding: %v", err)
return 0
}
return len(encoding.Encode(text, nil, nil))
return result
}
// handleStreamingResponse 处理流式响应
@@ -607,16 +577,11 @@ func handleStreamingResponse(c *gin.Context, resp *http.Response, requestTimesta
}
// 计算响应 token 数
responseTokenCount := calculateTokenCountFromText(fullContent.String())
responseTokenCount := billing.CalculateTextTokensSimple(fullContent.String())
// 计算费用
var cost float64
switch backendModel.BillingMethod {
case models.BillingMethodToken:
cost = float64(requestTokenCount)*backendModel.PromptTokenPrice + float64(responseTokenCount)*backendModel.CompletionTokenPrice
case models.BillingMethodRequest:
cost = backendModel.FixedPrice
}
costCalculator := billing.NewCostCalculator()
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
// 创建日志记录
logEntry := &models.RequestLog{

View File

@@ -0,0 +1,132 @@
package billing
import (
"ai-gateway/internal/models"
"log"
"github.com/pkoukk/tiktoken-go"
)
// TokenCalculator 提供token计算功能
type TokenCalculator struct {
encoding *tiktoken.Tiktoken
}
// NewTokenCalculator 创建一个新的TokenCalculator实例
func NewTokenCalculator() (*TokenCalculator, error) {
encoding, err := tiktoken.GetEncoding("cl100k_base")
if err != nil {
return nil, err
}
return &TokenCalculator{encoding: encoding}, nil
}
// CalculateMessagesTokens 计算消息列表的token总数
// 根据OpenAI的token计算规则
// - 每条消息有4个token的基础开销role + 分隔符等)
// - role字段的token数
// - content字段的token数
// - 整个对话有2个token的基础开销
func (tc *TokenCalculator) CalculateMessagesTokens(messages []ChatCompletionMessage) int {
totalTokens := 0
for _, msg := range messages {
// 每条消息的基础开销role + 分隔符等)
totalTokens += 4
// role的token数
totalTokens += len(tc.encoding.Encode(msg.Role, nil, nil))
// content的token数
totalTokens += len(tc.encoding.Encode(msg.Content, nil, nil))
}
// 对话的基础开销
totalTokens += 2
return totalTokens
}
// CalculateTextTokens 从文本计算token数
func (tc *TokenCalculator) CalculateTextTokens(text string) int {
return len(tc.encoding.Encode(text, nil, nil))
}
// ChatCompletionMessage 聊天消息结构
type ChatCompletionMessage struct {
Role string
Content string
}
// CostCalculator 提供费用计算功能
type CostCalculator struct{}
// NewCostCalculator 创建一个新的CostCalculator实例
func NewCostCalculator() *CostCalculator {
return &CostCalculator{}
}
// CalculateCost 根据计费方式计算费用
// 支持两种计费方式:
// - token: 按token计费分别计算输入和输出token的费用
// - request: 按请求次数计费,使用固定价格
func (cc *CostCalculator) CalculateCost(
billingMethod string,
requestTokenCount int,
responseTokenCount int,
promptTokenPrice float64,
completionTokenPrice float64,
fixedPrice float64,
) float64 {
switch billingMethod {
case models.BillingMethodToken:
return float64(requestTokenCount)*promptTokenPrice + float64(responseTokenCount)*completionTokenPrice
case models.BillingMethodRequest:
return fixedPrice
default:
log.Printf("Unknown billing method: %s, defaulting to 0 cost", billingMethod)
return 0
}
}
// CalculateModelCost 根据后端模型配置计算费用
func (cc *CostCalculator) CalculateModelCost(
backendModel *models.BackendModel,
requestTokenCount int,
responseTokenCount int,
) float64 {
return cc.CalculateCost(
backendModel.BillingMethod,
requestTokenCount,
responseTokenCount,
backendModel.PromptTokenPrice,
backendModel.CompletionTokenPrice,
backendModel.FixedPrice,
)
}
// CalculateMessagesTokensSimple 简化的消息token计算函数不需要创建calculator实例
func CalculateMessagesTokensSimple(messages []ChatCompletionMessage) int {
encoding, err := tiktoken.GetEncoding("cl100k_base")
if err != nil {
log.Printf("Failed to get tiktoken encoding: %v", err)
return 0
}
totalTokens := 0
for _, msg := range messages {
totalTokens += 4
totalTokens += len(encoding.Encode(msg.Role, nil, nil))
totalTokens += len(encoding.Encode(msg.Content, nil, nil))
}
totalTokens += 2
return totalTokens
}
// CalculateTextTokensSimple 简化的文本token计算函数不需要创建calculator实例
func CalculateTextTokensSimple(text string) int {
encoding, err := tiktoken.GetEncoding("cl100k_base")
if err != nil {
log.Printf("Failed to get tiktoken encoding: %v", err)
return 0
}
return len(encoding.Encode(text, nil, nil))
}