diff --git a/backend/api/provider_models_handlers.go b/backend/api/provider_models_handlers.go new file mode 100644 index 0000000..d2b2c19 --- /dev/null +++ b/backend/api/provider_models_handlers.go @@ -0,0 +1,151 @@ +package api + +import ( + "ai-gateway/internal/db" + "ai-gateway/internal/models" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/gin-gonic/gin" +) + +// GetProviderModelsHandler 处理 GET /api/providers/:id/models 请求 +// 获取指定提供商的可用模型列表 +func (h *APIHandler) GetProviderModelsHandler(c *gin.Context) { + id := c.Param("id") + var providerID uint + if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"}) + return + } + + // 获取提供商信息 + provider, err := db.GetProviderByID(h.DB, providerID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Provider not found"}) + return + } + + // 从提供商获取模型列表 + models, err := fetchProviderModels(provider) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": err.Error(), + "type": "provider_error", + }, + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "provider_id": provider.ID, + "provider_name": provider.Name, + "models": models, + }) +} + +// fetchProviderModels 从提供商API获取模型列表 +func fetchProviderModels(provider *models.Provider) ([]string, error) { + // 构建获取模型列表的URL + modelsURL := provider.BaseURL + "/v1/models" + + // 创建HTTP请求 + req, err := http.NewRequest("GET", modelsURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %v", err) + } + + // 设置认证头 + req.Header.Set("Authorization", "Bearer "+provider.ApiKey) + req.Header.Set("Content-Type", "application/json") + + // 发送请求 + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to connect to provider: %v", err) + } + defer resp.Body.Close() + + // 检查响应状态 + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("provider returned status %d", resp.StatusCode) + } + + // 读取响应 + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read provider response: %v", err) + } + + // 解析响应 - 尝试多种可能的格式 + var models []string + + // 首先尝试标准的OpenAI格式 + var openaiResult struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + Object string `json:"object"` + } + + if err := json.Unmarshal(body, &openaiResult); err == nil && openaiResult.Object == "list" { + for _, model := range openaiResult.Data { + if model.ID != "" { + models = append(models, model.ID) + } + } + if len(models) > 0 { + return models, nil + } + } + + // 尝试其他可能的格式 + // 尝试直接是数组格式 + var arrayResult []struct { + ID string `json:"id"` + } + if err := json.Unmarshal(body, &arrayResult); err == nil { + for _, model := range arrayResult { + if model.ID != "" { + models = append(models, model.ID) + } + } + if len(models) > 0 { + return models, nil + } + } + + // 尝试map格式 + var mapResult map[string]interface{} + if err := json.Unmarshal(body, &mapResult); err == nil { + // 查找包含模型信息的字段 + for key, value := range mapResult { + if key == "models" || key == "data" || key == "list" { + if modelList, ok := value.([]interface{}); ok { + for _, item := range modelList { + if modelMap, ok := item.(map[string]interface{}); ok { + if id, ok := modelMap["id"].(string); ok && id != "" { + models = append(models, id) + } + } + } + } + } + } + if len(models) > 0 { + return models, nil + } + } + + // 如果没有获取到模型,返回错误 + if len(models) == 0 { + return nil, fmt.Errorf("no models found in provider response") + } + + return models, nil +} diff --git a/backend/main.go b/backend/main.go index a68ce2d..f0895e3 100644 --- a/backend/main.go +++ b/backend/main.go @@ -101,6 +101,7 @@ func main() { api_.POST("/providers", handler.CreateProviderHandler) api_.PUT("/providers/:id", handler.UpdateProviderHandler) api_.DELETE("/providers/:id", handler.DeleteProviderHandler) + api_.GET("/providers/:id/models", handler.GetProviderModelsHandler) // Virtual Models api_.POST("/virtual-models", handler.CreateVirtualModelHandler) diff --git a/frontend/src/components/ui/FuzzySearchSelect.jsx b/frontend/src/components/ui/FuzzySearchSelect.jsx new file mode 100644 index 0000000..62dbc7a --- /dev/null +++ b/frontend/src/components/ui/FuzzySearchSelect.jsx @@ -0,0 +1,164 @@ +import React, { useState, useEffect, useRef } from 'react'; + +const FuzzySearchSelect = ({ + value, + onChange, + options = [], + placeholder = "请选择...", + emptyText = "无匹配选项", + loadOptions = null, + loadingText = "加载中...", + errorText = "加载失败" +}) => { + const [isOpen, setIsOpen] = useState(false); + const [searchTerm, setSearchTerm] = useState(value || ''); + const [filteredOptions, setFilteredOptions] = useState(options); + const [loading, setLoading] = useState(false); + const [error, setError] = useState(null); + const wrapperRef = useRef(null); + const loadedOptionsRef = useRef([]); + const hasLoadedRef = useRef(false); + const lastLoadOptionsRef = useRef(null); + + const fuzzySearch = (query, items) => { + if (!query) return items; + const lowerQuery = query.toLowerCase(); + return items.filter(item => { + const lowerItem = item.toLowerCase(); + let queryIndex = 0; + for (let i = 0; i < lowerItem.length && queryIndex < lowerQuery.length; i++) { + if (lowerItem[i] === lowerQuery[queryIndex]) { + queryIndex++; + } + } + return queryIndex === lowerQuery.length; + }); + }; + + // 仅处理初始加载 - 当下拉框打开且从未加载过时 + useEffect(() => { + if (loadOptions && isOpen && !hasLoadedRef.current && !loading) { + const loadData = async () => { + setLoading(true); + setError(null); + try { + const loadedOptions = await loadOptions(); + loadedOptionsRef.current = loadedOptions || []; + hasLoadedRef.current = true; + setFilteredOptions(loadedOptions || []); + } catch (err) { + setError(err.message || '加载选项失败'); + loadedOptionsRef.current = []; + } finally { + setLoading(false); + } + }; + + loadData(); + } + }, [isOpen]); // 仅依赖 isOpen + + // 当 value prop 改变时,同步到 searchTerm(用于初始化或外部更新,但不打开下拉框) + useEffect(() => { + if (value !== undefined && value !== null && value !== searchTerm) { + setSearchTerm(value); + // 不打开下拉框,只更新搜索词显示 + } + }, [value, searchTerm]); + + // 当 loadOptions 函数引用改变时(提供商改变),重置加载状态 + useEffect(() => { + if (lastLoadOptionsRef.current !== loadOptions) { + hasLoadedRef.current = false; + loadedOptionsRef.current = []; + setFilteredOptions([]); + setError(null); + // 不清空 searchTerm,保持用户输入或传入的 value + lastLoadOptionsRef.current = loadOptions; + } + }, [loadOptions]); + + // 实时搜索过滤 - 当搜索词改变时 + useEffect(() => { + if (loadOptions && hasLoadedRef.current) { + // 异步加载模式:从已加载的选项中搜索 + const filtered = fuzzySearch(searchTerm, loadedOptionsRef.current); + setFilteredOptions(filtered); + } else if (!loadOptions && options.length > 0) { + // 同步模式:从传入的 options 中搜索 + const filtered = fuzzySearch(searchTerm, options); + setFilteredOptions(filtered); + } + }, [searchTerm]); // 仅依赖 searchTerm,不依赖 loadOptions 或 options + + useEffect(() => { + function handleClickOutside(event) { + if (wrapperRef.current && !wrapperRef.current.contains(event.target)) { + setIsOpen(false); + } + } + document.addEventListener("mousedown", handleClickOutside); + return () => { + document.removeEventListener("mousedown", handleClickOutside); + }; + }, []); + + const handleSelect = (option) => { + onChange(option); + setSearchTerm(option); + setIsOpen(false); + }; + + const handleInputChange = (e) => { + const value = e.target.value; + setSearchTerm(value); + setIsOpen(true); + }; + + const handleInputFocus = () => { + setIsOpen(true); + }; + + return ( +
+ + + {isOpen && ( +
+ {loading ? ( +
+ {loadingText} +
+ ) : error ? ( +
+ {errorText}: {error} +
+ ) : filteredOptions.length > 0 ? ( + filteredOptions.map((option, index) => ( +
handleSelect(option)} + > + {option} +
+ )) + ) : ( +
+ {emptyText} +
+ )} +
+ )} +
+ ); +}; + +export default FuzzySearchSelect; diff --git a/frontend/src/features/providers/api/providerModels.js b/frontend/src/features/providers/api/providerModels.js new file mode 100644 index 0000000..e5b06cf --- /dev/null +++ b/frontend/src/features/providers/api/providerModels.js @@ -0,0 +1,16 @@ +import apiClient from '../../../lib/api'; + +// 获取提供商的模型列表 +export const getProviderModels = async (providerId) => { + try { + const response = await apiClient.get(`/api/providers/${providerId}/models`); + return response.data; + } catch (error) { + console.error('获取提供商模型列表失败:', error); + const errorMessage = + error.response?.data?.error?.message || + error.message || + '获取提供商模型列表失败'; + throw new Error(errorMessage); + } +}; \ No newline at end of file diff --git a/frontend/src/features/virtual-models/components/DraggableBackendModelCard.jsx b/frontend/src/features/virtual-models/components/DraggableBackendModelCard.jsx index 2b7aa37..5f36428 100644 --- a/frontend/src/features/virtual-models/components/DraggableBackendModelCard.jsx +++ b/frontend/src/features/virtual-models/components/DraggableBackendModelCard.jsx @@ -1,8 +1,14 @@ -import React from 'react'; +import React, { useState, useCallback } from 'react'; import { useSortable } from '@dnd-kit/sortable'; import { CSS } from '@dnd-kit/utilities'; +import FuzzySearchSelect from '../../../components/ui/FuzzySearchSelect'; +import { getProviderModels } from '../../providers/api/providerModels'; const DraggableBackendModelCard = ({ index, bm, providers, onModelChange, onRemove }) => { + const [modelOptions, setModelOptions] = useState([]); + const [loadingModels, setLoadingModels] = useState(false); + const [modelError, setModelError] = useState(null); + const { attributes, listeners, @@ -18,6 +24,30 @@ const DraggableBackendModelCard = ({ index, bm, providers, onModelChange, onRemo opacity: isDragging ? 0.5 : 1, }; + // 加载提供商模型列表 + const loadProviderModels = useCallback(async () => { + if (!bm.ProviderID || bm.ProviderID === 0) { + return []; + } + + try { + const response = await getProviderModels(bm.ProviderID); + return response.models || []; + } catch (error) { + console.error('加载提供商模型失败:', error); + setModelError(error.message); + return []; + } + }, [bm.ProviderID]); + + // 处理提供商变更 + const handleProviderChange = (e) => { + const newProviderID = parseInt(e.target.value); + onModelChange(index, 'ProviderID', newProviderID); + // 清空模型名称,因为提供商变了 + onModelChange(index, 'Name', ''); + }; + return (
提供商 onModelChange(index, 'Name', e.target.value)} - className="mt-1 block w-full rounded-md border-gray-300 shadow-sm" + onChange={(value) => onModelChange(index, 'Name', value)} + loadOptions={loadProviderModels} + placeholder="请选择或搜索模型..." + emptyText="请先选择提供商" + loadingText="正在加载模型列表..." + errorText="加载模型失败" />
diff --git a/frontend/vite.config.js b/frontend/vite.config.js index 7d20252..7795784 100644 --- a/frontend/vite.config.js +++ b/frontend/vite.config.js @@ -12,6 +12,7 @@ export default defineConfig({ proxy: { '/api': { target: 'http://localhost:8080', // 后端服务器地址 + // target: 'http://10.1.39.104:9130', changeOrigin: true, }, },