Quellcode durchsuchen

feat: do not access database before response return (close #158)

JustSong vor 2 Jahren
Ursprung
Commit
6d961064d2
3 geänderte Dateien mit 56 neuen und 5 gelöschten Zeilen
  1. 11 2
      controller/relay-text.go
  2. 1 1
      middleware/auth.go
  3. 44 2
      model/cache.go

+ 11 - 2
controller/relay-text.go

@@ -16,6 +16,7 @@ import (
 func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	channelType := c.GetInt("channel")
 	tokenId := c.GetInt("token_id")
+	userId := c.GetInt("id")
 	consumeQuota := c.GetBool("consume_quota")
 	group := c.GetString("group")
 	var textRequest GeneralOpenAIRequest
@@ -73,7 +74,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	groupRatio := common.GetGroupRatio(group)
 	ratio := modelRatio * groupRatio
 	preConsumedQuota := int(float64(preConsumedTokens) * ratio)
-	if consumeQuota {
+	userQuota, err := model.CacheGetUserQuota(userId)
+	if err != nil {
+		return errorWrapper(err, "get_user_quota_failed", http.StatusOK)
+	}
+	if userQuota > 10*preConsumedQuota {
+		// in this case, we do not pre-consume quota
+		// because the user has enough quota
+		preConsumedQuota = 0
+	}
+	if consumeQuota && preConsumedQuota > 0 {
 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 		if err != nil {
 			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusOK)
@@ -133,7 +143,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 				common.SysError("Error consuming token remain quota: " + err.Error())
 			}
 			tokenName := c.GetString("token_name")
-			userId := c.GetInt("id")
 			model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("通过令牌「%s」使用模型 %s 消耗 %s(模型倍率 %.2f,分组倍率 %.2f)", tokenName, textRequest.Model, common.LogQuota(quota), modelRatio, groupRatio))
 			model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 			channelId := c.GetInt("channel_id")

+ 1 - 1
middleware/auth.go

@@ -100,7 +100,7 @@ func TokenAuth() func(c *gin.Context) {
 			c.Abort()
 			return
 		}
-		if !model.IsUserEnabled(token.UserId) {
+		if !model.CacheIsUserEnabled(token.UserId) {
 			c.JSON(http.StatusOK, gin.H{
 				"error": gin.H{
 					"message": "用户已被封禁",

+ 44 - 2
model/cache.go

@@ -6,14 +6,17 @@ import (
 	"fmt"
 	"math/rand"
 	"one-api/common"
+	"strconv"
 	"strings"
 	"sync"
 	"time"
 )
 
 const (
-	TokenCacheSeconds        = 60 * 60
-	UserId2GroupCacheSeconds = 60 * 60
+	TokenCacheSeconds         = 60 * 60
+	UserId2GroupCacheSeconds  = 60 * 60
+	UserId2QuotaCacheSeconds  = 10 * 60
+	UserId2StatusCacheSeconds = 60 * 60
 )
 
 func CacheGetTokenByKey(key string) (*Token, error) {
@@ -60,6 +63,45 @@ func CacheGetUserGroup(id int) (group string, err error) {
 	return group, err
 }
 
+func CacheGetUserQuota(id int) (quota int, err error) {
+	if !common.RedisEnabled {
+		return GetUserQuota(id)
+	}
+	quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
+	if err != nil {
+		quota, err = GetUserQuota(id)
+		if err != nil {
+			return 0, err
+		}
+		err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), UserId2QuotaCacheSeconds*time.Second)
+		if err != nil {
+			common.SysError("Redis set user quota error: " + err.Error())
+		}
+		return quota, err
+	}
+	quota, err = strconv.Atoi(quotaString)
+	return quota, err
+}
+
+func CacheIsUserEnabled(userId int) bool {
+	if !common.RedisEnabled {
+		return IsUserEnabled(userId)
+	}
+	enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
+	if err != nil {
+		status := common.UserStatusDisabled
+		if IsUserEnabled(userId) {
+			status = common.UserStatusEnabled
+		}
+		enabled = fmt.Sprintf("%d", status)
+		err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, UserId2StatusCacheSeconds*time.Second)
+		if err != nil {
+			common.SysError("Redis set user enabled error: " + err.Error())
+		}
+	}
+	return enabled == "1"
+}
+
 var group2model2channels map[string]map[string][]*Channel
 var channelSyncLock sync.RWMutex