diff --git a/backend/api/handlers.go b/backend/api/handlers.go index 084573d..e29a844 100644 --- a/backend/api/handlers.go +++ b/backend/api/handlers.go @@ -57,6 +57,27 @@ type ResponsesRequest struct { Stream bool `json:"stream,omitempty"` } +// BackendModelAssociation 代表一个后端模型关联及其配置 +type BackendModelAssociation struct { + BackendModelID uint `json:"backend_model_id" binding:"required"` + Priority int `json:"priority" binding:"required"` + CostThreshold float64 `json:"cost_threshold"` +} + +// CreateVirtualModelRequest 创建虚拟模型的请求结构 +type CreateVirtualModelRequest struct { + Name string `json:"name" binding:"required"` + Description string `json:"description"` + BackendModels []BackendModelAssociation `json:"backend_models"` +} + +// UpdateVirtualModelRequest 更新虚拟模型的请求结构 +type UpdateVirtualModelRequest struct { + Name string `json:"name" binding:"required"` + Description string `json:"description"` + BackendModels []BackendModelAssociation `json:"backend_models"` +} + // ListModels 处理 GET /models 请求 func (h *APIHandler) ListModels(c *gin.Context) { var virtualModels []models.VirtualModel @@ -465,6 +486,81 @@ func (h *APIHandler) GetProvidersHandler(c *gin.Context) { c.JSON(http.StatusOK, providers) } +// GetProviderHandler 处理 GET /api/providers/:id 请求 +func (h *APIHandler) GetProviderHandler(c *gin.Context) { + id := c.Param("id") + var providerID uint + if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"}) + return + } + + provider, err := db.GetProviderByID(h.DB, providerID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Provider not found"}) + return + } + + c.JSON(http.StatusOK, provider) +} + +// CreateProviderHandler 处理 POST /api/providers 请求 +func (h *APIHandler) CreateProviderHandler(c *gin.Context) { + var provider models.Provider + if err := c.ShouldBindJSON(&provider); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if err := db.CreateProvider(h.DB, &provider); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider"}) + return + } + + c.JSON(http.StatusCreated, provider) +} + +// UpdateProviderHandler 处理 PUT /api/providers/:id 请求 +func (h *APIHandler) UpdateProviderHandler(c *gin.Context) { + id := c.Param("id") + var providerID uint + if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"}) + return + } + + var provider models.Provider + if err := c.ShouldBindJSON(&provider); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + provider.ID = providerID + + if err := db.UpdateProvider(h.DB, &provider); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider"}) + return + } + + c.JSON(http.StatusOK, provider) +} + +// DeleteProviderHandler 处理 DELETE /api/providers/:id 请求 +func (h *APIHandler) DeleteProviderHandler(c *gin.Context) { + id := c.Param("id") + var providerID uint + if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"}) + return + } + + if err := db.DeleteProvider(h.DB, providerID); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete provider"}) + return + } + + c.JSON(http.StatusNoContent, nil) +} + // CreateVirtualModelHandler 处理 POST /api/virtual-models 请求 func (h *APIHandler) CreateVirtualModelHandler(c *gin.Context) { var vm models.VirtualModel diff --git a/backend/gateway.db b/backend/gateway.db index d09b811..b07dd92 100644 Binary files a/backend/gateway.db and b/backend/gateway.db differ diff --git a/backend/internal/db/database.go b/backend/internal/db/database.go index 5ea402f..5e4adc1 100644 --- a/backend/internal/db/database.go +++ b/backend/internal/db/database.go @@ -41,6 +41,30 @@ func GetProviders(db *gorm.DB) ([]models.Provider, error) { return providers, nil } +// GetProviderByID 通过ID获取服务商 +func GetProviderByID(db *gorm.DB, id uint) (*models.Provider, error) { + var provider models.Provider + if err := db.First(&provider, id).Error; err != nil { + return nil, err + } + return &provider, nil +} + +// CreateProvider 创建一个新的服务商 +func CreateProvider(db *gorm.DB, provider *models.Provider) error { + return db.Create(provider).Error +} + +// UpdateProvider 更新一个服务商 +func UpdateProvider(db *gorm.DB, provider *models.Provider) error { + return db.Save(provider).Error +} + +// DeleteProvider 删除一个服务商 +func DeleteProvider(db *gorm.DB, id uint) error { + return db.Delete(&models.Provider{}, id).Error +} + // CreateVirtualModel 创建一个新的虚拟模型 func CreateVirtualModel(db *gorm.DB, virtualModel *models.VirtualModel) error { return db.Create(virtualModel).Error diff --git a/backend/internal/models/schema.go b/backend/internal/models/schema.go index a79c39d..da8ce0d 100644 --- a/backend/internal/models/schema.go +++ b/backend/internal/models/schema.go @@ -29,23 +29,24 @@ type APIKey struct { // VirtualModel 用户与之交互的虚拟模型 type VirtualModel struct { gorm.Model - Name string `gorm:"uniqueIndex;not null"` // 虚拟模型名称,唯一索引 + Name string `gorm:"uniqueIndex;not null"` // 虚拟模型名称,唯一索引 BackendModels []BackendModel `gorm:"foreignKey:VirtualModelID"` // 关联的后端模型列表 } // BackendModel 实际的后端AI模型 type BackendModel struct { gorm.Model - VirtualModelID uint `gorm:"index;not null"` // 关联的虚拟模型ID - ProviderID uint `gorm:"index;not null"` // 关联的服务商ID - Provider Provider `gorm:"foreignKey:ProviderID"` // GORM关联 - Name string `gorm:"not null"` // 后端模型名称 - Priority int `gorm:"not null"` // 优先级(数字越小优先级越高) - MaxContextLength int `gorm:"not null"` // 最大上下文长度 - BillingMethod string `gorm:"not null"` // 计费方式 - PromptTokenPrice float64 `gorm:"type:decimal(10,6)"` // 输入token单价 - CompletionTokenPrice float64 `gorm:"type:decimal(10,6)"` // 输出token单价 - FixedPrice float64 `gorm:"type:decimal(10,2)"` // 固定价格(按次计费) + VirtualModelID uint `gorm:"index;not null"` // 关联的虚拟模型ID + ProviderID uint `gorm:"index;not null"` // 关联的服务商ID + Provider Provider `gorm:"foreignKey:ProviderID"` // GORM关联 + Name string `gorm:"not null"` // 后端模型名称 + Priority int `gorm:"not null"` // 优先级(数字越小优先级越高) + MaxContextLength int `gorm:"not null"` // 最大上下文长度 + BillingMethod string `gorm:"not null"` // 计费方式 + PromptTokenPrice float64 `gorm:"type:decimal(10,6)"` // 输入token单价 + CompletionTokenPrice float64 `gorm:"type:decimal(10,6)"` // 输出token单价 + FixedPrice float64 `gorm:"type:decimal(10,2)"` // 固定价格(按次计费) + CostThreshold float64 `gorm:"type:decimal(10,6)"` // 成本阈值 } // RequestLog 记录每次API请求的详细信息 @@ -61,4 +62,4 @@ type RequestLog struct { Cost float64 `gorm:"type:decimal(10,6)"` // 成本 RequestBody string `gorm:"type:text"` // 请求体 ResponseBody string `gorm:"type:text"` // 响应体 -} \ No newline at end of file +} diff --git a/backend/internal/router/selector.go b/backend/internal/router/selector.go index a6a65bd..a7cf55b 100644 --- a/backend/internal/router/selector.go +++ b/backend/internal/router/selector.go @@ -47,6 +47,44 @@ func SelectBackendModel(db *gorm.DB, virtualModelName string, requestTokenCount return suitableModels[i].Priority < suitableModels[j].Priority }) - // 返回优先级最高的模型 - return &suitableModels[0], nil -} \ No newline at end of file + // 选择合适的模型(考虑每个后端模型的成本阈值) + // 估算响应token数(假设等于请求token数) + estimatedResponseTokens := requestTokenCount + + var selectedModel *models.BackendModel + + // 按优先级遍历模型,选择第一个满足成本阈值的模型 + for i := range suitableModels { + model := &suitableModels[i] + + // 计算估算成本 + var estimatedCost float64 + switch model.BillingMethod { + case models.BillingMethodToken: + estimatedCost = float64(requestTokenCount)*model.PromptTokenPrice + + float64(estimatedResponseTokens)*model.CompletionTokenPrice + case models.BillingMethodRequest: + estimatedCost = model.FixedPrice + } + + // 如果该模型设置了成本阈值,检查成本是否超过阈值 + if model.CostThreshold > 0 { + // 如果成本超过该模型的阈值,跳过该模型 + if estimatedCost > model.CostThreshold { + continue + } + } + + // 找到第一个满足条件的模型(未设置阈值或成本在阈值内) + selectedModel = model + break + } + + // 如果所有模型都超过了各自的阈值,返回最后一个模型作为兜底 + if selectedModel == nil { + selectedModel = &suitableModels[len(suitableModels)-1] + } + + // 返回选中的模型 + return selectedModel, nil +} diff --git a/backend/main.go b/backend/main.go index 19b38c4..1c7d073 100644 --- a/backend/main.go +++ b/backend/main.go @@ -52,7 +52,12 @@ func main() { api_ := router.Group("/api") api_.Use(middleware.AuthMiddleware(database)) { + // Providers api_.GET("/providers", handler.GetProvidersHandler) + api_.GET("/providers/:id", handler.GetProviderHandler) + api_.POST("/providers", handler.CreateProviderHandler) + api_.PUT("/providers/:id", handler.UpdateProviderHandler) + api_.DELETE("/providers/:id", handler.DeleteProviderHandler) // Virtual Models api_.POST("/virtual-models", handler.CreateVirtualModelHandler) diff --git a/frontend/src/features/providers/api/index.js b/frontend/src/features/providers/api/index.js index b5aa637..6d571c2 100644 --- a/frontend/src/features/providers/api/index.js +++ b/frontend/src/features/providers/api/index.js @@ -13,4 +13,43 @@ export const getProviders = async () => { '获取服务商列表失败'; throw new Error(errorMessage); } +}; + +export const getProvider = async (id) => { + try { + const response = await apiClient.get(`/api/providers/${id}`); + return response.data; + } catch (error) { + console.error('获取服务商详情失败:', error); + throw new Error(error.response?.data?.error || '获取服务商详情失败'); + } +}; + +export const createProvider = async (providerData) => { + try { + const response = await apiClient.post('/api/providers', providerData); + return response.data; + } catch (error) { + console.error('创建服务商失败:', error); + throw new Error(error.response?.data?.error || '创建服务商失败'); + } +}; + +export const updateProvider = async (id, providerData) => { + try { + const response = await apiClient.put(`/api/providers/${id}`, providerData); + return response.data; + } catch (error) { + console.error('更新服务商失败:', error); + throw new Error(error.response?.data?.error || '更新服务商失败'); + } +}; + +export const deleteProvider = async (id) => { + try { + await apiClient.delete(`/api/providers/${id}`); + } catch (error) { + console.error('删除服务商失败:', error); + throw new Error(error.response?.data?.error || '删除服务商失败'); + } }; \ No newline at end of file diff --git a/frontend/src/features/providers/components/ProviderForm.jsx b/frontend/src/features/providers/components/ProviderForm.jsx new file mode 100644 index 0000000..7d17cb7 --- /dev/null +++ b/frontend/src/features/providers/components/ProviderForm.jsx @@ -0,0 +1,94 @@ +import React, { useState, useEffect } from 'react'; + +const ProviderForm = ({ provider, onSave, onCancel }) => { + const [name, setName] = useState(''); + const [baseURL, setBaseURL] = useState(''); + const [apiKey, setAPIKey] = useState(''); + const [apiVersion, setAPIVersion] = useState(''); + + useEffect(() => { + if (provider) { + setName(provider.Name || ''); + setBaseURL(provider.BaseURL || ''); + setAPIKey(provider.APIKey || ''); + setAPIVersion(provider.APIVersion || ''); + } + }, [provider]); + + const handleSubmit = (e) => { + e.preventDefault(); + onSave({ + Name: name, + BaseURL: baseURL, + APIKey: apiKey, + APIVersion: apiVersion, + }); + }; + + return ( +
+
+ + setName(e.target.value)} + required + className="mt-1 block w-full rounded-md border-gray-300 shadow-sm focus:border-indigo-500 focus:ring-indigo-500 sm:text-sm px-3 py-2 border" + /> +
+ +
+ + setBaseURL(e.target.value)} + required + placeholder="https://api.example.com" + className="mt-1 block w-full rounded-md border-gray-300 shadow-sm focus:border-indigo-500 focus:ring-indigo-500 sm:text-sm px-3 py-2 border" + /> +
+ +
+ + setAPIKey(e.target.value)} + required + className="mt-1 block w-full rounded-md border-gray-300 shadow-sm focus:border-indigo-500 focus:ring-indigo-500 sm:text-sm px-3 py-2 border" + /> +
+ +
+ + setAPIVersion(e.target.value)} + placeholder="v1" + className="mt-1 block w-full rounded-md border-gray-300 shadow-sm focus:border-indigo-500 focus:ring-indigo-500 sm:text-sm px-3 py-2 border" + /> +
+ +
+ + +
+
+ ); +}; + +export default ProviderForm; \ No newline at end of file diff --git a/frontend/src/features/providers/components/ProviderList.jsx b/frontend/src/features/providers/components/ProviderList.jsx index 761639b..e5e0acc 100644 --- a/frontend/src/features/providers/components/ProviderList.jsx +++ b/frontend/src/features/providers/components/ProviderList.jsx @@ -1,29 +1,64 @@ import React, { useState, useEffect } from 'react'; -import { getProviders } from '../api'; +import { getProviders, createProvider, updateProvider, deleteProvider } from '../api'; +import ProviderForm from './ProviderForm'; const ProviderList = () => { const [providers, setProviders] = useState([]); const [loading, setLoading] = useState(true); const [error, setError] = useState(null); + const [editingProvider, setEditingProvider] = useState(null); + const [showCreateForm, setShowCreateForm] = useState(false); + + const fetchProviders = async () => { + try { + setLoading(true); + const data = await getProviders(); + setProviders(data); + setError(null); + } catch (err) { + setError(err.message || '获取提供商列表失败'); + console.error('Error fetching providers:', err); + } finally { + setLoading(false); + } + }; useEffect(() => { - const fetchProviders = async () => { - try { - setLoading(true); - const data = await getProviders(); - setProviders(data); - setError(null); - } catch (err) { - setError(err.message || '获取提供商列表失败'); - console.error('Error fetching providers:', err); - } finally { - setLoading(false); - } - }; - fetchProviders(); }, []); + const handleCreate = async (providerData) => { + try { + await createProvider(providerData); + setShowCreateForm(false); + await fetchProviders(); + } catch (err) { + setError(err.message || '创建提供商失败'); + } + }; + + const handleUpdate = async (providerData) => { + try { + await updateProvider(editingProvider.ID, providerData); + setEditingProvider(null); + await fetchProviders(); + } catch (err) { + setError(err.message || '更新提供商失败'); + } + }; + + const handleDelete = async (id) => { + if (!window.confirm('确定要删除这个提供商吗?')) { + return; + } + try { + await deleteProvider(id); + await fetchProviders(); + } catch (err) { + setError(err.message || '删除提供商失败'); + } + }; + if (loading) { return (
@@ -40,8 +75,43 @@ const ProviderList = () => { ); } + if (showCreateForm) { + return ( +
+

创建新提供商

+ setShowCreateForm(false)} + /> +
+ ); + } + + if (editingProvider) { + return ( +
+

编辑提供商

+ setEditingProvider(null)} + /> +
+ ); + } + return ( -
+
+
+ +
+ +
@@ -49,7 +119,7 @@ const ProviderList = () => { Name ) : ( providers.map((provider, index) => ( - + @@ -95,6 +165,7 @@ const ProviderList = () => { )}
- Type + Base URL Actions @@ -65,7 +135,7 @@ const ProviderList = () => {
{provider.Name || 'N/A'} @@ -79,15 +149,15 @@ const ProviderList = () => {
+
); }; diff --git a/frontend/src/features/virtual-models/components/VirtualModelForm.jsx b/frontend/src/features/virtual-models/components/VirtualModelForm.jsx index 3018aba..7231ff1 100644 --- a/frontend/src/features/virtual-models/components/VirtualModelForm.jsx +++ b/frontend/src/features/virtual-models/components/VirtualModelForm.jsx @@ -36,6 +36,7 @@ const VirtualModelForm = ({ model, onSave, onCancel }) => { PromptTokenPrice: 0, CompletionTokenPrice: 0, FixedPrice: 0, + CostThreshold: 0, }, ]); }; @@ -103,6 +104,16 @@ const VirtualModelForm = ({ model, onSave, onCancel }) => { className="mt-1 block w-full rounded-md border-gray-300 shadow-sm" />
+
+ + handleBackendModelChange(index, 'CostThreshold', parseFloat(e.target.value) || 0)} + className="mt-1 block w-full rounded-md border-gray-300 shadow-sm" + /> +
{ 名称 + + 后端模型 + 操作 @@ -56,6 +59,19 @@ const VirtualModelList = ({ onEdit }) => { {vm.Name} + + {vm.BackendModels && vm.BackendModels.length > 0 ? ( +
+ {vm.BackendModels.map((bm, index) => ( +
+ {bm.Name} (阈值: {bm.CostThreshold || 0}) +
+ ))} +
+ ) : ( + + )} +