Files
AIRouter/backend/internal/router/selector.go

91 lines
2.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package router
import (
"ai-gateway/internal/models"
"errors"
"sort"
"gorm.io/gorm"
)
// SelectBackendModel 根据虚拟模型名称和请求token数量选择合适的后端模型
func SelectBackendModel(db *gorm.DB, virtualModelName string, requestTokenCount int) (*models.BackendModel, error) {
// 查找虚拟模型并预加载关联的后端模型及其服务商信息
var virtualModel models.VirtualModel
err := db.Where("name = ?", virtualModelName).
Preload("BackendModels.Provider").
Preload("BackendModels").
First(&virtualModel).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("virtual model not found: " + virtualModelName)
}
return nil, err
}
// 如果没有关联的后端模型
if len(virtualModel.BackendModels) == 0 {
return nil, errors.New("no backend models configured for virtual model: " + virtualModelName)
}
// 筛选满足上下文长度要求的模型
var suitableModels []models.BackendModel
for _, backendModel := range virtualModel.BackendModels {
if backendModel.MaxContextLength >= requestTokenCount {
suitableModels = append(suitableModels, backendModel)
}
}
// 如果没有满足条件的模型
if len(suitableModels) == 0 {
return nil, errors.New("no suitable backend model found")
}
// 按优先级排序Priority值越小优先级越高
sort.Slice(suitableModels, func(i, j int) bool {
return suitableModels[i].Priority < suitableModels[j].Priority
})
// 选择合适的模型(考虑每个后端模型的成本阈值)
// 估算响应token数假设等于请求token数
estimatedResponseTokens := requestTokenCount
var selectedModel *models.BackendModel
// 按优先级遍历模型,选择第一个满足成本阈值的模型
for i := range suitableModels {
model := &suitableModels[i]
// 计算估算成本
var estimatedCost float64
switch model.BillingMethod {
case models.BillingMethodToken:
estimatedCost = float64(requestTokenCount)*model.PromptTokenPrice +
float64(estimatedResponseTokens)*model.CompletionTokenPrice
case models.BillingMethodRequest:
estimatedCost = model.FixedPrice
}
// 如果该模型设置了成本阈值,检查成本是否超过阈值
if model.CostThreshold > 0 {
// 如果成本超过该模型的阈值,跳过该模型
if estimatedCost > model.CostThreshold {
continue
}
}
// 找到第一个满足条件的模型(未设置阈值或成本在阈值内)
selectedModel = model
break
}
// 如果所有模型都超过了各自的阈值,返回最后一个模型作为兜底
if selectedModel == nil {
selectedModel = &suitableModels[len(suitableModels)-1]
}
// 返回选中的模型
return selectedModel, nil
}