diff --git a/backend/internal/billing/calculator.go b/backend/internal/billing/calculator.go index e6a4988..cad5c2f 100644 --- a/backend/internal/billing/calculator.go +++ b/backend/internal/billing/calculator.go @@ -76,11 +76,11 @@ func (cc *CostCalculator) CalculateCost( ) float64 { switch billingMethod { case models.BillingMethodToken: - return float64(requestTokenCount)*promptTokenPrice + float64(responseTokenCount)*completionTokenPrice + cost := float64(requestTokenCount)*promptTokenPrice/1_000_000 + float64(responseTokenCount)*completionTokenPrice/1_000_000 + return cost case models.BillingMethodRequest: return fixedPrice default: - log.Printf("Unknown billing method: %s, defaulting to 0 cost", billingMethod) return 0 } } diff --git a/backend/internal/router/selector.go b/backend/internal/router/selector.go index bda72f6..48ae552 100644 --- a/backend/internal/router/selector.go +++ b/backend/internal/router/selector.go @@ -1,6 +1,7 @@ package router import ( + "ai-gateway/internal/billing" "ai-gateway/internal/models" "errors" "sort" @@ -48,24 +49,15 @@ func SelectBackendModel(db *gorm.DB, virtualModelName string, requestTokenCount }) // 选择合适的模型(考虑每个后端模型的成本阈值) - // 估算响应token数(假设等于请求token数) - estimatedResponseTokens := requestTokenCount - var selectedModel *models.BackendModel // 按优先级遍历模型,选择第一个满足成本阈值的模型 + costCalculator := billing.NewCostCalculator() for i := range suitableModels { model := &suitableModels[i] - // 计算估算成本 - var estimatedCost float64 - switch model.BillingMethod { - case models.BillingMethodToken: - estimatedCost = float64(requestTokenCount)*model.PromptTokenPrice + - float64(estimatedResponseTokens)*model.CompletionTokenPrice - case models.BillingMethodRequest: - estimatedCost = model.FixedPrice - } + // 使用封装的计算器计算估算成本 + estimatedCost := costCalculator.CalculateModelCost(model, requestTokenCount, 0) // 如果CostThreshold不为0,则表示设置了成本阈值 if model.CostThreshold != 0 {