重构token计算逻辑,新增费用计算器,优化消息和文本token计算
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"ai-gateway/internal/billing"
|
||||
"ai-gateway/internal/db"
|
||||
"ai-gateway/internal/logger"
|
||||
"ai-gateway/internal/models"
|
||||
@@ -16,7 +17,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -146,7 +146,8 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 使用tiktoken精确计算请求token数
|
||||
requestTokenCount := calculateTokenCount(req.Messages)
|
||||
messages := convertToTikTokenMessages(req.Messages)
|
||||
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
|
||||
|
||||
// 选择后端模型
|
||||
backendModel, err := router.SelectBackendModel(h.DB, req.Model, requestTokenCount)
|
||||
@@ -273,7 +274,7 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
|
||||
if choice, ok := choices[0].(map[string]interface{}); ok {
|
||||
if message, ok := choice["message"].(map[string]interface{}); ok {
|
||||
if content, ok := message["content"].(string); ok {
|
||||
responseTokenCount = calculateTokenCountFromText(content)
|
||||
responseTokenCount = billing.CalculateTextTokensSimple(content)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -281,13 +282,8 @@ func (h *APIHandler) ChatCompletions(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 计算费用
|
||||
var cost float64
|
||||
switch backendModel.BillingMethod {
|
||||
case models.BillingMethodToken:
|
||||
cost = float64(requestTokenCount)*backendModel.PromptTokenPrice + float64(responseTokenCount)*backendModel.CompletionTokenPrice
|
||||
case models.BillingMethodRequest:
|
||||
cost = backendModel.FixedPrice
|
||||
}
|
||||
costCalculator := billing.NewCostCalculator()
|
||||
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
|
||||
|
||||
// 创建日志记录
|
||||
logEntry := &models.RequestLog{
|
||||
@@ -325,7 +321,7 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
|
||||
|
||||
var req ResponsesRequest
|
||||
|
||||
// 解析请求\u4f53
|
||||
// 解析请求体
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
@@ -337,7 +333,8 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 使用tiktoken精确计算请求token数
|
||||
requestTokenCount := calculateTokenCount(req.Messages)
|
||||
messages := convertToTikTokenMessages(req.Messages)
|
||||
requestTokenCount := billing.CalculateMessagesTokensSimple(messages)
|
||||
|
||||
// 选择后端模型
|
||||
backendModel, err := router.SelectBackendModel(h.DB, req.Model, requestTokenCount)
|
||||
@@ -426,7 +423,7 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
|
||||
if choice, ok := choices[0].(map[string]interface{}); ok {
|
||||
if message, ok := choice["message"].(map[string]interface{}); ok {
|
||||
if content, ok := message["content"].(string); ok {
|
||||
responseTokenCount = calculateTokenCountFromText(content)
|
||||
responseTokenCount = billing.CalculateTextTokensSimple(content)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -434,13 +431,8 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 计算费用
|
||||
var cost float64
|
||||
switch backendModel.BillingMethod {
|
||||
case models.BillingMethodToken:
|
||||
cost = float64(requestTokenCount)*backendModel.PromptTokenPrice + float64(responseTokenCount)*backendModel.CompletionTokenPrice
|
||||
case models.BillingMethodRequest:
|
||||
cost = backendModel.FixedPrice
|
||||
}
|
||||
costCalculator := billing.NewCostCalculator()
|
||||
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
|
||||
|
||||
// 从上下文获取API密钥
|
||||
apiKeyValue, exists := c.Get("apiKey")
|
||||
@@ -480,38 +472,16 @@ func (h *APIHandler) ResponsesCompletions(c *gin.Context) {
|
||||
c.Writer.Write(responseBody)
|
||||
}
|
||||
|
||||
// calculateTokenCount 计算消息列表的token总数
|
||||
func calculateTokenCount(messages []ChatCompletionMessage) int {
|
||||
encoding, err := tiktoken.GetEncoding("cl100k_base")
|
||||
if err != nil {
|
||||
log.Printf("Failed to get tiktoken encoding: %v", err)
|
||||
return 0
|
||||
// convertToTikTokenMessages 将ChatCompletionMessage转换为billing包的消息格式
|
||||
func convertToTikTokenMessages(messages []ChatCompletionMessage) []billing.ChatCompletionMessage {
|
||||
result := make([]billing.ChatCompletionMessage, len(messages))
|
||||
for i, msg := range messages {
|
||||
result[i] = billing.ChatCompletionMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
}
|
||||
}
|
||||
|
||||
totalTokens := 0
|
||||
for _, msg := range messages {
|
||||
// 每条消息的基础开销(role + 分隔符等)
|
||||
totalTokens += 4
|
||||
// role的token数
|
||||
totalTokens += len(encoding.Encode(msg.Role, nil, nil))
|
||||
// content的token数
|
||||
totalTokens += len(encoding.Encode(msg.Content, nil, nil))
|
||||
}
|
||||
// 对话的基础开销
|
||||
totalTokens += 2
|
||||
|
||||
return totalTokens
|
||||
}
|
||||
|
||||
// calculateTokenCountFromText 从文本计算token数
|
||||
func calculateTokenCountFromText(text string) int {
|
||||
encoding, err := tiktoken.GetEncoding("cl100k_base")
|
||||
if err != nil {
|
||||
log.Printf("Failed to get tiktoken encoding: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
return len(encoding.Encode(text, nil, nil))
|
||||
return result
|
||||
}
|
||||
|
||||
// handleStreamingResponse 处理流式响应
|
||||
@@ -607,16 +577,11 @@ func handleStreamingResponse(c *gin.Context, resp *http.Response, requestTimesta
|
||||
}
|
||||
|
||||
// 计算响应 token 数
|
||||
responseTokenCount := calculateTokenCountFromText(fullContent.String())
|
||||
responseTokenCount := billing.CalculateTextTokensSimple(fullContent.String())
|
||||
|
||||
// 计算费用
|
||||
var cost float64
|
||||
switch backendModel.BillingMethod {
|
||||
case models.BillingMethodToken:
|
||||
cost = float64(requestTokenCount)*backendModel.PromptTokenPrice + float64(responseTokenCount)*backendModel.CompletionTokenPrice
|
||||
case models.BillingMethodRequest:
|
||||
cost = backendModel.FixedPrice
|
||||
}
|
||||
costCalculator := billing.NewCostCalculator()
|
||||
cost := costCalculator.CalculateModelCost(backendModel, requestTokenCount, responseTokenCount)
|
||||
|
||||
// 创建日志记录
|
||||
logEntry := &models.RequestLog{
|
||||
|
||||
132
backend/internal/billing/calculator.go
Normal file
132
backend/internal/billing/calculator.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"ai-gateway/internal/models"
|
||||
"log"
|
||||
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
)
|
||||
|
||||
// TokenCalculator 提供token计算功能
|
||||
type TokenCalculator struct {
|
||||
encoding *tiktoken.Tiktoken
|
||||
}
|
||||
|
||||
// NewTokenCalculator 创建一个新的TokenCalculator实例
|
||||
func NewTokenCalculator() (*TokenCalculator, error) {
|
||||
encoding, err := tiktoken.GetEncoding("cl100k_base")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TokenCalculator{encoding: encoding}, nil
|
||||
}
|
||||
|
||||
// CalculateMessagesTokens 计算消息列表的token总数
|
||||
// 根据OpenAI的token计算规则:
|
||||
// - 每条消息有4个token的基础开销(role + 分隔符等)
|
||||
// - role字段的token数
|
||||
// - content字段的token数
|
||||
// - 整个对话有2个token的基础开销
|
||||
func (tc *TokenCalculator) CalculateMessagesTokens(messages []ChatCompletionMessage) int {
|
||||
totalTokens := 0
|
||||
for _, msg := range messages {
|
||||
// 每条消息的基础开销(role + 分隔符等)
|
||||
totalTokens += 4
|
||||
// role的token数
|
||||
totalTokens += len(tc.encoding.Encode(msg.Role, nil, nil))
|
||||
// content的token数
|
||||
totalTokens += len(tc.encoding.Encode(msg.Content, nil, nil))
|
||||
}
|
||||
// 对话的基础开销
|
||||
totalTokens += 2
|
||||
|
||||
return totalTokens
|
||||
}
|
||||
|
||||
// CalculateTextTokens 从文本计算token数
|
||||
func (tc *TokenCalculator) CalculateTextTokens(text string) int {
|
||||
return len(tc.encoding.Encode(text, nil, nil))
|
||||
}
|
||||
|
||||
// ChatCompletionMessage 聊天消息结构
|
||||
type ChatCompletionMessage struct {
|
||||
Role string
|
||||
Content string
|
||||
}
|
||||
|
||||
// CostCalculator 提供费用计算功能
|
||||
type CostCalculator struct{}
|
||||
|
||||
// NewCostCalculator 创建一个新的CostCalculator实例
|
||||
func NewCostCalculator() *CostCalculator {
|
||||
return &CostCalculator{}
|
||||
}
|
||||
|
||||
// CalculateCost 根据计费方式计算费用
|
||||
// 支持两种计费方式:
|
||||
// - token: 按token计费,分别计算输入和输出token的费用
|
||||
// - request: 按请求次数计费,使用固定价格
|
||||
func (cc *CostCalculator) CalculateCost(
|
||||
billingMethod string,
|
||||
requestTokenCount int,
|
||||
responseTokenCount int,
|
||||
promptTokenPrice float64,
|
||||
completionTokenPrice float64,
|
||||
fixedPrice float64,
|
||||
) float64 {
|
||||
switch billingMethod {
|
||||
case models.BillingMethodToken:
|
||||
return float64(requestTokenCount)*promptTokenPrice + float64(responseTokenCount)*completionTokenPrice
|
||||
case models.BillingMethodRequest:
|
||||
return fixedPrice
|
||||
default:
|
||||
log.Printf("Unknown billing method: %s, defaulting to 0 cost", billingMethod)
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// CalculateModelCost 根据后端模型配置计算费用
|
||||
func (cc *CostCalculator) CalculateModelCost(
|
||||
backendModel *models.BackendModel,
|
||||
requestTokenCount int,
|
||||
responseTokenCount int,
|
||||
) float64 {
|
||||
return cc.CalculateCost(
|
||||
backendModel.BillingMethod,
|
||||
requestTokenCount,
|
||||
responseTokenCount,
|
||||
backendModel.PromptTokenPrice,
|
||||
backendModel.CompletionTokenPrice,
|
||||
backendModel.FixedPrice,
|
||||
)
|
||||
}
|
||||
|
||||
// CalculateMessagesTokensSimple 简化的消息token计算函数(不需要创建calculator实例)
|
||||
func CalculateMessagesTokensSimple(messages []ChatCompletionMessage) int {
|
||||
encoding, err := tiktoken.GetEncoding("cl100k_base")
|
||||
if err != nil {
|
||||
log.Printf("Failed to get tiktoken encoding: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
totalTokens := 0
|
||||
for _, msg := range messages {
|
||||
totalTokens += 4
|
||||
totalTokens += len(encoding.Encode(msg.Role, nil, nil))
|
||||
totalTokens += len(encoding.Encode(msg.Content, nil, nil))
|
||||
}
|
||||
totalTokens += 2
|
||||
|
||||
return totalTokens
|
||||
}
|
||||
|
||||
// CalculateTextTokensSimple 简化的文本token计算函数(不需要创建calculator实例)
|
||||
func CalculateTextTokensSimple(text string) int {
|
||||
encoding, err := tiktoken.GetEncoding("cl100k_base")
|
||||
if err != nil {
|
||||
log.Printf("Failed to get tiktoken encoding: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
return len(encoding.Encode(text, nil, nil))
|
||||
}
|
||||
Reference in New Issue
Block a user