Files
2025-11-17 11:42:17 +08:00

70 lines
1.8 KiB
Go

package middleware
import (
"ai-gateway/internal/models"
"errors"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"gorm.io/gorm"
)
var jwtSecret = []byte("your-secret-key")
// AuthMiddleware 创建并返回一个API密钥鉴权中间件
func AuthMiddleware(db *gorm.DB, webUIPassword string) gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Missing or invalid Authorization header"})
c.Abort()
return
}
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
// 检查是否为 Web UI 路由
if strings.HasPrefix(c.Request.URL.Path, "/api") {
// Web UI 认证 (JWT)
claims := &jwt.RegisteredClaims{}
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
return jwtSecret, nil
})
if err != nil || !token.Valid {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
c.Abort()
return
}
} else {
// API 密钥认证
var apiKey models.APIKey
if err := db.Where("key = ?", tokenString).First(&apiKey).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Database error"})
}
c.Abort()
return
}
c.Set("apiKey", apiKey)
}
c.Next()
}
}
func GenerateJWT() (string, error) {
claims := jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(jwtSecret)
}