Jelajahi Sumber

refactor: token cache logic

CalciumIon 1 tahun lalu
induk
melakukan
bb5e032dd2
15 mengubah file dengan 414 tambahan dan 193 penghapusan
  1. 18 1
      common/crypto.go
  2. 185 25
      common/redis.go
  3. 6 1
      constant/cache_key.go
  4. 0 3
      main.go
  5. 0 83
      model/cache.go
  6. 0 10
      model/log.go
  7. 14 0
      model/main.go
  8. 96 33
      model/token.go
  9. 64 0
      model/token_cache.go
  10. 12 12
      model/user.go
  11. 7 7
      model/user_cache.go
  12. 4 0
      model/utils.go
  13. 3 13
      relay/relay-audio.go
  14. 4 4
      relay/relay-text.go
  15. 1 1
      service/quota.go

+ 18 - 1
common/crypto.go

@@ -1,6 +1,23 @@
 package common
 
-import "golang.org/x/crypto/bcrypt"
+import (
+	"crypto/hmac"
+	"crypto/sha256"
+	"encoding/hex"
+	"golang.org/x/crypto/bcrypt"
+)
+
+func GenerateHMACWithKey(key []byte, data string) string {
+	h := hmac.New(sha256.New, key)
+	h.Write([]byte(data))
+	return hex.EncodeToString(h.Sum(nil))
+}
+
+func GenerateHMAC(data string) string {
+	h := hmac.New(sha256.New, []byte(SessionSecret))
+	h.Write([]byte(data))
+	return hex.EncodeToString(h.Sum(nil))
+}
 
 func Password2Hash(password string) (string, error) {
 	passwordBytes := []byte(password)

+ 185 - 25
common/redis.go

@@ -2,11 +2,15 @@ package common
 
 import (
 	"context"
+	"errors"
 	"fmt"
 	"os"
+	"reflect"
+	"strconv"
 	"time"
 
 	"github.com/go-redis/redis/v8"
+	"gorm.io/gorm"
 )
 
 var RDB *redis.Client
@@ -58,39 +62,167 @@ func RedisGet(key string) (string, error) {
 	return RDB.Get(ctx, key).Result()
 }
 
-func RedisExpire(key string, expiration time.Duration) error {
+//func RedisExpire(key string, expiration time.Duration) error {
+//	ctx := context.Background()
+//	return RDB.Expire(ctx, key, expiration).Err()
+//}
+//
+//func RedisGetEx(key string, expiration time.Duration) (string, error) {
+//	ctx := context.Background()
+//	return RDB.GetSet(ctx, key, expiration).Result()
+//}
+
+func RedisDel(key string) error {
 	ctx := context.Background()
-	return RDB.Expire(ctx, key, expiration).Err()
+	return RDB.Del(ctx, key).Err()
 }
 
-func RedisGetEx(key string, expiration time.Duration) (string, error) {
+func RedisHDelObj(key string) error {
 	ctx := context.Background()
-	return RDB.GetSet(ctx, key, expiration).Result()
+	return RDB.HDel(ctx, key).Err()
 }
 
-func RedisDel(key string) error {
+func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
 	ctx := context.Background()
-	return RDB.Del(ctx, key).Err()
+
+	data := make(map[string]interface{})
+
+	// 使用反射遍历结构体字段
+	v := reflect.ValueOf(obj).Elem()
+	t := v.Type()
+	for i := 0; i < v.NumField(); i++ {
+		field := t.Field(i)
+		value := v.Field(i)
+
+		// Skip DeletedAt field
+		if field.Type.String() == "gorm.DeletedAt" {
+			continue
+		}
+
+		// 处理指针类型
+		if value.Kind() == reflect.Ptr {
+			if value.IsNil() {
+				data[field.Name] = ""
+				continue
+			}
+			value = value.Elem()
+		}
+
+		// 处理布尔类型
+		if value.Kind() == reflect.Bool {
+			data[field.Name] = strconv.FormatBool(value.Bool())
+			continue
+		}
+
+		// 其他类型直接转换为字符串
+		data[field.Name] = fmt.Sprintf("%v", value.Interface())
+	}
+
+	txn := RDB.TxPipeline()
+	txn.HSet(ctx, key, data)
+	txn.Expire(ctx, key, expiration)
+
+	_, err := txn.Exec(ctx)
+	if err != nil {
+		return fmt.Errorf("failed to execute transaction: %w", err)
+	}
+	return nil
 }
 
-func RedisDecrease(key string, value int64) error {
+func RedisHGetObj(key string, obj interface{}) error {
+	ctx := context.Background()
+
+	result, err := RDB.HGetAll(ctx, key).Result()
+	if err != nil {
+		return fmt.Errorf("failed to load hash from Redis: %w", err)
+	}
+
+	if len(result) == 0 {
+		return fmt.Errorf("key %s not found in Redis", key)
+	}
+
+	// Handle both pointer and non-pointer values
+	val := reflect.ValueOf(obj)
+	if val.Kind() != reflect.Ptr {
+		return fmt.Errorf("obj must be a pointer to a struct, got %T", obj)
+	}
+
+	v := val.Elem()
+	if v.Kind() != reflect.Struct {
+		return fmt.Errorf("obj must be a pointer to a struct, got pointer to %T", v.Interface())
+	}
+
+	t := v.Type()
+	for i := 0; i < v.NumField(); i++ {
+		field := t.Field(i)
+		fieldName := field.Name
+		if value, ok := result[fieldName]; ok {
+			fieldValue := v.Field(i)
+
+			// Handle pointer types
+			if fieldValue.Kind() == reflect.Ptr {
+				if value == "" {
+					continue
+				}
+				if fieldValue.IsNil() {
+					fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
+				}
+				fieldValue = fieldValue.Elem()
+			}
+
+			// Enhanced type handling for Token struct
+			switch fieldValue.Kind() {
+			case reflect.String:
+				fieldValue.SetString(value)
+			case reflect.Int, reflect.Int64:
+				intValue, err := strconv.ParseInt(value, 10, 64)
+				if err != nil {
+					return fmt.Errorf("failed to parse int field %s: %w", fieldName, err)
+				}
+				fieldValue.SetInt(intValue)
+			case reflect.Bool:
+				boolValue, err := strconv.ParseBool(value)
+				if err != nil {
+					return fmt.Errorf("failed to parse bool field %s: %w", fieldName, err)
+				}
+				fieldValue.SetBool(boolValue)
+			case reflect.Struct:
+				// Special handling for gorm.DeletedAt
+				if fieldValue.Type().String() == "gorm.DeletedAt" {
+					if value != "" {
+						timeValue, err := time.Parse(time.RFC3339, value)
+						if err != nil {
+							return fmt.Errorf("failed to parse DeletedAt field %s: %w", fieldName, err)
+						}
+						fieldValue.Set(reflect.ValueOf(gorm.DeletedAt{Time: timeValue, Valid: true}))
+					}
+				}
+			default:
+				return fmt.Errorf("unsupported field type: %s for field %s", fieldValue.Kind(), fieldName)
+			}
+		}
+	}
+
+	return nil
+}
 
+// RedisIncr Add this function to handle atomic increments
+func RedisIncr(key string, delta int64) error {
 	// 检查键的剩余生存时间
 	ttlCmd := RDB.TTL(context.Background(), key)
 	ttl, err := ttlCmd.Result()
-	if err != nil {
-		// 失败则尝试直接减少
-		return RDB.DecrBy(context.Background(), key, value).Err()
+	if err != nil && !errors.Is(err, redis.Nil) {
+		return fmt.Errorf("failed to get TTL: %w", err)
 	}
 
-	// 如果剩余生存时间大于0,则进行减少操作
+	// 只有在 key 存在且有 TTL 时才需要特殊处理
 	if ttl > 0 {
 		ctx := context.Background()
 		// 开始一个Redis事务
 		txn := RDB.TxPipeline()
 
 		// 减少余额
-		decrCmd := txn.DecrBy(ctx, key, value)
+		decrCmd := txn.IncrBy(ctx, key, delta)
 		if err := decrCmd.Err(); err != nil {
 			return err // 如果减少失败,则直接返回错误
 		}
@@ -101,26 +233,54 @@ func RedisDecrease(key string, value int64) error {
 		// 执行事务
 		_, err = txn.Exec(ctx)
 		return err
-	} else {
-		_ = RedisDel(key)
 	}
 	return nil
 }
 
-// RedisIncr Add this function to handle atomic increments
-func RedisIncr(key string, delta int) error {
-	ctx := context.Background()
+func RedisHIncrBy(key, field string, delta int64) error {
+	ttlCmd := RDB.TTL(context.Background(), key)
+	ttl, err := ttlCmd.Result()
+	if err != nil && !errors.Is(err, redis.Nil) {
+		return fmt.Errorf("failed to get TTL: %w", err)
+	}
 
-	// 检查键是否存在
-	exists, err := RDB.Exists(ctx, key).Result()
-	if err != nil {
+	if ttl > 0 {
+		ctx := context.Background()
+		txn := RDB.TxPipeline()
+
+		incrCmd := txn.HIncrBy(ctx, key, field, delta)
+		if err := incrCmd.Err(); err != nil {
+			return err
+		}
+
+		txn.Expire(ctx, key, ttl)
+
+		_, err = txn.Exec(ctx)
 		return err
 	}
-	if exists == 0 {
-		return fmt.Errorf("key does not exist") // 键不存在,返回错误
+	return nil
+}
+
+func RedisHSetField(key, field string, value interface{}) error {
+	ttlCmd := RDB.TTL(context.Background(), key)
+	ttl, err := ttlCmd.Result()
+	if err != nil && !errors.Is(err, redis.Nil) {
+		return fmt.Errorf("failed to get TTL: %w", err)
 	}
 
-	// 键存在,执行INCRBY操作
-	result := RDB.IncrBy(ctx, key, int64(delta))
-	return result.Err()
+	if ttl > 0 {
+		ctx := context.Background()
+		txn := RDB.TxPipeline()
+
+		hsetCmd := txn.HSet(ctx, key, field, value)
+		if err := hsetCmd.Err(); err != nil {
+			return err
+		}
+
+		txn.Expire(ctx, key, ttl)
+
+		_, err = txn.Exec(ctx)
+		return err
+	}
+	return nil
 }

+ 6 - 1
constant/cache_key.go

@@ -9,10 +9,15 @@ var (
 	UserId2StatusCacheSeconds = common.SyncFrequency
 )
 
+// Cache keys
 const (
-	// Cache keys
 	UserGroupKeyFmt    = "user_group:%d"
 	UserQuotaKeyFmt    = "user_quota:%d"
 	UserEnabledKeyFmt  = "user_enabled:%d"
 	UserUsernameKeyFmt = "user_name:%d"
 )
+
+const (
+	TokenFiledRemainQuota = "RemainQuota"
+	TokenFieldGroup       = "Group"
+)

+ 0 - 3
main.go

@@ -80,9 +80,6 @@ func main() {
 		common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
 		model.InitChannelCache()
 	}
-	if common.RedisEnabled {
-		go model.SyncTokenCache(common.SyncFrequency)
-	}
 	if common.MemoryCacheEnabled {
 		go model.SyncOptions(common.SyncFrequency)
 		go model.SyncChannelCache(common.SyncFrequency)

+ 0 - 83
model/cache.go

@@ -1,99 +1,16 @@
 package model
 
 import (
-	"encoding/json"
 	"errors"
 	"fmt"
 	"math/rand"
 	"one-api/common"
-	"one-api/constant"
 	"sort"
 	"strings"
 	"sync"
 	"time"
 )
 
-// 仅用于定时同步缓存
-var token2UserId = make(map[string]int)
-var token2UserIdLock sync.RWMutex
-
-func cacheSetToken(token *Token) error {
-	jsonBytes, err := json.Marshal(token)
-	if err != nil {
-		return err
-	}
-	err = common.RedisSet(fmt.Sprintf("token:%s", token.Key), string(jsonBytes), time.Duration(constant.TokenCacheSeconds)*time.Second)
-	if err != nil {
-		common.SysError(fmt.Sprintf("failed to set token %s to redis: %s", token.Key, err.Error()))
-		return err
-	}
-	token2UserIdLock.Lock()
-	defer token2UserIdLock.Unlock()
-	token2UserId[token.Key] = token.UserId
-	return nil
-}
-
-// CacheGetTokenByKey 从缓存中获取 token 并续期时间,如果缓存中不存在,则从数据库中获取
-func CacheGetTokenByKey(key string) (*Token, error) {
-	if !common.RedisEnabled {
-		return GetTokenByKey(key)
-	}
-	var token *Token
-	tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
-	if err != nil {
-		// 如果缓存中不存在,则从数据库中获取
-		token, err = GetTokenByKey(key)
-		if err != nil {
-			return nil, err
-		}
-		err = cacheSetToken(token)
-		return token, nil
-	}
-	// 如果缓存中存在,则续期时间
-	err = common.RedisExpire(fmt.Sprintf("token:%s", key), time.Duration(constant.TokenCacheSeconds)*time.Second)
-	err = json.Unmarshal([]byte(tokenObjectString), &token)
-	return token, err
-}
-
-func SyncTokenCache(frequency int) {
-	for {
-		time.Sleep(time.Duration(frequency) * time.Second)
-		common.SysLog("syncing tokens from database")
-		token2UserIdLock.Lock()
-		// 从token2UserId中获取所有的key
-		var copyToken2UserId = make(map[string]int)
-		for s, i := range token2UserId {
-			copyToken2UserId[s] = i
-		}
-		token2UserId = make(map[string]int)
-		token2UserIdLock.Unlock()
-
-		for key := range copyToken2UserId {
-			token, err := GetTokenByKey(key)
-			if err != nil {
-				// 如果数据库中不存在,则删除缓存
-				common.SysError(fmt.Sprintf("failed to get token %s from database: %s", key, err.Error()))
-				//delete redis
-				err := common.RedisDel(fmt.Sprintf("token:%s", key))
-				if err != nil {
-					common.SysError(fmt.Sprintf("failed to delete token %s from redis: %s", key, err.Error()))
-				}
-			} else {
-				// 如果数据库中存在,先检查redis
-				_, err = common.RedisGet(fmt.Sprintf("token:%s", key))
-				if err != nil {
-					// 如果redis中不存在,则跳过
-					continue
-				}
-				err = cacheSetToken(token)
-				if err != nil {
-					common.SysError(fmt.Sprintf("failed to update token %s to redis: %s", key, err.Error()))
-				}
-			}
-		}
-	}
-}
-
 //func CacheGetUserGroup(id int) (group string, err error) {
 //	if !common.RedisEnabled {
 //		return GetUserGroup(id)

+ 0 - 10
model/log.go

@@ -12,16 +12,6 @@ import (
 	"gorm.io/gorm"
 )
 
-var groupCol string
-
-func init() {
-	if common.UsingPostgreSQL {
-		groupCol = `"group"`
-	} else {
-		groupCol = "`group`"
-	}
-}
-
 type Log struct {
 	Id               int    `json:"id" gorm:"index:idx_created_at_id,priority:1"`
 	UserId           int    `json:"user_id" gorm:"index"`

+ 14 - 0
model/main.go

@@ -13,6 +13,20 @@ import (
 	"time"
 )
 
+var groupCol string
+var keyCol string
+
+func init() {
+	if common.UsingPostgreSQL {
+		groupCol = `"group"`
+		keyCol = `"key"`
+
+	} else {
+		groupCol = "`group`"
+		keyCol = "`key`"
+	}
+}
+
 var DB *gorm.DB
 
 var LOG_DB *gorm.DB

+ 96 - 33
model/token.go

@@ -3,6 +3,7 @@ package model
 import (
 	"errors"
 	"fmt"
+	"github.com/bytedance/gopkg/util/gopool"
 	"gorm.io/gorm"
 	"one-api/common"
 	relaycommon "one-api/relay/common"
@@ -30,6 +31,10 @@ type Token struct {
 	DeletedAt          gorm.DeletedAt `gorm:"index"`
 }
 
+func (token *Token) Clean() {
+	token.Key = ""
+}
+
 func (token *Token) GetIpLimitsMap() map[string]any {
 	// delete empty spaces
 	//split with \n
@@ -71,7 +76,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
 	if key == "" {
 		return nil, errors.New("未提供令牌")
 	}
-	token, err = CacheGetTokenByKey(key)
+	token, err = GetTokenByKey(key, false)
 	if err == nil {
 		if token.Status == common.TokenStatusExhausted {
 			keyPrefix := key[:3]
@@ -129,21 +134,37 @@ func GetTokenById(id int) (*Token, error) {
 	var err error = nil
 	err = DB.First(&token, "id = ?", id).Error
 	if err != nil {
-		if common.RedisEnabled {
-			go cacheSetToken(&token)
-		}
+		gopool.Go(func() {
+			if err := cacheSetToken(token); err != nil {
+				common.SysError("failed to update user status cache: " + err.Error())
+			}
+		})
 	}
 	return &token, err
 }
 
-func GetTokenByKey(key string) (*Token, error) {
-	keyCol := "`key`"
-	if common.UsingPostgreSQL {
-		keyCol = `"key"`
+func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
+	defer func() {
+		// Update Redis cache asynchronously on successful DB read
+		if shouldUpdateRedis(fromDB, err) && token != nil {
+			gopool.Go(func() {
+				if err := cacheSetToken(*token); err != nil {
+					common.SysError("failed to update user status cache: " + err.Error())
+				}
+			})
+		}
+	}()
+	if !fromDB && common.RedisEnabled {
+		// Try Redis first
+		token, err := cacheGetTokenByKey(key)
+		if err == nil {
+			return token, nil
+		}
+		// Don't return error - fall through to DB
 	}
-	var token Token
-	err := DB.Where(keyCol+" = ?", key).First(&token).Error
-	return &token, err
+	fromDB = true
+	err = DB.Where(keyCol+" = ?", key).First(&token).Error
+	return token, err
 }
 
 func (token *Token) Insert() error {
@@ -153,20 +174,48 @@ func (token *Token) Insert() error {
 }
 
 // Update Make sure your token's fields is completed, because this will update non-zero values
-func (token *Token) Update() error {
-	var err error
+func (token *Token) Update() (err error) {
+	defer func() {
+		if common.RedisEnabled && err == nil {
+			gopool.Go(func() {
+				err := cacheSetToken(*token)
+				if err != nil {
+					common.SysError("failed to update token cache: " + err.Error())
+				}
+			})
+		}
+	}()
 	err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota",
 		"model_limits_enabled", "model_limits", "allow_ips", "group").Updates(token).Error
 	return err
 }
 
-func (token *Token) SelectUpdate() error {
+func (token *Token) SelectUpdate() (err error) {
+	defer func() {
+		if common.RedisEnabled && err == nil {
+			gopool.Go(func() {
+				err := cacheSetToken(*token)
+				if err != nil {
+					common.SysError("failed to update token cache: " + err.Error())
+				}
+			})
+		}
+	}()
 	// This can update zero values
 	return DB.Model(token).Select("accessed_time", "status").Updates(token).Error
 }
 
-func (token *Token) Delete() error {
-	var err error
+func (token *Token) Delete() (err error) {
+	defer func() {
+		if common.RedisEnabled && err == nil {
+			gopool.Go(func() {
+				err := cacheDeleteToken(token.Key)
+				if err != nil {
+					common.SysError("failed to delete token cache: " + err.Error())
+				}
+			})
+		}
+	}()
 	err = DB.Delete(token).Error
 	return err
 }
@@ -214,10 +263,16 @@ func DeleteTokenById(id int, userId int) (err error) {
 	return token.Delete()
 }
 
-func IncreaseTokenQuota(id int, quota int) (err error) {
+func IncreaseTokenQuota(id int, key string, quota int) (err error) {
 	if quota < 0 {
 		return errors.New("quota 不能为负数!")
 	}
+	gopool.Go(func() {
+		err := cacheIncrTokenQuota(key, int64(quota))
+		if err != nil {
+			common.SysError("failed to increase token quota: " + err.Error())
+		}
+	})
 	if common.BatchUpdateEnabled {
 		addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
 		return nil
@@ -236,10 +291,16 @@ func increaseTokenQuota(id int, quota int) (err error) {
 	return err
 }
 
-func DecreaseTokenQuota(id int, quota int) (err error) {
+func DecreaseTokenQuota(id int, key string, quota int) (err error) {
 	if quota < 0 {
 		return errors.New("quota 不能为负数!")
 	}
+	gopool.Go(func() {
+		err := cacheDecrTokenQuota(key, int64(quota))
+		if err != nil {
+			common.SysError("failed to decrease token quota: " + err.Error())
+		}
+	})
 	if common.BatchUpdateEnabled {
 		addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
 		return nil
@@ -262,20 +323,22 @@ func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
 	if quota < 0 {
 		return errors.New("quota 不能为负数!")
 	}
-	if !relayInfo.IsPlayground {
-		token, err := GetTokenById(relayInfo.TokenId)
-		if err != nil {
-			return err
-		}
-		if !token.UnlimitedQuota && token.RemainQuota < quota {
-			return errors.New("令牌额度不足")
-		}
+	if relayInfo.IsPlayground {
+		return nil
 	}
-	if !relayInfo.IsPlayground {
-		err := DecreaseTokenQuota(relayInfo.TokenId, quota)
-		if err != nil {
-			return err
-		}
+	//if relayInfo.TokenUnlimited {
+	//	return nil
+	//}
+	token, err := GetTokenById(relayInfo.TokenId)
+	if err != nil {
+		return err
+	}
+	if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
+		return errors.New("令牌额度不足")
+	}
+	err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
+	if err != nil {
+		return err
 	}
 	return nil
 }
@@ -293,9 +356,9 @@ func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int
 
 	if !relayInfo.IsPlayground {
 		if quota > 0 {
-			err = DecreaseTokenQuota(relayInfo.TokenId, quota)
+			err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
 		} else {
-			err = IncreaseTokenQuota(relayInfo.TokenId, -quota)
+			err = IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
 		}
 		if err != nil {
 			return err

+ 64 - 0
model/token_cache.go

@@ -0,0 +1,64 @@
+package model
+
+import (
+	"fmt"
+	"one-api/common"
+	"one-api/constant"
+	"time"
+)
+
+func cacheSetToken(token Token) error {
+	key := common.GenerateHMAC(token.Key)
+	token.Clean()
+	err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.TokenCacheSeconds)*time.Second)
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+func cacheDeleteToken(key string) error {
+	key = common.GenerateHMAC(key)
+	err := common.RedisHDelObj(fmt.Sprintf("token:%s", key))
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+func cacheIncrTokenQuota(key string, increment int64) error {
+	key = common.GenerateHMAC(key)
+	err := common.RedisHIncrBy(fmt.Sprintf("token:%s", key), constant.TokenFiledRemainQuota, increment)
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+func cacheDecrTokenQuota(key string, decrement int64) error {
+	return cacheIncrTokenQuota(key, -decrement)
+}
+
+func cacheSetTokenField(key string, field string, value string) error {
+	key = common.GenerateHMAC(key)
+	err := common.RedisHSetField(fmt.Sprintf("token:%s", key), field, value)
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+// CacheGetTokenByKey 从缓存中获取 token,如果缓存中不存在,则从数据库中获取
+func cacheGetTokenByKey(key string) (*Token, error) {
+	hmacKey := common.GenerateHMAC(key)
+	if !common.RedisEnabled {
+		return nil, nil
+	}
+	var token Token
+	err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token)
+	if err != nil {
+		return nil, err
+	}
+	token.Key = key
+	return &token, nil
+}

+ 12 - 12
model/user.go

@@ -252,7 +252,7 @@ func (user *User) Update(updatePassword bool) error {
 	}
 
 	// 更新缓存
-	return updateUserCache(user)
+	return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
 }
 
 func (user *User) Edit(updatePassword bool) error {
@@ -281,7 +281,7 @@ func (user *User) Edit(updatePassword bool) error {
 	}
 
 	// 更新缓存
-	return updateUserCache(user)
+	return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
 }
 
 func (user *User) Delete() error {
@@ -411,7 +411,7 @@ func IsAdmin(userId int) bool {
 func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
 	defer func() {
 		// Update Redis cache asynchronously on successful DB read
-		if common.RedisEnabled {
+		if shouldUpdateRedis(fromDB, err) {
 			gopool.Go(func() {
 				if err := updateUserStatusCache(id, status); err != nil {
 					common.SysError("failed to update user status cache: " + err.Error())
@@ -427,7 +427,7 @@ func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
 		}
 		// Don't return error - fall through to DB
 	}
-
+	fromDB = true
 	var user User
 	err = DB.Where("id = ?", id).Select("status").Find(&user).Error
 	if err != nil {
@@ -453,7 +453,7 @@ func ValidateAccessToken(token string) (user *User) {
 func GetUserQuota(id int, fromDB bool) (quota int, err error) {
 	defer func() {
 		// Update Redis cache asynchronously on successful DB read
-		if common.RedisEnabled && err == nil {
+		if shouldUpdateRedis(fromDB, err) {
 			gopool.Go(func() {
 				if err := updateUserQuotaCache(id, quota); err != nil {
 					common.SysError("failed to update user quota cache: " + err.Error())
@@ -469,7 +469,7 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) {
 		// Don't return error - fall through to DB
 		//common.SysError("failed to get user quota from cache: " + err.Error())
 	}
-
+	fromDB = true
 	err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error
 	if err != nil {
 		return 0, err
@@ -492,7 +492,7 @@ func GetUserEmail(id int) (email string, err error) {
 func GetUserGroup(id int, fromDB bool) (group string, err error) {
 	defer func() {
 		// Update Redis cache asynchronously on successful DB read
-		if common.RedisEnabled && err == nil {
+		if shouldUpdateRedis(fromDB, err) {
 			gopool.Go(func() {
 				if err := updateUserGroupCache(id, group); err != nil {
 					common.SysError("failed to update user group cache: " + err.Error())
@@ -507,7 +507,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
 		}
 		// Don't return error - fall through to DB
 	}
-
+	fromDB = true
 	err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
 	if err != nil {
 		return "", err
@@ -521,7 +521,7 @@ func IncreaseUserQuota(id int, quota int) (err error) {
 		return errors.New("quota 不能为负数!")
 	}
 	gopool.Go(func() {
-		err := cacheIncrUserQuota(id, quota)
+		err := cacheIncrUserQuota(id, int64(quota))
 		if err != nil {
 			common.SysError("failed to increase user quota: " + err.Error())
 		}
@@ -546,7 +546,7 @@ func DecreaseUserQuota(id int, quota int) (err error) {
 		return errors.New("quota 不能为负数!")
 	}
 	gopool.Go(func() {
-		err := cacheDecrUserQuota(id, quota)
+		err := cacheDecrUserQuota(id, int64(quota))
 		if err != nil {
 			common.SysError("failed to decrease user quota: " + err.Error())
 		}
@@ -631,7 +631,7 @@ func updateUserRequestCount(id int, count int) {
 func GetUsernameById(id int, fromDB bool) (username string, err error) {
 	defer func() {
 		// Update Redis cache asynchronously on successful DB read
-		if common.RedisEnabled && err == nil {
+		if shouldUpdateRedis(fromDB, err) {
 			gopool.Go(func() {
 				if err := updateUserNameCache(id, username); err != nil {
 					common.SysError("failed to update user name cache: " + err.Error())
@@ -646,7 +646,7 @@ func GetUsernameById(id int, fromDB bool) (username string, err error) {
 		}
 		// Don't return error - fall through to DB
 	}
-
+	fromDB = true
 	err = DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username).Error
 	if err != nil {
 		return "", err

+ 7 - 7
model/user_cache.go

@@ -93,24 +93,24 @@ func updateUserNameCache(userId int, username string) error {
 }
 
 // updateUserCache updates all user cache fields
-func updateUserCache(user *User) error {
+func updateUserCache(userId int, username string, userGroup string, quota int, status int) error {
 	if !common.RedisEnabled {
 		return nil
 	}
 
-	if err := updateUserGroupCache(user.Id, user.Group); err != nil {
+	if err := updateUserGroupCache(userId, userGroup); err != nil {
 		return fmt.Errorf("update group cache: %w", err)
 	}
 
-	if err := updateUserQuotaCache(user.Id, user.Quota); err != nil {
+	if err := updateUserQuotaCache(userId, quota); err != nil {
 		return fmt.Errorf("update quota cache: %w", err)
 	}
 
-	if err := updateUserStatusCache(user.Id, user.Status == common.UserStatusEnabled); err != nil {
+	if err := updateUserStatusCache(userId, status == common.UserStatusEnabled); err != nil {
 		return fmt.Errorf("update status cache: %w", err)
 	}
 
-	if err := updateUserNameCache(user.Id, user.Username); err != nil {
+	if err := updateUserNameCache(userId, username); err != nil {
 		return fmt.Errorf("update username cache: %w", err)
 	}
 
@@ -193,7 +193,7 @@ func getUserCache(userId int) (*userCache, error) {
 }
 
 // Add atomic quota operations
-func cacheIncrUserQuota(userId int, delta int) error {
+func cacheIncrUserQuota(userId int, delta int64) error {
 	if !common.RedisEnabled {
 		return nil
 	}
@@ -201,6 +201,6 @@ func cacheIncrUserQuota(userId int, delta int) error {
 	return common.RedisIncr(key, delta)
 }
 
-func cacheDecrUserQuota(userId int, delta int) error {
+func cacheDecrUserQuota(userId int, delta int64) error {
 	return cacheIncrUserQuota(userId, -delta)
 }

+ 4 - 0
model/utils.go

@@ -88,3 +88,7 @@ func RecordExist(err error) (bool, error) {
 	}
 	return false, err
 }
+
+func shouldUpdateRedis(fromDB bool, err error) bool {
+	return common.RedisEnabled && fromDB && err == nil
+}

+ 3 - 13
relay/relay-audio.go

@@ -81,19 +81,9 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 	if err != nil {
 		return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
 	}
-	if userQuota-preConsumedQuota < 0 {
-		return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("audio pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest)
-	}
-	if userQuota > 100*preConsumedQuota {
-		// in this case, we do not pre-consume quota
-		// because the user has enough quota
-		preConsumedQuota = 0
-	}
-	if preConsumedQuota > 0 {
-		err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
-		if err != nil {
-			return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
-		}
+	preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, preConsumedQuota, relayInfo)
+	if openaiErr != nil {
+		return openaiErr
 	}
 	defer func() {
 		if openaiErr != nil {

+ 4 - 4
relay/relay-text.go

@@ -291,14 +291,14 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
 	}
 
 	if preConsumedQuota > 0 {
-		err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
-		if err != nil {
-			return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
-		}
 		err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
 		if err != nil {
 			return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
 		}
+		err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
+		if err != nil {
+			return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
+		}
 	}
 	return preConsumedQuota, userQuota, nil
 }

+ 1 - 1
service/quota.go

@@ -23,7 +23,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
 		return err
 	}
 
-	token, err := model.CacheGetTokenByKey(strings.TrimLeft(relayInfo.TokenKey, "sk-"))
+	token, err := model.GetTokenByKey(strings.TrimLeft(relayInfo.TokenKey, "sk-"), false)
 	if err != nil {
 		return err
 	}