Преглед изворни кода

chore: pass through error out

JustSong пре 2 година
родитељ
комит
621eb91b46
5 измењених фајлова са 56 додато и 37 уклоњено
  1. 0 1
      controller/relay-text.go
  2. 12 1
      middleware/auth.go
  3. 16 11
      model/cache.go
  4. 24 19
      model/token.go
  5. 4 5
      model/user.go

+ 0 - 1
controller/relay-text.go

@@ -377,7 +377,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 					logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 					model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
 					model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
-
 					model.UpdateChannelUsedQuota(channelId, quota)
 				}
 			}

+ 12 - 1
middleware/auth.go

@@ -100,7 +100,18 @@ func TokenAuth() func(c *gin.Context) {
 			c.Abort()
 			return
 		}
-		if !model.CacheIsUserEnabled(token.UserId) {
+		userEnabled, err := model.IsUserEnabled(token.UserId)
+		if err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{
+				"error": gin.H{
+					"message": err.Error(),
+					"type":    "one_api_error",
+				},
+			})
+			c.Abort()
+			return
+		}
+		if !userEnabled {
 			c.JSON(http.StatusForbidden, gin.H{
 				"error": gin.H{
 					"message": "用户已被封禁",

+ 16 - 11
model/cache.go

@@ -103,23 +103,28 @@ func CacheDecreaseUserQuota(id int, quota int) error {
 	return err
 }
 
-func CacheIsUserEnabled(userId int) bool {
+func CacheIsUserEnabled(userId int) (bool, error) {
 	if !common.RedisEnabled {
 		return IsUserEnabled(userId)
 	}
 	enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
+	if err == nil {
+		return enabled == "1", nil
+	}
+
+	userEnabled, err := IsUserEnabled(userId)
 	if err != nil {
-		status := common.UserStatusDisabled
-		if IsUserEnabled(userId) {
-			status = common.UserStatusEnabled
-		}
-		enabled = fmt.Sprintf("%d", status)
-		err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
-		if err != nil {
-			common.SysError("Redis set user enabled error: " + err.Error())
-		}
+		return false, err
+	}
+	enabled = "0"
+	if userEnabled {
+		enabled = "1"
+	}
+	err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
+	if err != nil {
+		common.SysError("Redis set user enabled error: " + err.Error())
 	}
-	return enabled == "1"
+	return userEnabled, err
 }
 
 var group2model2channels map[string]map[string][]*Channel

+ 24 - 19
model/token.go

@@ -39,32 +39,35 @@ func ValidateUserToken(key string) (token *Token, err error) {
 	}
 	token, err = CacheGetTokenByKey(key)
 	if err == nil {
+		if token.Status == common.TokenStatusExhausted {
+			return nil, errors.New("该令牌额度已用尽")
+		} else if token.Status == common.TokenStatusExpired {
+			return nil, errors.New("该令牌已过期")
+		}
 		if token.Status != common.TokenStatusEnabled {
 			return nil, errors.New("该令牌状态不可用")
 		}
 		if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
-			token.Status = common.TokenStatusExpired
-			err := token.SelectUpdate()
-			if err != nil {
-				common.SysError("failed to update token status" + err.Error())
+			if !common.RedisEnabled {
+				token.Status = common.TokenStatusExpired
+				err := token.SelectUpdate()
+				if err != nil {
+					common.SysError("failed to update token status" + err.Error())
+				}
 			}
 			return nil, errors.New("该令牌已过期")
 		}
 		if !token.UnlimitedQuota && token.RemainQuota <= 0 {
-			token.Status = common.TokenStatusExhausted
-			err := token.SelectUpdate()
-			if err != nil {
-				common.SysError("failed to update token status" + err.Error())
+			if !common.RedisEnabled {
+				// in this case, we can make sure the token is exhausted
+				token.Status = common.TokenStatusExhausted
+				err := token.SelectUpdate()
+				if err != nil {
+					common.SysError("failed to update token status" + err.Error())
+				}
 			}
 			return nil, errors.New("该令牌额度已用尽")
 		}
-		go func() {
-			token.AccessedTime = common.GetTimestamp()
-			err := token.SelectUpdate()
-			if err != nil {
-				common.SysError("failed to update token" + err.Error())
-			}
-		}()
 		return token, nil
 	}
 	return nil, errors.New("无效的令牌")
@@ -141,8 +144,9 @@ func IncreaseTokenQuota(id int, quota int) (err error) {
 func increaseTokenQuota(id int, quota int) (err error) {
 	err = DB.Model(&Token{}).Where("id = ?", id).Updates(
 		map[string]interface{}{
-			"remain_quota": gorm.Expr("remain_quota + ?", quota),
-			"used_quota":   gorm.Expr("used_quota - ?", quota),
+			"remain_quota":  gorm.Expr("remain_quota + ?", quota),
+			"used_quota":    gorm.Expr("used_quota - ?", quota),
+			"accessed_time": common.GetTimestamp(),
 		},
 	).Error
 	return err
@@ -162,8 +166,9 @@ func DecreaseTokenQuota(id int, quota int) (err error) {
 func decreaseTokenQuota(id int, quota int) (err error) {
 	err = DB.Model(&Token{}).Where("id = ?", id).Updates(
 		map[string]interface{}{
-			"remain_quota": gorm.Expr("remain_quota - ?", quota),
-			"used_quota":   gorm.Expr("used_quota + ?", quota),
+			"remain_quota":  gorm.Expr("remain_quota - ?", quota),
+			"used_quota":    gorm.Expr("used_quota + ?", quota),
+			"accessed_time": common.GetTimestamp(),
 		},
 	).Error
 	return err

+ 4 - 5
model/user.go

@@ -226,17 +226,16 @@ func IsAdmin(userId int) bool {
 	return user.Role >= common.RoleAdminUser
 }
 
-func IsUserEnabled(userId int) bool {
+func IsUserEnabled(userId int) (bool, error) {
 	if userId == 0 {
-		return false
+		return false, errors.New("user id is empty")
 	}
 	var user User
 	err := DB.Where("id = ?", userId).Select("status").Find(&user).Error
 	if err != nil {
-		common.SysError("no such user " + err.Error())
-		return false
+		return false, err
 	}
-	return user.Status == common.UserStatusEnabled
+	return user.Status == common.UserStatusEnabled, nil
 }
 
 func ValidateAccessToken(token string) (user *User) {