Просмотр исходного кода

refactor: enhance API security with read-only token authentication and improved rate limiting

CaIon 6 дней назад
Родитель
Сommit
d814d62e2f
6 измененных файлов с 119 добавлено и 153 удалено
  1. 1 1
      common/constants.go
  2. 25 25
      controller/log.go
  3. 0 88
      controller/secure_verification.go
  4. 57 0
      middleware/auth.go
  5. 31 33
      model/log.go
  6. 5 6
      router/api-router.go

+ 1 - 1
common/constants.go

@@ -39,7 +39,7 @@ var OptionMap map[string]string
 var OptionMapRWMutex sync.RWMutex
 
 var ItemsPerPage = 10
-var MaxRecentItems = 100
+var MaxRecentItems = 1000
 
 var PasswordLoginEnabled = true
 var PasswordRegisterEnabled = true

+ 25 - 25
controller/log.go

@@ -53,40 +53,32 @@ func GetUserLogs(c *gin.Context) {
 	return
 }
 
+// Deprecated: SearchAllLogs 已废弃,前端未使用该接口。
 func SearchAllLogs(c *gin.Context) {
-	keyword := c.Query("keyword")
-	logs, err := model.SearchAllLogs(keyword)
-	if err != nil {
-		common.ApiError(c, err)
-		return
-	}
 	c.JSON(http.StatusOK, gin.H{
-		"success": true,
-		"message": "",
-		"data":    logs,
+		"success": false,
+		"message": "该接口已废弃",
 	})
-	return
 }
 
