Files
AIRouter/backend/internal/billing/calculator.go

133 lines
3.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package billing
import (
"ai-gateway/internal/models"
"log"
"github.com/pkoukk/tiktoken-go"
)
// TokenCalculator 提供token计算功能
type TokenCalculator struct {
encoding *tiktoken.Tiktoken
}
// NewTokenCalculator 创建一个新的TokenCalculator实例
func NewTokenCalculator() (*TokenCalculator, error) {
encoding, err := tiktoken.GetEncoding("cl100k_base")
if err != nil {
return nil, err
}
return &TokenCalculator{encoding: encoding}, nil
}
// CalculateMessagesTokens 计算消息列表的token总数
// 根据OpenAI的token计算规则
// - 每条消息有4个token的基础开销role + 分隔符等)
// - role字段的token数
// - content字段的token数
// - 整个对话有2个token的基础开销
func (tc *TokenCalculator) CalculateMessagesTokens(messages []ChatCompletionMessage) int {
totalTokens := 0
for _, msg := range messages {
// 每条消息的基础开销role + 分隔符等)
totalTokens += 4
// role的token数
totalTokens += len(tc.encoding.Encode(msg.Role, nil, nil))
// content的token数
totalTokens += len(tc.encoding.Encode(msg.Content, nil, nil))
}
// 对话的基础开销
totalTokens += 2
return totalTokens
}
// CalculateTextTokens 从文本计算token数
func (tc *TokenCalculator) CalculateTextTokens(text string) int {
return len(tc.encoding.Encode(text, nil, nil))
}
// ChatCompletionMessage 聊天消息结构
type ChatCompletionMessage struct {
Role string
Content string
}
// CostCalculator 提供费用计算功能
type CostCalculator struct{}
// NewCostCalculator 创建一个新的CostCalculator实例
func NewCostCalculator() *CostCalculator {
return &CostCalculator{}
}
// CalculateCost 根据计费方式计算费用
// 支持两种计费方式:
// - token: 按token计费分别计算输入和输出token的费用
// - request: 按请求次数计费,使用固定价格
func (cc *CostCalculator) CalculateCost(
billingMethod string,
requestTokenCount int,
responseTokenCount int,
promptTokenPrice float64,
completionTokenPrice float64,
fixedPrice float64,
) float64 {
switch billingMethod {
case models.BillingMethodToken:
cost := float64(requestTokenCount)*promptTokenPrice/1_000_000 + float64(responseTokenCount)*completionTokenPrice/1_000_000
return cost
case models.BillingMethodRequest:
return fixedPrice
default:
return 0
}
}
// CalculateModelCost 根据后端模型配置计算费用
func (cc *CostCalculator) CalculateModelCost(
backendModel *models.BackendModel,
requestTokenCount int,
responseTokenCount int,
) float64 {
return cc.CalculateCost(
backendModel.BillingMethod,
requestTokenCount,
responseTokenCount,
backendModel.PromptTokenPrice,
backendModel.CompletionTokenPrice,
backendModel.FixedPrice,
)
}
// CalculateMessagesTokensSimple 简化的消息token计算函数不需要创建calculator实例
func CalculateMessagesTokensSimple(messages []ChatCompletionMessage) int {
encoding, err := tiktoken.GetEncoding("cl100k_base")
if err != nil {
log.Printf("Failed to get tiktoken encoding: %v", err)
return 0
}
totalTokens := 0
for _, msg := range messages {
totalTokens += 4
totalTokens += len(encoding.Encode(msg.Role, nil, nil))
totalTokens += len(encoding.Encode(msg.Content, nil, nil))
}
totalTokens += 2
return totalTokens
}
// CalculateTextTokensSimple 简化的文本token计算函数不需要创建calculator实例
func CalculateTextTokensSimple(text string) int {
encoding, err := tiktoken.GetEncoding("cl100k_base")
if err != nil {
log.Printf("Failed to get tiktoken encoding: %v", err)
return 0
}
return len(encoding.Encode(text, nil, nil))
}