This commit is contained in:
2025-11-08 13:57:34 +08:00
commit f35c236fd0
11 changed files with 802 additions and 0 deletions

279
backend/api/handlers.go Normal file
View File

@@ -0,0 +1,279 @@
package api
import (
"ai-gateway/internal/logger"
"ai-gateway/internal/models"
"ai-gateway/internal/router"
"bytes"
"encoding/json"
"io"
"log"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
"gorm.io/gorm"
)
// APIHandler 持有数据库连接并处理API请求
type APIHandler struct {
DB *gorm.DB
}
// ModelListResponse 符合OpenAI /v1/models API响应格式
type ModelListResponse struct {
Object string `json:"object"`
Data []ModelData `json:"data"`
}
// ModelData 单个模型的数据结构
type ModelData struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
// ChatCompletionRequest 聊天补全请求结构
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
Stream bool `json:"stream,omitempty"`
}
// ChatCompletionMessage 聊天消息结构
type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// ListModels 处理 GET /models 请求
func (h *APIHandler) ListModels(c *gin.Context) {
var virtualModels []models.VirtualModel
// 查询所有虚拟模型
if err := h.DB.Find(&virtualModels).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to retrieve models",
"type": "internal_error",
},
})
return
}
// 格式化为OpenAI API响应格式
modelData := make([]ModelData, len(virtualModels))
for i, vm := range virtualModels {
modelData[i] = ModelData{
ID: vm.Name,
Object: "model",
Created: vm.CreatedAt.Unix(),
OwnedBy: "ai-gateway",
}
}
response := ModelListResponse{
Object: "list",
Data: modelData,
}
c.JSON(http.StatusOK, response)
}
// ChatCompletions 处理 POST /v1/chat/completions 请求
func (h *APIHandler) ChatCompletions(c *gin.Context) {
// 记录请求开始时间
requestTimestamp := time.Now()
var req ChatCompletionRequest
// 解析请求体
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "Invalid request format",
"type": "invalid_request_error",
},
})
return
}
// 使用tiktoken精确计算请求token数
requestTokenCount := calculateTokenCount(req.Messages)
// 选择后端模型
backendModel, err := router.SelectBackendModel(h.DB, req.Model, requestTokenCount)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "invalid_request_error",
},
})
return
}
// 准备转发请求
requestBody, err := json.Marshal(req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to process request",
"type": "internal_error",
},
})
return
}
// 构建后端API URL
backendURL := backendModel.Provider.BaseURL + "/v1/chat/completions"
// 创建HTTP请求
httpReq, err := http.NewRequest("POST", backendURL, bytes.NewReader(requestBody))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to create backend request",
"type": "internal_error",
},
})
return
}
// 设置请求头
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+backendModel.Provider.ApiKey)
// 复制原始请求的其他相关头部
if userAgent := c.GetHeader("User-Agent"); userAgent != "" {
httpReq.Header.Set("User-Agent", userAgent)
}
// 执行请求
client := &http.Client{
Timeout: 120 * time.Second,
}
resp, err := client.Do(httpReq)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"message": "Failed to connect to backend service",
"type": "service_unavailable",
},
})
return
}
defer resp.Body.Close()
// 记录响应时间
responseTimestamp := time.Now()
// 读取响应体
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to read backend response",
"type": "internal_error",
},
})
return
}
// 计算响应token数
responseTokenCount := 0
var responseData map[string]interface{}
if err := json.Unmarshal(responseBody, &responseData); err == nil {
if choices, ok := responseData["choices"].([]interface{}); ok && len(choices) > 0 {
if choice, ok := choices[0].(map[string]interface{}); ok {
if message, ok := choice["message"].(map[string]interface{}); ok {
if content, ok := message["content"].(string); ok {
responseTokenCount = calculateTokenCountFromText(content)
}
}
}
}
}
// 计算费用
var cost float64
switch backendModel.BillingMethod {
case models.BillingMethodToken:
cost = float64(requestTokenCount)*backendModel.PromptTokenPrice + float64(responseTokenCount)*backendModel.CompletionTokenPrice
case models.BillingMethodRequest:
cost = backendModel.FixedPrice
}
// 从上下文获取API密钥
apiKeyValue, exists := c.Get("apiKey")
var apiKeyID uint
if exists {
if apiKey, ok := apiKeyValue.(models.APIKey); ok {
apiKeyID = apiKey.ID
}
}
// 创建日志记录
logEntry := &models.RequestLog{
APIKeyID: apiKeyID,
VirtualModelName: req.Model,
BackendModelName: backendModel.Name,
RequestTimestamp: requestTimestamp,
ResponseTimestamp: responseTimestamp,
RequestTokens: requestTokenCount,
ResponseTokens: responseTokenCount,
Cost: cost,
RequestBody: string(requestBody),
ResponseBody: string(responseBody),
}
// 异步记录日志
logger.LogRequest(h.DB, logEntry)
// 复制响应头
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
// 设置响应状态码并返回响应体
c.Status(resp.StatusCode)
c.Writer.Write(responseBody)
}
// calculateTokenCount 计算消息列表的token总数
func calculateTokenCount(messages []ChatCompletionMessage) int {
encoding, err := tiktoken.GetEncoding("cl100k_base")
if err != nil {
log.Printf("Failed to get tiktoken encoding: %v", err)
return 0
}
totalTokens := 0
for _, msg := range messages {
// 每条消息的基础开销role + 分隔符等)
totalTokens += 4
// role的token数
totalTokens += len(encoding.Encode(msg.Role, nil, nil))
// content的token数
totalTokens += len(encoding.Encode(msg.Content, nil, nil))
}
// 对话的基础开销
totalTokens += 2
return totalTokens
}
// calculateTokenCountFromText 从文本计算token数
func calculateTokenCountFromText(text string) int {
encoding, err := tiktoken.GetEncoding("cl100k_base")
if err != nil {
log.Printf("Failed to get tiktoken encoding: %v", err)
return 0
}
return len(encoding.Encode(text, nil, nil))
}

