133 lines
3.6 KiB
Go
133 lines
3.6 KiB
Go
package billing
|
||
|
||
import (
|
||
"ai-gateway/internal/models"
|
||
"log"
|
||
|
||
"github.com/pkoukk/tiktoken-go"
|
||
)
|
||
|
||
// TokenCalculator 提供token计算功能
|
||
type TokenCalculator struct {
|
||
encoding *tiktoken.Tiktoken
|
||
}
|
||
|
||
// NewTokenCalculator 创建一个新的TokenCalculator实例
|
||
func NewTokenCalculator() (*TokenCalculator, error) {
|
||
encoding, err := tiktoken.GetEncoding("cl100k_base")
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &TokenCalculator{encoding: encoding}, nil
|
||
}
|
||
|
||
// CalculateMessagesTokens 计算消息列表的token总数
|
||
// 根据OpenAI的token计算规则:
|
||
// - 每条消息有4个token的基础开销(role + 分隔符等)
|
||
// - role字段的token数
|
||
// - content字段的token数
|
||
// - 整个对话有2个token的基础开销
|
||
func (tc *TokenCalculator) CalculateMessagesTokens(messages []ChatCompletionMessage) int {
|
||
totalTokens := 0
|
||
for _, msg := range messages {
|
||
// 每条消息的基础开销(role + 分隔符等)
|
||
totalTokens += 4
|
||
// role的token数
|
||
totalTokens += len(tc.encoding.Encode(msg.Role, nil, nil))
|
||
// content的token数
|
||
totalTokens += len(tc.encoding.Encode(msg.Content, nil, nil))
|
||
}
|
||
// 对话的基础开销
|
||
totalTokens += 2
|
||
|
||
return totalTokens
|
||
}
|
||
|
||
// CalculateTextTokens 从文本计算token数
|
||
func (tc *TokenCalculator) CalculateTextTokens(text string) int {
|
||
return len(tc.encoding.Encode(text, nil, nil))
|
||
}
|
||
|
||
// ChatCompletionMessage 聊天消息结构
|
||
type ChatCompletionMessage struct {
|
||
Role string
|
||
Content string
|
||
}
|
||
|
||
// CostCalculator 提供费用计算功能
|
||
type CostCalculator struct{}
|
||
|
||
// NewCostCalculator 创建一个新的CostCalculator实例
|
||
func NewCostCalculator() *CostCalculator {
|
||
return &CostCalculator{}
|
||
}
|
||
|
||
// CalculateCost 根据计费方式计算费用
|
||
// 支持两种计费方式:
|
||
// - token: 按token计费,分别计算输入和输出token的费用
|
||
// - request: 按请求次数计费,使用固定价格
|
||
func (cc *CostCalculator) CalculateCost(
|
||
billingMethod string,
|
||
requestTokenCount int,
|
||
responseTokenCount int,
|
||
promptTokenPrice float64,
|
||
completionTokenPrice float64,
|
||
fixedPrice float64,
|
||
) float64 {
|
||
switch billingMethod {
|
||
case models.BillingMethodToken:
|
||
cost := float64(requestTokenCount)*promptTokenPrice/1_000_000 + float64(responseTokenCount)*completionTokenPrice/1_000_000
|
||
return cost
|
||
case models.BillingMethodRequest:
|
||
return fixedPrice
|
||
default:
|
||
return 0
|
||
}
|
||
}
|
||
|
||
// CalculateModelCost 根据后端模型配置计算费用
|
||
func (cc *CostCalculator) CalculateModelCost(
|
||
backendModel *models.BackendModel,
|
||
requestTokenCount int,
|
||
responseTokenCount int,
|
||
) float64 {
|
||
return cc.CalculateCost(
|
||
backendModel.BillingMethod,
|
||
requestTokenCount,
|
||
responseTokenCount,
|
||
backendModel.PromptTokenPrice,
|
||
backendModel.CompletionTokenPrice,
|
||
backendModel.FixedPrice,
|
||
)
|
||
}
|
||
|
||
// CalculateMessagesTokensSimple 简化的消息token计算函数(不需要创建calculator实例)
|
||
func CalculateMessagesTokensSimple(messages []ChatCompletionMessage) int {
|
||
encoding, err := tiktoken.GetEncoding("cl100k_base")
|
||
if err != nil {
|
||
log.Printf("Failed to get tiktoken encoding: %v", err)
|
||
return 0
|
||
}
|
||
|
||
totalTokens := 0
|
||
for _, msg := range messages {
|
||
totalTokens += 4
|
||
totalTokens += len(encoding.Encode(msg.Role, nil, nil))
|
||
totalTokens += len(encoding.Encode(msg.Content, nil, nil))
|
||
}
|
||
totalTokens += 2
|
||
|
||
return totalTokens
|
||
}
|
||
|
||
// CalculateTextTokensSimple 简化的文本token计算函数(不需要创建calculator实例)
|
||
func CalculateTextTokensSimple(text string) int {
|
||
encoding, err := tiktoken.GetEncoding("cl100k_base")
|
||
if err != nil {
|
||
log.Printf("Failed to get tiktoken encoding: %v", err)
|
||
return 0
|
||
}
|
||
|
||
return len(encoding.Encode(text, nil, nil))
|
||
}
|