| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327 |
- package common
- import (
- "context"
- "errors"
- "fmt"
- "os"
- "reflect"
- "strconv"
- "time"
- "github.com/go-redis/redis/v8"
- "gorm.io/gorm"
- )
- var RDB *redis.Client
- var RedisEnabled = true
- func RedisKeyCacheSeconds() int {
- return SyncFrequency
- }
- // InitRedisClient This function is called after init()
- func InitRedisClient() (err error) {
- if os.Getenv("REDIS_CONN_STRING") == "" {
- RedisEnabled = false
- SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
- return nil
- }
- if os.Getenv("SYNC_FREQUENCY") == "" {
- SysLog("SYNC_FREQUENCY not set, use default value 60")
- SyncFrequency = 60
- }
- SysLog("Redis is enabled")
- opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
- if err != nil {
- FatalLog("failed to parse Redis connection string: " + err.Error())
- }
- opt.PoolSize = GetEnvOrDefault("REDIS_POOL_SIZE", 10)
- RDB = redis.NewClient(opt)
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
- _, err = RDB.Ping(ctx).Result()
- if err != nil {
- FatalLog("Redis ping test failed: " + err.Error())
- }
- if DebugEnabled {
- SysLog(fmt.Sprintf("Redis connected to %s", opt.Addr))
- SysLog(fmt.Sprintf("Redis database: %d", opt.DB))
- }
- return err
- }
- func ParseRedisOption() *redis.Options {
- opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
- if err != nil {
- FatalLog("failed to parse Redis connection string: " + err.Error())
- }
- return opt
- }
- func RedisSet(key string, value string, expiration time.Duration) error {
- if DebugEnabled {
- SysLog(fmt.Sprintf("Redis SET: key=%s, value=%s, expiration=%v", key, value, expiration))
- }
- ctx := context.Background()
- return RDB.Set(ctx, key, value, expiration).Err()
- }
- func RedisGet(key string) (string, error) {
- if DebugEnabled {
- SysLog(fmt.Sprintf("Redis GET: key=%s", key))
- }
- ctx := context.Background()
- val, err := RDB.Get(ctx, key).Result()
- return val, err
- }
- //func RedisExpire(key string, expiration time.Duration) error {
- // ctx := context.Background()
- // return RDB.Expire(ctx, key, expiration).Err()
- //}
- //
- //func RedisGetEx(key string, expiration time.Duration) (string, error) {
- // ctx := context.Background()
- // return RDB.GetSet(ctx, key, expiration).Result()
- //}
- func RedisDel(key string) error {
- if DebugEnabled {
- SysLog(fmt.Sprintf("Redis DEL: key=%s", key))
- }
- ctx := context.Background()
- return RDB.Del(ctx, key).Err()
- }
- func RedisDelKey(key string) error {
- if DebugEnabled {
- SysLog(fmt.Sprintf("Redis DEL Key: key=%s", key))
- }
- ctx := context.Background()
- return RDB.Del(ctx, key).Err()
- }
- func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
- if DebugEnabled {
- SysLog(fmt.Sprintf("Redis HSET: key=%s, obj=%+v, expiration=%v", key, obj, expiration))
- }
- ctx := context.Background()
- data := make(map[string]interface{})
- // 使用反射遍历结构体字段
- v := reflect.ValueOf(obj).Elem()
- t := v.Type()
- for i := 0; i < v.NumField(); i++ {
- field := t.Field(i)
- value := v.Field(i)
- // Skip DeletedAt field
- if field.Type.String() == "gorm.DeletedAt" {
- continue
- }
- // 处理指针类型
- if value.Kind() == reflect.Ptr {
- if value.IsNil() {
- data[field.Name] = ""
- continue
- }
- value = value.Elem()
- }
- // 处理布尔类型
- if value.Kind() == reflect.Bool {
- data[field.Name] = strconv.FormatBool(value.Bool())
- continue
- }
- // 其他类型直接转换为字符串
- data[field.Name] = fmt.Sprintf("%v", value.Interface())
- }
- txn := RDB.TxPipeline()
- txn.HSet(ctx, key, data)
- // 只有在 expiration 大于 0 时才设置过期时间
- if expiration > 0 {
- txn.Expire(ctx, key, expiration)
- }
- _, err := txn.Exec(ctx)
- if err != nil {
- return fmt.Errorf("failed to execute transaction: %w", err)
- }
- return nil
- }
- func RedisHGetObj(key string, obj interface{}) error {
- if DebugEnabled {
- SysLog(fmt.Sprintf("Redis HGETALL: key=%s", key))
- }
- ctx := context.Background()
- result, err := RDB.HGetAll(ctx, key).Result()
- if err != nil {
- return fmt.Errorf("failed to load hash from Redis: %w", err)
- }
- if len(result) == 0 {
- return fmt.Errorf("key %s not found in Redis", key)
- }
- // Handle both pointer and non-pointer values
- val := reflect.ValueOf(obj)
- if val.Kind() != reflect.Ptr {
- return fmt.Errorf("obj must be a pointer to a struct, got %T", obj)
- }
- v := val.Elem()
- if v.Kind() != reflect.Struct {
- return fmt.Errorf("obj must be a pointer to a struct, got pointer to %T", v.Interface())
- }
- t := v.Type()
- for i := 0; i < v.NumField(); i++ {
- field := t.Field(i)
- fieldName := field.Name
- if value, ok := result[fieldName]; ok {
- fieldValue := v.Field(i)
- // Handle pointer types
- if fieldValue.Kind() == reflect.Ptr {
- if value == "" {
- continue
- }
- if fieldValue.IsNil() {
- fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
- }
- fieldValue = fieldValue.Elem()
- }
- // Enhanced type handling for Token struct
- switch fieldValue.Kind() {
- case reflect.String:
- fieldValue.SetString(value)
- case reflect.Int, reflect.Int64:
- intValue, err := strconv.ParseInt(value, 10, 64)
- if err != nil {
- return fmt.Errorf("failed to parse int field %s: %w", fieldName, err)
- }
- fieldValue.SetInt(intValue)
- case reflect.Bool:
- boolValue, err := strconv.ParseBool(value)
- if err != nil {
- return fmt.Errorf("failed to parse bool field %s: %w", fieldName, err)
- }
- fieldValue.SetBool(boolValue)
- case reflect.Struct:
- // Special handling for gorm.DeletedAt
- if fieldValue.Type().String() == "gorm.DeletedAt" {
- if value != "" {
- timeValue, err := time.Parse(time.RFC3339, value)
- if err != nil {
- return fmt.Errorf("failed to parse DeletedAt field %s: %w", fieldName, err)
- }
- fieldValue.Set(reflect.ValueOf(gorm.DeletedAt{Time: timeValue, Valid: true}))
- }
- }
- default:
- return fmt.Errorf("unsupported field type: %s for field %s", fieldValue.Kind(), fieldName)
- }
- }
- }
- return nil
- }
- // RedisIncr Add this function to handle atomic increments
- func RedisIncr(key string, delta int64) error {
- if DebugEnabled {
- SysLog(fmt.Sprintf("Redis INCR: key=%s, delta=%d", key, delta))
- }
- // 检查键的剩余生存时间
- ttlCmd := RDB.TTL(context.Background(), key)
- ttl, err := ttlCmd.Result()
- if err != nil && !errors.Is(err, redis.Nil) {
- return fmt.Errorf("failed to get TTL: %w", err)
- }
- // 只有在 key 存在且有 TTL 时才需要特殊处理
- if ttl > 0 {
- ctx := context.Background()
- // 开始一个Redis事务
- txn := RDB.TxPipeline()
- // 减少余额
- decrCmd := txn.IncrBy(ctx, key, delta)
- if err := decrCmd.Err(); err != nil {
- return err // 如果减少失败,则直接返回错误
- }
- // 重新设置过期时间,使用原来的过期时间
- txn.Expire(ctx, key, ttl)
- // 执行事务
- _, err = txn.Exec(ctx)
- return err
- }
- return nil
- }
- func RedisHIncrBy(key, field string, delta int64) error {
- if DebugEnabled {
- SysLog(fmt.Sprintf("Redis HINCRBY: key=%s, field=%s, delta=%d", key, field, delta))
- }
- ttlCmd := RDB.TTL(context.Background(), key)
- ttl, err := ttlCmd.Result()
- if err != nil && !errors.Is(err, redis.Nil) {
- return fmt.Errorf("failed to get TTL: %w", err)
- }
- if ttl > 0 {
- ctx := context.Background()
- txn := RDB.TxPipeline()
- incrCmd := txn.HIncrBy(ctx, key, field, delta)
- if err := incrCmd.Err(); err != nil {
- return err
- }
- txn.Expire(ctx, key, ttl)
- _, err = txn.Exec(ctx)
- return err
- }
- return nil
- }
- func RedisHSetField(key, field string, value interface{}) error {
- if DebugEnabled {
- SysLog(fmt.Sprintf("Redis HSET field: key=%s, field=%s, value=%v", key, field, value))
- }
- ttlCmd := RDB.TTL(context.Background(), key)
- ttl, err := ttlCmd.Result()
- if err != nil && !errors.Is(err, redis.Nil) {
- return fmt.Errorf("failed to get TTL: %w", err)
- }
- if ttl > 0 {
- ctx := context.Background()
- txn := RDB.TxPipeline()
- hsetCmd := txn.HSet(ctx, key, field, value)
- if err := hsetCmd.Err(); err != nil {
- return err
- }
- txn.Expire(ctx, key, ttl)
- _, err = txn.Exec(ctx)
- return err
- }
- return nil
- }
|