47
backend/docs/AGENTS.md Normal file
View File

@@ -0,0 +1,47 @@
# AGENTS.md
This file provides guidance to agents when working with code in this repository.
## Project Overview
This project is an AI gateway that routes requests to various backend AI models. It provides a unified API that is compatible with the OpenAI API format.
### Technology Stack
* **Language:** Go
* **Web Framework:** Gin
* **ORM:** GORM
* **Database:** SQLite
### How to Run
To run the project, use the following command:
```bash
go run main.go
```
The server will start on port `8080`.
### Architecture
The project is divided into the following packages:
* `api`: Contains the API handlers that process incoming requests.
* `router`: Implements the logic for selecting the appropriate backend model for a given request.
* `db`: Handles the database initialization and migrations.
* `models`: Defines the database schema.
* `middleware`: Contains the authentication middleware.
* `logger`: Provides asynchronous logging of requests to the database.
### Authentication
The gateway uses API key authentication. The API key must be provided in the `Authorization` header as a Bearer token.
### Routing
The gateway uses a virtual model system to route requests. Each virtual model can be associated with multiple backend models. When a request is received for a virtual model, the gateway selects the backend model with the highest priority that can accommodate the request's token count.
### Logging
All requests are logged to the `RequestLog` table in the database. Logging is performed asynchronously to avoid impacting request processing time.

BIN
backend/gateway.db Normal file

Binary file not shown.

45
backend/go.mod Normal file
View File

