|
|
@@ -2,11 +2,15 @@ package common
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
+ "errors"
|
|
|
"fmt"
|
|
|
"os"
|
|
|
+ "reflect"
|
|
|
+ "strconv"
|
|
|
"time"
|
|
|
|
|
|
"github.com/go-redis/redis/v8"
|
|
|
+ "gorm.io/gorm"
|
|
|
)
|
|
|
|
|
|
var RDB *redis.Client
|
|
|
@@ -58,39 +62,167 @@ func RedisGet(key string) (string, error) {
|
|
|
return RDB.Get(ctx, key).Result()
|
|
|
}
|
|
|
|
|
|
-func RedisExpire(key string, expiration time.Duration) error {
|
|
|
+//func RedisExpire(key string, expiration time.Duration) error {
|
|
|
+// ctx := context.Background()
|
|
|
+// return RDB.Expire(ctx, key, expiration).Err()
|
|
|
+//}
|
|
|
+//
|
|
|
+//func RedisGetEx(key string, expiration time.Duration) (string, error) {
|
|
|
+// ctx := context.Background()
|
|
|
+// return RDB.GetSet(ctx, key, expiration).Result()
|
|
|
+//}
|
|
|
+
|
|
|
+func RedisDel(key string) error {
|
|
|
ctx := context.Background()
|
|
|
- return RDB.Expire(ctx, key, expiration).Err()
|
|
|
+ return RDB.Del(ctx, key).Err()
|
|
|
}
|
|
|
|
|
|
-func RedisGetEx(key string, expiration time.Duration) (string, error) {
|
|
|
+func RedisHDelObj(key string) error {
|
|
|
ctx := context.Background()
|
|
|
- return RDB.GetSet(ctx, key, expiration).Result()
|
|
|
+ return RDB.HDel(ctx, key).Err()
|
|
|
}
|
|
|
|
|
|
-func RedisDel(key string) error {
|
|
|
+func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
|
|
|
ctx := context.Background()
|
|
|
- return RDB.Del(ctx, key).Err()
|
|
|
+
|
|
|
+ data := make(map[string]interface{})
|
|
|
+
|
|
|
+ // 使用反射遍历结构体字段
|
|
|
+ v := reflect.ValueOf(obj).Elem()
|
|
|
+ t := v.Type()
|
|
|
+ for i := 0; i < v.NumField(); i++ {
|
|
|
+ field := t.Field(i)
|
|
|
+ value := v.Field(i)
|
|
|
+
|
|
|
+ // Skip DeletedAt field
|
|
|
+ if field.Type.String() == "gorm.DeletedAt" {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ // 处理指针类型
|
|
|
+ if value.Kind() == reflect.Ptr {
|
|
|
+ if value.IsNil() {
|
|
|
+ data[field.Name] = ""
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ value = value.Elem()
|
|
|
+ }
|
|
|
+
|
|
|
+ // 处理布尔类型
|
|
|
+ if value.Kind() == reflect.Bool {
|
|
|
+ data[field.Name] = strconv.FormatBool(value.Bool())
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ // 其他类型直接转换为字符串
|
|
|
+ data[field.Name] = fmt.Sprintf("%v", value.Interface())
|
|
|
+ }
|
|
|
+
|
|
|
+ txn := RDB.TxPipeline()
|
|
|
+ txn.HSet(ctx, key, data)
|
|
|
+ txn.Expire(ctx, key, expiration)
|
|
|
+
|
|
|
+ _, err := txn.Exec(ctx)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("failed to execute transaction: %w", err)
|
|
|
+ }
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
-func RedisDecrease(key string, value int64) error {
|
|
|
+func RedisHGetObj(key string, obj interface{}) error {
|
|
|
+ ctx := context.Background()
|
|
|
+
|
|
|
+ result, err := RDB.HGetAll(ctx, key).Result()
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("failed to load hash from Redis: %w", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(result) == 0 {
|
|
|
+ return fmt.Errorf("key %s not found in Redis", key)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Handle both pointer and non-pointer values
|
|
|
+ val := reflect.ValueOf(obj)
|
|
|
+ if val.Kind() != reflect.Ptr {
|
|
|
+ return fmt.Errorf("obj must be a pointer to a struct, got %T", obj)
|
|
|
+ }
|
|
|
+
|
|
|
+ v := val.Elem()
|
|
|
+ if v.Kind() != reflect.Struct {
|
|
|
+ return fmt.Errorf("obj must be a pointer to a struct, got pointer to %T", v.Interface())
|
|
|
+ }
|
|
|
+
|
|
|
+ t := v.Type()
|
|
|
+ for i := 0; i < v.NumField(); i++ {
|
|
|
+ field := t.Field(i)
|
|
|
+ fieldName := field.Name
|
|
|
+ if value, ok := result[fieldName]; ok {
|
|
|
+ fieldValue := v.Field(i)
|
|
|
+
|
|
|
+ // Handle pointer types
|
|
|
+ if fieldValue.Kind() == reflect.Ptr {
|
|
|
+ if value == "" {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if fieldValue.IsNil() {
|
|
|
+ fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
|
|
|
+ }
|
|
|
+ fieldValue = fieldValue.Elem()
|
|
|
+ }
|
|
|
+
|
|
|
+ // Enhanced type handling for Token struct
|
|
|
+ switch fieldValue.Kind() {
|
|
|
+ case reflect.String:
|
|
|
+ fieldValue.SetString(value)
|
|
|
+ case reflect.Int, reflect.Int64:
|
|
|
+ intValue, err := strconv.ParseInt(value, 10, 64)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("failed to parse int field %s: %w", fieldName, err)
|
|
|
+ }
|
|
|
+ fieldValue.SetInt(intValue)
|
|
|
+ case reflect.Bool:
|
|
|
+ boolValue, err := strconv.ParseBool(value)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("failed to parse bool field %s: %w", fieldName, err)
|
|
|
+ }
|
|
|
+ fieldValue.SetBool(boolValue)
|
|
|
+ case reflect.Struct:
|
|
|
+ // Special handling for gorm.DeletedAt
|
|
|
+ if fieldValue.Type().String() == "gorm.DeletedAt" {
|
|
|
+ if value != "" {
|
|
|
+ timeValue, err := time.Parse(time.RFC3339, value)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("failed to parse DeletedAt field %s: %w", fieldName, err)
|
|
|
+ }
|
|
|
+ fieldValue.Set(reflect.ValueOf(gorm.DeletedAt{Time: timeValue, Valid: true}))
|
|
|
+ }
|
|
|
+ }
|
|
|
+ default:
|
|
|
+ return fmt.Errorf("unsupported field type: %s for field %s", fieldValue.Kind(), fieldName)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
|
|
|
+// RedisIncr Add this function to handle atomic increments
|
|
|
+func RedisIncr(key string, delta int64) error {
|
|
|
// 检查键的剩余生存时间
|
|
|
ttlCmd := RDB.TTL(context.Background(), key)
|
|
|
ttl, err := ttlCmd.Result()
|
|
|
- if err != nil {
|
|
|
- // 失败则尝试直接减少
|
|
|
- return RDB.DecrBy(context.Background(), key, value).Err()
|
|
|
+ if err != nil && !errors.Is(err, redis.Nil) {
|
|
|
+ return fmt.Errorf("failed to get TTL: %w", err)
|
|
|
}
|
|
|
|
|
|
- // 如果剩余生存时间大于0,则进行减少操作
|
|
|
+ // 只有在 key 存在且有 TTL 时才需要特殊处理
|
|
|
if ttl > 0 {
|
|
|
ctx := context.Background()
|
|
|
// 开始一个Redis事务
|
|
|
txn := RDB.TxPipeline()
|
|
|
|
|
|
// 减少余额
|
|
|
- decrCmd := txn.DecrBy(ctx, key, value)
|
|
|
+ decrCmd := txn.IncrBy(ctx, key, delta)
|
|
|
if err := decrCmd.Err(); err != nil {
|
|
|
return err // 如果减少失败,则直接返回错误
|
|
|
}
|
|
|
@@ -101,26 +233,54 @@ func RedisDecrease(key string, value int64) error {
|
|
|
// 执行事务
|
|
|
_, err = txn.Exec(ctx)
|
|
|
return err
|
|
|
- } else {
|
|
|
- _ = RedisDel(key)
|
|
|
}
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-// RedisIncr Add this function to handle atomic increments
|
|
|
-func RedisIncr(key string, delta int) error {
|
|
|
- ctx := context.Background()
|
|
|
+func RedisHIncrBy(key, field string, delta int64) error {
|
|
|
+ ttlCmd := RDB.TTL(context.Background(), key)
|
|
|
+ ttl, err := ttlCmd.Result()
|
|
|
+ if err != nil && !errors.Is(err, redis.Nil) {
|
|
|
+ return fmt.Errorf("failed to get TTL: %w", err)
|
|
|
+ }
|
|
|
|
|
|
- // 检查键是否存在
|
|
|
- exists, err := RDB.Exists(ctx, key).Result()
|
|
|
- if err != nil {
|
|
|
+ if ttl > 0 {
|
|
|
+ ctx := context.Background()
|
|
|
+ txn := RDB.TxPipeline()
|
|
|
+
|
|
|
+ incrCmd := txn.HIncrBy(ctx, key, field, delta)
|
|
|
+ if err := incrCmd.Err(); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ txn.Expire(ctx, key, ttl)
|
|
|
+
|
|
|
+ _, err = txn.Exec(ctx)
|
|
|
return err
|
|
|
}
|
|
|
- if exists == 0 {
|
|
|
- return fmt.Errorf("key does not exist") // 键不存在,返回错误
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func RedisHSetField(key, field string, value interface{}) error {
|
|
|
+ ttlCmd := RDB.TTL(context.Background(), key)
|
|
|
+ ttl, err := ttlCmd.Result()
|
|
|
+ if err != nil && !errors.Is(err, redis.Nil) {
|
|
|
+ return fmt.Errorf("failed to get TTL: %w", err)
|
|
|
}
|
|
|
|
|
|
- // 键存在,执行INCRBY操作
|
|
|
- result := RDB.IncrBy(ctx, key, int64(delta))
|
|
|
- return result.Err()
|
|
|
+ if ttl > 0 {
|
|
|
+ ctx := context.Background()
|
|
|
+ txn := RDB.TxPipeline()
|
|
|
+
|
|
|
+ hsetCmd := txn.HSet(ctx, key, field, value)
|
|
|
+ if err := hsetCmd.Err(); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ txn.Expire(ctx, key, ttl)
|
|
|
+
|
|
|
+ _, err = txn.Exec(ctx)
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ return nil
|
|
|
}
|