实现后端模型选择和筛选功能
This commit is contained in:
151
backend/api/provider_models_handlers.go
Normal file
151
backend/api/provider_models_handlers.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"ai-gateway/internal/db"
|
||||
"ai-gateway/internal/models"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GetProviderModelsHandler 处理 GET /api/providers/:id/models 请求
|
||||
// 获取指定提供商的可用模型列表
|
||||
func (h *APIHandler) GetProviderModelsHandler(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var providerID uint
|
||||
if _, err := fmt.Sscanf(id, "%d", &providerID); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取提供商信息
|
||||
provider, err := db.GetProviderByID(h.DB, providerID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Provider not found"})
|
||||
return
|
||||
}
|
||||
|
||||
// 从提供商获取模型列表
|
||||
models, err := fetchProviderModels(provider)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
"type": "provider_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"provider_id": provider.ID,
|
||||
"provider_name": provider.Name,
|
||||
"models": models,
|
||||
})
|
||||
}
|
||||
|
||||
// fetchProviderModels 从提供商API获取模型列表
|
||||
func fetchProviderModels(provider *models.Provider) ([]string, error) {
|
||||
// 构建获取模型列表的URL
|
||||
modelsURL := provider.BaseURL + "/v1/models"
|
||||
|
||||
// 创建HTTP请求
|
||||
req, err := http.NewRequest("GET", modelsURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
// 设置认证头
|
||||
req.Header.Set("Authorization", "Bearer "+provider.ApiKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// 发送请求
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to provider: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 检查响应状态
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("provider returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 读取响应
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read provider response: %v", err)
|
||||
}
|
||||
|
||||
// 解析响应 - 尝试多种可能的格式
|
||||
var models []string
|
||||
|
||||
// 首先尝试标准的OpenAI格式
|
||||
var openaiResult struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
Object string `json:"object"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &openaiResult); err == nil && openaiResult.Object == "list" {
|
||||
for _, model := range openaiResult.Data {
|
||||
if model.ID != "" {
|
||||
models = append(models, model.ID)
|
||||
}
|
||||
}
|
||||
if len(models) > 0 {
|
||||
return models, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试其他可能的格式
|
||||
// 尝试直接是数组格式
|
||||
var arrayResult []struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &arrayResult); err == nil {
|
||||
for _, model := range arrayResult {
|
||||
if model.ID != "" {
|
||||
models = append(models, model.ID)
|
||||
}
|
||||
}
|
||||
if len(models) > 0 {
|
||||
return models, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试map格式
|
||||
var mapResult map[string]interface{}
|
||||
if err := json.Unmarshal(body, &mapResult); err == nil {
|
||||
// 查找包含模型信息的字段
|
||||
for key, value := range mapResult {
|
||||
if key == "models" || key == "data" || key == "list" {
|
||||
if modelList, ok := value.([]interface{}); ok {
|
||||
for _, item := range modelList {
|
||||
if modelMap, ok := item.(map[string]interface{}); ok {
|
||||
if id, ok := modelMap["id"].(string); ok && id != "" {
|
||||
models = append(models, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(models) > 0 {
|
||||
return models, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有获取到模型,返回错误
|
||||
if len(models) == 0 {
|
||||
return nil, fmt.Errorf("no models found in provider response")
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
@@ -101,6 +101,7 @@ func main() {
|
||||
api_.POST("/providers", handler.CreateProviderHandler)
|
||||
api_.PUT("/providers/:id", handler.UpdateProviderHandler)
|
||||
api_.DELETE("/providers/:id", handler.DeleteProviderHandler)
|
||||
api_.GET("/providers/:id/models", handler.GetProviderModelsHandler)
|
||||
|
||||
// Virtual Models
|
||||
api_.POST("/virtual-models", handler.CreateVirtualModelHandler)
|
||||
|
||||
164
frontend/src/components/ui/FuzzySearchSelect.jsx
Normal file
164
frontend/src/components/ui/FuzzySearchSelect.jsx
Normal file
@@ -0,0 +1,164 @@
|
||||
import React, { useState, useEffect, useRef } from 'react';
|
||||
|
||||
const FuzzySearchSelect = ({
|
||||
value,
|
||||
onChange,
|
||||
options = [],
|
||||
placeholder = "请选择...",
|
||||
emptyText = "无匹配选项",
|
||||
loadOptions = null,
|
||||
loadingText = "加载中...",
|
||||
errorText = "加载失败"
|
||||
}) => {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const [searchTerm, setSearchTerm] = useState(value || '');
|
||||
const [filteredOptions, setFilteredOptions] = useState(options);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [error, setError] = useState(null);
|
||||
const wrapperRef = useRef(null);
|
||||
const loadedOptionsRef = useRef([]);
|
||||
const hasLoadedRef = useRef(false);
|
||||
const lastLoadOptionsRef = useRef(null);
|
||||
|
||||
const fuzzySearch = (query, items) => {
|
||||
if (!query) return items;
|
||||
const lowerQuery = query.toLowerCase();
|
||||
return items.filter(item => {
|
||||
const lowerItem = item.toLowerCase();
|
||||
let queryIndex = 0;
|
||||
for (let i = 0; i < lowerItem.length && queryIndex < lowerQuery.length; i++) {
|
||||
if (lowerItem[i] === lowerQuery[queryIndex]) {
|
||||
queryIndex++;
|
||||
}
|
||||
}
|
||||
return queryIndex === lowerQuery.length;
|
||||
});
|
||||
};
|
||||
|
||||
// 仅处理初始加载 - 当下拉框打开且从未加载过时
|
||||
useEffect(() => {
|
||||
if (loadOptions && isOpen && !hasLoadedRef.current && !loading) {
|
||||
const loadData = async () => {
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
try {
|
||||
const loadedOptions = await loadOptions();
|
||||
loadedOptionsRef.current = loadedOptions || [];
|
||||
hasLoadedRef.current = true;
|
||||
setFilteredOptions(loadedOptions || []);
|
||||
} catch (err) {
|
||||
setError(err.message || '加载选项失败');
|
||||
loadedOptionsRef.current = [];
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
loadData();
|
||||
}
|
||||
}, [isOpen]); // 仅依赖 isOpen
|
||||
|
||||
// 当 value prop 改变时,同步到 searchTerm(用于初始化或外部更新,但不打开下拉框)
|
||||
useEffect(() => {
|
||||
if (value !== undefined && value !== null && value !== searchTerm) {
|
||||
setSearchTerm(value);
|
||||
// 不打开下拉框,只更新搜索词显示
|
||||
}
|
||||
}, [value, searchTerm]);
|
||||
|
||||
// 当 loadOptions 函数引用改变时(提供商改变),重置加载状态
|
||||
useEffect(() => {
|
||||
if (lastLoadOptionsRef.current !== loadOptions) {
|
||||
hasLoadedRef.current = false;
|
||||
loadedOptionsRef.current = [];
|
||||
setFilteredOptions([]);
|
||||
setError(null);
|
||||
// 不清空 searchTerm,保持用户输入或传入的 value
|
||||
lastLoadOptionsRef.current = loadOptions;
|
||||
}
|
||||
}, [loadOptions]);
|
||||
|
||||
// 实时搜索过滤 - 当搜索词改变时
|
||||
useEffect(() => {
|
||||
if (loadOptions && hasLoadedRef.current) {
|
||||
// 异步加载模式:从已加载的选项中搜索
|
||||
const filtered = fuzzySearch(searchTerm, loadedOptionsRef.current);
|
||||
setFilteredOptions(filtered);
|
||||
} else if (!loadOptions && options.length > 0) {
|
||||
// 同步模式:从传入的 options 中搜索
|
||||
const filtered = fuzzySearch(searchTerm, options);
|
||||
setFilteredOptions(filtered);
|
||||
}
|
||||
}, [searchTerm]); // 仅依赖 searchTerm,不依赖 loadOptions 或 options
|
||||
|
||||
useEffect(() => {
|
||||
function handleClickOutside(event) {
|
||||
if (wrapperRef.current && !wrapperRef.current.contains(event.target)) {
|
||||
setIsOpen(false);
|
||||
}
|
||||
}
|
||||
document.addEventListener("mousedown", handleClickOutside);
|
||||
return () => {
|
||||
document.removeEventListener("mousedown", handleClickOutside);
|
||||
};
|
||||
}, []);
|
||||
|
||||
const handleSelect = (option) => {
|
||||
onChange(option);
|
||||
setSearchTerm(option);
|
||||
setIsOpen(false);
|
||||
};
|
||||
|
||||
const handleInputChange = (e) => {
|
||||
const value = e.target.value;
|
||||
setSearchTerm(value);
|
||||
setIsOpen(true);
|
||||
};
|
||||
|
||||
const handleInputFocus = () => {
|
||||
setIsOpen(true);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="relative" ref={wrapperRef}>
|
||||
<input
|
||||
type="text"
|
||||
value={searchTerm}
|
||||
onChange={handleInputChange}
|
||||
onFocus={handleInputFocus}
|
||||
placeholder={placeholder}
|
||||
className="mt-1 block w-full rounded-md border-gray-300 shadow-sm focus:border-indigo-500 focus:ring-indigo-500 sm:text-sm"
|
||||
/>
|
||||
|
||||
{isOpen && (
|
||||
<div className="absolute z-10 w-full mt-1 bg-white border border-gray-300 rounded-md shadow-lg max-h-60 overflow-auto">
|
||||
{loading ? (
|
||||
<div className="px-3 py-2 text-gray-500 text-sm">
|
||||
{loadingText}
|
||||
</div>
|
||||
) : error ? (
|
||||
<div className="px-3 py-2 text-red-500 text-sm">
|
||||
{errorText}: {error}
|
||||
</div>
|
||||
) : filteredOptions.length > 0 ? (
|
||||
filteredOptions.map((option, index) => (
|
||||
<div
|
||||
key={index}
|
||||
className="px-3 py-2 hover:bg-gray-100 cursor-pointer text-sm"
|
||||
onClick={() => handleSelect(option)}
|
||||
>
|
||||
{option}
|
||||
</div>
|
||||
))
|
||||
) : (
|
||||
<div className="px-3 py-2 text-gray-500 text-sm">
|
||||
{emptyText}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default FuzzySearchSelect;
|
||||
16
frontend/src/features/providers/api/providerModels.js
Normal file
16
frontend/src/features/providers/api/providerModels.js
Normal file
@@ -0,0 +1,16 @@
|
||||
import apiClient from '../../../lib/api';
|
||||
|
||||
// 获取提供商的模型列表
|
||||
export const getProviderModels = async (providerId) => {
|
||||
try {
|
||||
const response = await apiClient.get(`/api/providers/${providerId}/models`);
|
||||
return response.data;
|
||||
} catch (error) {
|
||||
console.error('获取提供商模型列表失败:', error);
|
||||
const errorMessage =
|
||||
error.response?.data?.error?.message ||
|
||||
error.message ||
|
||||
'获取提供商模型列表失败';
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
};
|
||||
@@ -1,8 +1,14 @@
|
||||
import React from 'react';
|
||||
import React, { useState, useCallback } from 'react';
|
||||
import { useSortable } from '@dnd-kit/sortable';
|
||||
import { CSS } from '@dnd-kit/utilities';
|
||||
import FuzzySearchSelect from '../../../components/ui/FuzzySearchSelect';
|
||||
import { getProviderModels } from '../../providers/api/providerModels';
|
||||
|
||||
const DraggableBackendModelCard = ({ index, bm, providers, onModelChange, onRemove }) => {
|
||||
const [modelOptions, setModelOptions] = useState([]);
|
||||
const [loadingModels, setLoadingModels] = useState(false);
|
||||
const [modelError, setModelError] = useState(null);
|
||||
|
||||
const {
|
||||
attributes,
|
||||
listeners,
|
||||
@@ -18,6 +24,30 @@ const DraggableBackendModelCard = ({ index, bm, providers, onModelChange, onRemo
|
||||
opacity: isDragging ? 0.5 : 1,
|
||||
};
|
||||
|
||||
// 加载提供商模型列表
|
||||
const loadProviderModels = useCallback(async () => {
|
||||
if (!bm.ProviderID || bm.ProviderID === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await getProviderModels(bm.ProviderID);
|
||||
return response.models || [];
|
||||
} catch (error) {
|
||||
console.error('加载提供商模型失败:', error);
|
||||
setModelError(error.message);
|
||||
return [];
|
||||
}
|
||||
}, [bm.ProviderID]);
|
||||
|
||||
// 处理提供商变更
|
||||
const handleProviderChange = (e) => {
|
||||
const newProviderID = parseInt(e.target.value);
|
||||
onModelChange(index, 'ProviderID', newProviderID);
|
||||
// 清空模型名称,因为提供商变了
|
||||
onModelChange(index, 'Name', '');
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={setNodeRef}
|
||||
@@ -55,7 +85,7 @@ const DraggableBackendModelCard = ({ index, bm, providers, onModelChange, onRemo
|
||||
<label className="block text-sm font-medium text-gray-700">提供商</label>
|
||||
<select
|
||||
value={bm.ProviderID}
|
||||
onChange={(e) => onModelChange(index, 'ProviderID', parseInt(e.target.value))}
|
||||
onChange={handleProviderChange}
|
||||
className="mt-1 block w-full rounded-md border-gray-300 shadow-sm"
|
||||
>
|
||||
<option value={0}>选择提供商</option>
|
||||
@@ -69,11 +99,14 @@ const DraggableBackendModelCard = ({ index, bm, providers, onModelChange, onRemo
|
||||
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-gray-700">后端模型名称</label>
|
||||
<input
|
||||
type="text"
|
||||
<FuzzySearchSelect
|
||||
value={bm.Name}
|
||||
onChange={(e) => onModelChange(index, 'Name', e.target.value)}
|
||||
className="mt-1 block w-full rounded-md border-gray-300 shadow-sm"
|
||||
onChange={(value) => onModelChange(index, 'Name', value)}
|
||||
loadOptions={loadProviderModels}
|
||||
placeholder="请选择或搜索模型..."
|
||||
emptyText="请先选择提供商"
|
||||
loadingText="正在加载模型列表..."
|
||||
errorText="加载模型失败"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -12,6 +12,7 @@ export default defineConfig({
|
||||
proxy: {
|
||||
'/api': {
|
||||
target: 'http://localhost:8080', // 后端服务器地址
|
||||
// target: 'http://10.1.39.104:9130',
|
||||
changeOrigin: true,
|
||||
},
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user