@@ -0,0 +1,45 @@
module ai-gateway
go 1.24
require (
github.com/gin-gonic/gin v1.10.1
github.com/pkoukk/tiktoken-go v0.1.8
gorm.io/driver/sqlite v1.5.4
gorm.io/gorm v1.25.5
)
require (
github.com/bytedance/sonic v1.13.3 // indirect
github.com/bytedance/sonic/loader v0.2.4 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/cloudwego/base64x v0.1.5 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.9 // indirect
github.com/gin-contrib/cors v1.7.6 // indirect
github.com/gin-contrib/sse v1.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.26.0 // indirect
github.com/goccy/go-json v0.10.5 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v1.14.17 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.3.0 // indirect
golang.org/x/arch v0.18.0 // indirect
golang.org/x/crypto v0.39.0 // indirect
golang.org/x/net v0.41.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.26.0 // indirect
google.golang.org/protobuf v1.36.6 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

138
backend/go.sum Normal file
View File

@@ -0,0 +1,138 @@
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/bytedance/sonic v1.13.3 h1:MS8gmaH16Gtirygw7jV91pDCN33NyMrPbN7qiYhEsF0=
github.com/bytedance/sonic v1.13.3/go.mod h1:o68xyaF9u2gvVBuGHPlUVCy+ZfmNNO5ETf1+KgkJhz4=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/bytedance/sonic/loader v0.2.4 h1:ZWCw4stuXUsn1/+zQDqeE7JKP+QO47tz7QCNan80NzY=
github.com/bytedance/sonic/loader v0.2.4/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4=
github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gabriel-vasile/mimetype v1.4.9 h1:5k+WDwEsD9eTLL8Tz3L0VnmVh9QxGjRmjBvAG7U/oYY=
github.com/gabriel-vasile/mimetype v1.4.9/go.mod h1:WnSQhFKJuBlRyLiKohA/2DtIlPFAbguNaG7QCHcyGok=
github.com/gin-contrib/cors v1.7.6 h1:3gQ8GMzs1Ylpf70y8bMw4fVpycXIeX1ZemuSQIsnQQY=
github.com/gin-contrib/cors v1.7.6/go.mod h1:Ulcl+xN4jel9t1Ry8vqph23a60FwH9xVLd+3ykmTjOk=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/gin-gonic/gin v1.10.1 h1:T0ujvqyCSqRopADpgPgiTT63DUQVSfojyME59Ei63pQ=
github.com/gin-gonic/gin v1.10.1/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc/iMaVtFbr3Sw2k=
github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE=
github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo=
github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/ugorji/go/codec v1.3.0 h1:Qd2W2sQawAfG8XSvzwhBeoGq71zXOC/Q1E9y/wUcsUA=
github.com/ugorji/go/codec v1.3.0/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.18.0 h1:WN9poc33zL4AzGxqf8VtpKUnGvMi8O9lhNyBMF/85qc=
golang.org/x/arch v0.18.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk=
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/sqlite v1.5.4 h1:IqXwXi8M/ZlPzH/947tn5uik3aYQslP9BVveoax0nV0=
gorm.io/driver/sqlite v1.5.4/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4=
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

View File

@@ -0,0 +1,31 @@
package db
import (
"ai-gateway/internal/models"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// InitDB 初始化数据库连接并执行自动迁移
func InitDB() (*gorm.DB, error) {
// 使用SQLite驱动连接到gateway.db数据库文件
db, err := gorm.Open(sqlite.Open("gateway.db"), &gorm.Config{})
if err != nil {
return nil, err
}
// 自动迁移所有模型,创建相应的表
err = db.AutoMigrate(
&models.Provider{},
&models.APIKey{},
&models.VirtualModel{},
&models.BackendModel{},
&models.RequestLog{},
)
if err != nil {
return nil, err
}
return db, nil
}

View File

@@ -0,0 +1,17 @@
package logger
import (
"ai-gateway/internal/models"
"log"
"gorm.io/gorm"
)
// LogRequest 异步记录API请求日志
func LogRequest(db *gorm.DB, logEntry *models.RequestLog) {
go func() {
if err := db.Create(logEntry).Error; err != nil {
log.Printf("Failed to save request log: %v", err)
}
}()
}

View File

@@ -0,0 +1,77 @@
package middleware
import (
"ai-gateway/internal/models"
"errors"
"log"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// AuthMiddleware 创建并返回一个API密钥鉴权中间件
func AuthMiddleware(db *gorm.DB) gin.HandlerFunc {
return func(c *gin.Context) {
// 从请求头获取 Authorization 值
authHeader := c.GetHeader("Authorization")
// 检查是否为空或不以 "Bearer " 开头
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
log.Printf("[Auth] Failed: Missing or invalid Authorization header. IP: %s", c.ClientIP())
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"message": "Missing or invalid Authorization header",
"type": "authentication_error",
},
})
c.Abort()
return
}
// 提取 Bearer token
token := strings.TrimPrefix(authHeader, "Bearer ")
if token == "" {
log.Printf("[Auth] Failed: Missing API key. IP: %s", c.ClientIP())
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"message": "Missing API key",
"type": "authentication_error",
},
})
c.Abort()
return
}
// 在数据库中查询API密钥
var apiKey models.APIKey
if err := db.Where("key = ?", token).First(&apiKey).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Printf("[Auth] Failed: Invalid API key. Key: %s, IP: %s", token, c.ClientIP())
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"message": "Invalid API key",
"type": "authentication_error",
},
})
} else {
log.Printf("[Auth] Failed: Database error during authentication. Key: %s, IP: %s, Error: %v", token, c.ClientIP(), err)
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Failed to authenticate",
"type": "internal_error",
},
})
}
c.Abort()
return
}
// 将API密钥对象存入上下文供后续处理器使用
c.Set("apiKey", apiKey)
// 传递给下一个处理器
c.Next()
}
}

View File

