91 lines
2.7 KiB
Go
91 lines
2.7 KiB
Go
package router
|
||
|
||
import (
|
||
"ai-gateway/internal/models"
|
||
"errors"
|
||
"sort"
|
||
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// SelectBackendModel 根据虚拟模型名称和请求token数量选择合适的后端模型
|
||
func SelectBackendModel(db *gorm.DB, virtualModelName string, requestTokenCount int) (*models.BackendModel, error) {
|
||
// 查找虚拟模型并预加载关联的后端模型及其服务商信息
|
||
var virtualModel models.VirtualModel
|
||
err := db.Where("name = ?", virtualModelName).
|
||
Preload("BackendModels.Provider").
|
||
Preload("BackendModels").
|
||
First(&virtualModel).Error
|
||
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, errors.New("virtual model not found: " + virtualModelName)
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
// 如果没有关联的后端模型
|
||
if len(virtualModel.BackendModels) == 0 {
|
||
return nil, errors.New("no backend models configured for virtual model: " + virtualModelName)
|
||
}
|
||
|
||
// 筛选满足上下文长度要求的模型
|
||
var suitableModels []models.BackendModel
|
||
for _, backendModel := range virtualModel.BackendModels {
|
||
if backendModel.MaxContextLength >= requestTokenCount {
|
||
suitableModels = append(suitableModels, backendModel)
|
||
}
|
||
}
|
||
|
||
// 如果没有满足条件的模型
|
||
if len(suitableModels) == 0 {
|
||
return nil, errors.New("no suitable backend model found")
|
||
}
|
||
|
||
// 按优先级排序(Priority值越小优先级越高)
|
||
sort.Slice(suitableModels, func(i, j int) bool {
|
||
return suitableModels[i].Priority < suitableModels[j].Priority
|
||
})
|
||
|
||
// 选择合适的模型(考虑每个后端模型的成本阈值)
|
||
// 估算响应token数(假设等于请求token数)
|
||
estimatedResponseTokens := requestTokenCount
|
||
|
||
var selectedModel *models.BackendModel
|
||
|
||
// 按优先级遍历模型,选择第一个满足成本阈值的模型
|
||
for i := range suitableModels {
|
||
model := &suitableModels[i]
|
||
|
||
// 计算估算成本
|
||
var estimatedCost float64
|
||
switch model.BillingMethod {
|
||
case models.BillingMethodToken:
|
||
estimatedCost = float64(requestTokenCount)*model.PromptTokenPrice +
|
||
float64(estimatedResponseTokens)*model.CompletionTokenPrice
|
||
case models.BillingMethodRequest:
|
||
estimatedCost = model.FixedPrice
|
||
}
|
||
|
||
// 如果该模型设置了成本阈值,检查成本是否超过阈值
|
||
if model.CostThreshold > 0 {
|
||
// 如果成本超过该模型的阈值,跳过该模型
|
||
if estimatedCost > model.CostThreshold {
|
||
continue
|
||
}
|
||
}
|
||
|
||
// 找到第一个满足条件的模型(未设置阈值或成本在阈值内)
|
||
selectedModel = model
|
||
break
|
||
}
|
||
|
||
// 如果所有模型都超过了各自的阈值,返回最后一个模型作为兜底
|
||
if selectedModel == nil {
|
||
selectedModel = &suitableModels[len(suitableModels)-1]
|
||
}
|
||
|
||
// 返回选中的模型
|
||
return selectedModel, nil
|
||
}
|