+// Deprecated: SearchUserLogs 已废弃,前端未使用该接口。
 func SearchUserLogs(c *gin.Context) {
-	keyword := c.Query("keyword")
-	userId := c.GetInt("id")
-	logs, err := model.SearchUserLogs(userId, keyword)
-	if err != nil {
-		common.ApiError(c, err)
-		return
-	}
 	c.JSON(http.StatusOK, gin.H{
-		"success": true,
-		"message": "",
-		"data":    logs,
+		"success": false,
+		"message": "该接口已废弃",
 	})
-	return
 }
 
 func GetLogByKey(c *gin.Context) {
-	key := c.Query("key")
-	logs, err := model.GetLogByKey(key)
+	tokenId := c.GetInt("token_id")
+	if tokenId == 0 {
+		c.JSON(200, gin.H{
+			"success": false,
+			"message": "无效的令牌",
+		})
+		return
+	}
+	logs, err := model.GetLogByTokenId(tokenId)
 	if err != nil {
 		c.JSON(200, gin.H{
 			"success": false,
@@ -110,7 +102,11 @@ func GetLogsStat(c *gin.Context) {
 	modelName := c.Query("model_name")
 	channel, _ := strconv.Atoi(c.Query("channel"))
 	group := c.Query("group")
-	stat := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
+	stat, err := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
 	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
 	c.JSON(http.StatusOK, gin.H{
 		"success": true,
@@ -133,7 +129,11 @@ func GetLogsSelfStat(c *gin.Context) {
 	modelName := c.Query("model_name")
 	channel, _ := strconv.Atoi(c.Query("channel"))
 	group := c.Query("group")
-	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
+	quotaNum, err := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
 	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
 	c.JSON(200, gin.H{
 		"success": true,

+ 0 - 88
controller/secure_verification.go

@@ -133,94 +133,6 @@ func UniversalVerify(c *gin.Context) {
 	})
 }
 
-// GetVerificationStatus 获取验证状态
-func GetVerificationStatus(c *gin.Context) {
-	userId := c.GetInt("id")
-	if userId == 0 {
-		c.JSON(http.StatusUnauthorized, gin.H{
-			"success": false,
-			"message": "未登录",
-		})
-		return
-	}
-
-	session := sessions.Default(c)
-	verifiedAtRaw := session.Get(SecureVerificationSessionKey)
-
-	if verifiedAtRaw == nil {
-		c.JSON(http.StatusOK, gin.H{
-			"success": true,
-			"message": "",
-			"data": VerificationStatusResponse{
-				Verified: false,
-			},
-		})
-		return
-	}
-
-	verifiedAt, ok := verifiedAtRaw.(int64)
-	if !ok {
-		c.JSON(http.StatusOK, gin.H{
-			"success": true,
-			"message": "",
-			"data": VerificationStatusResponse{
-				Verified: false,
-			},
-		})
-		return
-	}
-
-	elapsed := time.Now().Unix() - verifiedAt
-	if elapsed >= SecureVerificationTimeout {
-		// 验证已过期
-		session.Delete(SecureVerificationSessionKey)
-		_ = session.Save()
-		c.JSON(http.StatusOK, gin.H{
-			"success": true,
-			"message": "",
-			"data": VerificationStatusResponse{
-				Verified: false,
-			},
-		})
-		return
-	}
-
-	c.JSON(http.StatusOK, gin.H{
-		"success": true,
-		"message": "",
-		"data": VerificationStatusResponse{
-			Verified:  true,
-			ExpiresAt: verifiedAt + SecureVerificationTimeout,
-		},
-	})
-}
-
-// CheckSecureVerification 检查是否已通过安全验证
-// 返回 true 表示验证有效,false 表示需要重新验证
-func CheckSecureVerification(c *gin.Context) bool {
-	session := sessions.Default(c)
-	verifiedAtRaw := session.Get(SecureVerificationSessionKey)
-
-	if verifiedAtRaw == nil {
-		return false
-	}
-
-	verifiedAt, ok := verifiedAtRaw.(int64)
-	if !ok {
-		return false
-	}
-
-	elapsed := time.Now().Unix() - verifiedAt
-	if elapsed >= SecureVerificationTimeout {
-		// 验证已过期,清除 session
-		session.Delete(SecureVerificationSessionKey)
-		_ = session.Save()
-		return false
-	}
-
-	return true
-}
-
 // PasskeyVerifyAndSetSession Passkey 验证完成后设置 session
 // 这是一个辅助函数,供 PasskeyVerifyFinish 调用
 func PasskeyVerifyAndSetSession(c *gin.Context) {

+ 57 - 0
middleware/auth.go

@@ -168,6 +168,63 @@ func WssAuth(c *gin.Context) {
 
 }
 
+// TokenAuthReadOnly 宽松版本的令牌认证中间件,用于只读查询接口。
+// 只验证令牌 key 是否存在,不检查令牌状态、过期时间和额度。
+// 即使令牌已过期、已耗尽或已禁用,也允许访问。
+// 仍然检查用户是否被封禁。
+func TokenAuthReadOnly() func(c *gin.Context) {
+	return func(c *gin.Context) {
+		key := c.Request.Header.Get("Authorization")
+		if key == "" {
+			c.JSON(http.StatusUnauthorized, gin.H{
+				"success": false,
+				"message": "未提供 Authorization 请求头",
+			})
+			c.Abort()
+			return
+		}
+		if strings.HasPrefix(key, "Bearer ") || strings.HasPrefix(key, "bearer ") {
+			key = strings.TrimSpace(key[7:])
+		}
+		key = strings.TrimPrefix(key, "sk-")
+		parts := strings.Split(key, "-")
+		key = parts[0]
+
+		token, err := model.GetTokenByKey(key, false)
+		if err != nil {
+			c.JSON(http.StatusUnauthorized, gin.H{
+				"success": false,
+				"message": "无效的令牌",
+			})
+			c.Abort()
+			return
+		}
+
+		userCache, err := model.GetUserCache(token.UserId)
+		if err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{
+				"success": false,
+				"message": err.Error(),
+			})
+			c.Abort()
+			return
+		}
+		if userCache.Status != common.UserStatusEnabled {
+			c.JSON(http.StatusForbidden, gin.H{
+				"success": false,
+				"message": "用户已被封禁",
+			})
+			c.Abort()
+			return
+		}
+
+		c.Set("id", token.UserId)
+		c.Set("token_id", token.Id)
+		c.Set("token_key", token.Key)
+		c.Next()
+	}
+}
+
 func TokenAuth() func(c *gin.Context) {
 	return func(c *gin.Context) {
 		// 先检测是否为ws

+ 31 - 33
model/log.go

@@ -2,9 +2,8 @@ package model
 
 import (
 	"context"
+	"errors"
 	"fmt"
-	"os"
-	"strings"
 	"time"
 
 	"github.com/QuantumNous/new-api/common"
@@ -66,16 +65,8 @@ func formatUserLogs(logs []*Log) {
 	}
 }
 
-func GetLogByKey(key string) (logs []*Log, err error) {
-	if os.Getenv("LOG_SQL_DSN") != "" {
-		var tk Token
-		if err = DB.Model(&Token{}).Where(logKeyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
-			return nil, err
-		}
-		err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error
-	} else {
-		err = LOG_DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error
-	}
+func GetLogByTokenId(tokenId int) (logs []*Log, err error) {
+	err = LOG_DB.Model(&Log{}).Where("token_id = ?", tokenId).Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
 	formatUserLogs(logs)
 	return logs, err
 }
@@ -276,6 +267,8 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
 	return logs, total, err
 }
 
+const logSearchCountLimit = 10000
+
 func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int, group string, requestId string) (logs []*Log, total int64, err error) {
 	var tx *gorm.DB
 	if logType == LogTypeUnknown {
@@ -285,7 +278,11 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
 	}
 
 	if modelName != "" {
-		tx = tx.Where("logs.model_name like ?", modelName)
+		modelNamePattern, err := sanitizeLikePattern(modelName)
+		if err != nil {
+			return nil, 0, err
+		}
+		tx = tx.Where("logs.model_name LIKE ? ESCAPE '!'", modelNamePattern)
 	}
 	if tokenName != "" {
 		tx = tx.Where("logs.token_name = ?", tokenName)
@@ -302,37 +299,28 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
 	if group != "" {
 		tx = tx.Where("logs."+logGroupCol+" = ?", group)
 	}
-	err = tx.Model(&Log{}).Count(&total).Error
+	err = tx.Model(&Log{}).Limit(logSearchCountLimit).Count(&total).Error
 	if err != nil {
-		return nil, 0, err
+		common.SysError("failed to count user logs: " + err.Error())
+		return nil, 0, errors.New("查询日志失败")
 	}
 	err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error
 	if err != nil {
-		return nil, 0, err
+		common.SysError("failed to search user logs: " + err.Error())
+		return nil, 0, errors.New("查询日志失败")
 	}
 
 	formatUserLogs(logs)
 	return logs, total, err
 }
 
-func SearchAllLogs(keyword string) (logs []*Log, err error) {
-	err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
-	return logs, err
-}
-
-func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
-	err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
-	formatUserLogs(logs)
-	return logs, err
-}
-
 type Stat struct {
 	Quota int `json:"quota"`
 	Rpm   int `json:"rpm"`
 	Tpm   int `json:"tpm"`
 }
 
-func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat) {
+func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat, err error) {
 	tx := LOG_DB.Table("logs").Select("sum(quota) quota")
 
 	// 为rpm和tpm创建单独的查询
@@ -353,8 +341,12 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
 		tx = tx.Where("created_at <= ?", endTimestamp)
 	}
 	if modelName != "" {
-		tx = tx.Where("model_name like ?", modelName)
-		rpmTpmQuery = rpmTpmQuery.Where("model_name like ?", modelName)
+		modelNamePattern, err := sanitizeLikePattern(modelName)
+		if err != nil {
+			return stat, err
+		}
+		tx = tx.Where("model_name LIKE ? ESCAPE '!'", modelNamePattern)
+		rpmTpmQuery = rpmTpmQuery.Where("model_name LIKE ? ESCAPE '!'", modelNamePattern)
 	}
 	if channel != 0 {
 		tx = tx.Where("channel_id = ?", channel)
@@ -372,10 +364,16 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
 	rpmTpmQuery = rpmTpmQuery.Where("created_at >= ?", time.Now().Add(-60*time.Second).Unix())
 
 	// 执行查询
-	tx.Scan(&stat)
-	rpmTpmQuery.Scan(&stat)
+	if err := tx.Scan(&stat).Error; err != nil {
+		common.SysError("failed to query log stat: " + err.Error())
+		return stat, errors.New("查询统计数据失败")
+	}
+	if err := rpmTpmQuery.Scan(&stat).Error; err != nil {
+		common.SysError("failed to query rpm/tpm stat: " + err.Error())
+		return stat, errors.New("查询统计数据失败")
+	}
 
-	return stat
+	return stat, nil
 }
 
 func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {

+ 5 - 6
router/api-router.go

@@ -50,7 +50,6 @@ func SetApiRouter(router *gin.Engine) {
 
 		// Universal secure verification routes
 		apiRouter.POST("/verify", middleware.UserAuth(), middleware.CriticalRateLimit(), controller.UniversalVerify)
-		apiRouter.GET("/verify/status", middleware.UserAuth(), controller.GetVerificationStatus)
 
 		userRoute := apiRouter.Group("/user")
 		{
@@ -247,10 +246,10 @@ func SetApiRouter(router *gin.Engine) {
 		}
 
 		usageRoute := apiRouter.Group("/usage")
-		usageRoute.Use(middleware.CriticalRateLimit())
+		usageRoute.Use(middleware.CORS(), middleware.CriticalRateLimit())
 		{
 			tokenUsageRoute := usageRoute.Group("/token")
-			tokenUsageRoute.Use(middleware.TokenAuth())
+			tokenUsageRoute.Use(middleware.TokenAuthReadOnly())
 			{
 				tokenUsageRoute.GET("/", controller.GetTokenUsage)
 			}
@@ -275,15 +274,15 @@ func SetApiRouter(router *gin.Engine) {
 		logRoute.GET("/channel_affinity_usage_cache", middleware.AdminAuth(), controller.GetChannelAffinityUsageCacheStats)
 		logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
 		logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs)
-		logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs)
+		logRoute.GET("/self/search", middleware.UserAuth(), middleware.SearchRateLimit(), controller.SearchUserLogs)
 
 		dataRoute := apiRouter.Group("/data")
 		dataRoute.GET("/", middleware.AdminAuth(), controller.GetAllQuotaDates)
 		dataRoute.GET("/self", middleware.UserAuth(), controller.GetUserQuotaDates)
 
-		logRoute.Use(middleware.CORS())
+		logRoute.Use(middleware.CORS(), middleware.CriticalRateLimit())
 		{
-			logRoute.GET("/token", controller.GetLogByKey)
+			logRoute.GET("/token", middleware.TokenAuthReadOnly(), controller.GetLogByKey)
 		}
 		groupRoute := apiRouter.Group("/group")
 		groupRoute.Use(middleware.AdminAuth())