@@ -0,0 +1,64 @@
package models
import (
"time"
"gorm.io/gorm"
)
// BillingMethod 计费方式常量
const (
BillingMethodToken = "token" // 按token计费
BillingMethodRequest = "request" // 按请求次数计费
)
// Provider 代表一个AI服务提供商
type Provider struct {
gorm.Model
Name string `gorm:"uniqueIndex;not null"` // 服务商名称,唯一索引
BaseURL string `gorm:"not null"` // API基础URL
ApiKey string `gorm:"not null"` // API密钥
}
// APIKey 用于网关本身的API认证
type APIKey struct {
gorm.Model
Key string `gorm:"uniqueIndex;not null"` // API密钥字符串唯一索引
}
// VirtualModel 用户与之交互的虚拟模型
type VirtualModel struct {
gorm.Model
Name string `gorm:"uniqueIndex;not null"` // 虚拟模型名称,唯一索引
BackendModels []BackendModel `gorm:"foreignKey:VirtualModelID"` // 关联的后端模型列表
}
// BackendModel 实际的后端AI模型
type BackendModel struct {
gorm.Model
VirtualModelID uint `gorm:"index;not null"` // 关联的虚拟模型ID
ProviderID uint `gorm:"index;not null"` // 关联的服务商ID
Provider Provider `gorm:"foreignKey:ProviderID"` // GORM关联
Name string `gorm:"not null"` // 后端模型名称
Priority int `gorm:"not null"` // 优先级(数字越小优先级越高)
MaxContextLength int `gorm:"not null"` // 最大上下文长度
BillingMethod string `gorm:"not null"` // 计费方式
PromptTokenPrice float64 `gorm:"type:decimal(10,6)"` // 输入token单价
CompletionTokenPrice float64 `gorm:"type:decimal(10,6)"` // 输出token单价
FixedPrice float64 `gorm:"type:decimal(10,2)"` // 固定价格(按次计费)
}
// RequestLog 记录每次API请求的详细信息
type RequestLog struct {
gorm.Model
APIKeyID uint `gorm:"index"` // API密钥ID
VirtualModelName string `gorm:"index"` // 虚拟模型名称
BackendModelName string `gorm:"index"` // 后端模型名称
RequestTimestamp time.Time `gorm:"index;not null"` // 请求时间戳
ResponseTimestamp time.Time `gorm:"not null"` // 响应时间戳
RequestTokens int `gorm:"default:0"` // 请求token数
ResponseTokens int `gorm:"default:0"` // 响应token数
Cost float64 `gorm:"type:decimal(10,6)"` // 成本
RequestBody string `gorm:"type:text"` // 请求体
ResponseBody string `gorm:"type:text"` // 响应体
}

View File

@@ -0,0 +1,52 @@
package router
import (
"ai-gateway/internal/models"
"errors"
"sort"
"gorm.io/gorm"
)
// SelectBackendModel 根据虚拟模型名称和请求token数量选择合适的后端模型
func SelectBackendModel(db *gorm.DB, virtualModelName string, requestTokenCount int) (*models.BackendModel, error) {
// 查找虚拟模型并预加载关联的后端模型及其服务商信息
var virtualModel models.VirtualModel
err := db.Where("name = ?", virtualModelName).
Preload("BackendModels.Provider").
Preload("BackendModels").
First(&virtualModel).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("virtual model not found: " + virtualModelName)
}
return nil, err
}
// 如果没有关联的后端模型
if len(virtualModel.BackendModels) == 0 {
return nil, errors.New("no backend models configured for virtual model: " + virtualModelName)
}
// 筛选满足上下文长度要求的模型
var suitableModels []models.BackendModel
for _, backendModel := range virtualModel.BackendModels {
if backendModel.MaxContextLength >= requestTokenCount {
suitableModels = append(suitableModels, backendModel)
}
}
// 如果没有满足条件的模型
if len(suitableModels) == 0 {
return nil, errors.New("no suitable backend model found")
}
// 按优先级排序Priority值越小优先级越高
sort.Slice(suitableModels, func(i, j int) bool {
return suitableModels[i].Priority < suitableModels[j].Priority
})
// 返回优先级最高的模型
return &suitableModels[0], nil
}

52
backend/main.go Normal file
View File

@@ -0,0 +1,52 @@
package main
import (
"ai-gateway/api"
"ai-gateway/internal/db"
"ai-gateway/internal/middleware"
"log"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
)
func main() {
// 初始化数据库连接
database, err := db.InitDB()
if err != nil {
log.Fatalf("Failed to initialize database: %v", err)
}
// 创建Gin路由器实例
router := gin.Default()
// 配置CORS中间件 - 采用非常宽松的策略
config := cors.DefaultConfig()
config.AllowAllOrigins = true
config.AllowMethods = []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"}
// 允许所有类型的头
config.AllowHeaders = []string{"Origin", "Content-Length", "Content-Type", "Authorization"}
config.AllowWildcard = true
config.AllowCredentials = true
router.Use(cors.New(config))
// 创建API处理器
handler := &api.APIHandler{
DB: database,
}
// 创建受保护的路由组
protected := router.Group("/")
protected.Use(middleware.AuthMiddleware(database))
{
protected.GET("/models", handler.ListModels)
protected.POST("/v1/chat/completions", handler.ChatCompletions)
}
// 启动HTTP服务器
log.Println("Starting AI Gateway server on :8080")
if err := router.Run(":8080"); err != nil {
log.Fatalf("Failed to start server: %v", err)
}
}