redis.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. package common
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "os"
  7. "reflect"
  8. "strconv"
  9. "time"
  10. "github.com/go-redis/redis/v8"
  11. "gorm.io/gorm"
  12. )
  13. var RDB *redis.Client
  14. var RedisEnabled = true
  15. func RedisKeyCacheSeconds() int {
  16. return SyncFrequency
  17. }
  18. // InitRedisClient This function is called after init()
  19. func InitRedisClient() (err error) {
  20. if os.Getenv("REDIS_CONN_STRING") == "" {
  21. RedisEnabled = false
  22. SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
  23. return nil
  24. }
  25. if os.Getenv("SYNC_FREQUENCY") == "" {
  26. SysLog("SYNC_FREQUENCY not set, use default value 60")
  27. SyncFrequency = 60
  28. }
  29. SysLog("Redis is enabled")
  30. opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
  31. if err != nil {
  32. FatalLog("failed to parse Redis connection string: " + err.Error())
  33. }
  34. opt.PoolSize = GetEnvOrDefault("REDIS_POOL_SIZE", 10)
  35. RDB = redis.NewClient(opt)
  36. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  37. defer cancel()
  38. _, err = RDB.Ping(ctx).Result()
  39. if err != nil {
  40. FatalLog("Redis ping test failed: " + err.Error())
  41. }
  42. if DebugEnabled {
  43. SysLog(fmt.Sprintf("Redis connected to %s", opt.Addr))
  44. SysLog(fmt.Sprintf("Redis database: %d", opt.DB))
  45. }
  46. return err
  47. }
  48. func ParseRedisOption() *redis.Options {
  49. opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
  50. if err != nil {
  51. FatalLog("failed to parse Redis connection string: " + err.Error())
  52. }
  53. return opt
  54. }
  55. func RedisSet(key string, value string, expiration time.Duration) error {
  56. if DebugEnabled {
  57. SysLog(fmt.Sprintf("Redis SET: key=%s, value=%s, expiration=%v", key, value, expiration))
  58. }
  59. ctx := context.Background()
  60. return RDB.Set(ctx, key, value, expiration).Err()
  61. }
  62. func RedisGet(key string) (string, error) {
  63. if DebugEnabled {
  64. SysLog(fmt.Sprintf("Redis GET: key=%s", key))
  65. }
  66. ctx := context.Background()
  67. val, err := RDB.Get(ctx, key).Result()
  68. return val, err
  69. }
  70. //func RedisExpire(key string, expiration time.Duration) error {
  71. // ctx := context.Background()
  72. // return RDB.Expire(ctx, key, expiration).Err()
  73. //}
  74. //
  75. //func RedisGetEx(key string, expiration time.Duration) (string, error) {
  76. // ctx := context.Background()
  77. // return RDB.GetSet(ctx, key, expiration).Result()
  78. //}
  79. func RedisDel(key string) error {
  80. if DebugEnabled {
  81. SysLog(fmt.Sprintf("Redis DEL: key=%s", key))
  82. }
  83. ctx := context.Background()
  84. return RDB.Del(ctx, key).Err()
  85. }
  86. func RedisDelKey(key string) error {
  87. if DebugEnabled {
  88. SysLog(fmt.Sprintf("Redis DEL Key: key=%s", key))
  89. }
  90. ctx := context.Background()
  91. return RDB.Del(ctx, key).Err()
  92. }
  93. func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
  94. if DebugEnabled {
  95. SysLog(fmt.Sprintf("Redis HSET: key=%s, obj=%+v, expiration=%v", key, obj, expiration))
  96. }
  97. ctx := context.Background()
  98. data := make(map[string]interface{})
  99. // 使用反射遍历结构体字段
  100. v := reflect.ValueOf(obj).Elem()
  101. t := v.Type()
  102. for i := 0; i < v.NumField(); i++ {
  103. field := t.Field(i)
  104. value := v.Field(i)
  105. // Skip DeletedAt field
  106. if field.Type.String() == "gorm.DeletedAt" {
  107. continue
  108. }
  109. // 处理指针类型
  110. if value.Kind() == reflect.Ptr {
  111. if value.IsNil() {
  112. data[field.Name] = ""
  113. continue
  114. }
  115. value = value.Elem()
  116. }
  117. // 处理布尔类型
  118. if value.Kind() == reflect.Bool {
  119. data[field.Name] = strconv.FormatBool(value.Bool())
  120. continue
  121. }
  122. // 其他类型直接转换为字符串
  123. data[field.Name] = fmt.Sprintf("%v", value.Interface())
  124. }
  125. txn := RDB.TxPipeline()
  126. txn.HSet(ctx, key, data)
  127. // 只有在 expiration 大于 0 时才设置过期时间
  128. if expiration > 0 {
  129. txn.Expire(ctx, key, expiration)
  130. }
  131. _, err := txn.Exec(ctx)
  132. if err != nil {
  133. return fmt.Errorf("failed to execute transaction: %w", err)
  134. }
  135. return nil
  136. }
  137. func RedisHGetObj(key string, obj interface{}) error {
  138. if DebugEnabled {
  139. SysLog(fmt.Sprintf("Redis HGETALL: key=%s", key))
  140. }
  141. ctx := context.Background()
  142. result, err := RDB.HGetAll(ctx, key).Result()
  143. if err != nil {
  144. return fmt.Errorf("failed to load hash from Redis: %w", err)
  145. }
  146. if len(result) == 0 {
  147. return fmt.Errorf("key %s not found in Redis", key)
  148. }
  149. // Handle both pointer and non-pointer values
  150. val := reflect.ValueOf(obj)
  151. if val.Kind() != reflect.Ptr {
  152. return fmt.Errorf("obj must be a pointer to a struct, got %T", obj)
  153. }
  154. v := val.Elem()
  155. if v.Kind() != reflect.Struct {
  156. return fmt.Errorf("obj must be a pointer to a struct, got pointer to %T", v.Interface())
  157. }
  158. t := v.Type()
  159. for i := 0; i < v.NumField(); i++ {
  160. field := t.Field(i)
  161. fieldName := field.Name
  162. if value, ok := result[fieldName]; ok {
  163. fieldValue := v.Field(i)
  164. // Handle pointer types
  165. if fieldValue.Kind() == reflect.Ptr {
  166. if value == "" {
  167. continue
  168. }
  169. if fieldValue.IsNil() {
  170. fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
  171. }
  172. fieldValue = fieldValue.Elem()
  173. }
  174. // Enhanced type handling for Token struct
  175. switch fieldValue.Kind() {
  176. case reflect.String:
  177. fieldValue.SetString(value)
  178. case reflect.Int, reflect.Int64:
  179. intValue, err := strconv.ParseInt(value, 10, 64)
  180. if err != nil {
  181. return fmt.Errorf("failed to parse int field %s: %w", fieldName, err)
  182. }
  183. fieldValue.SetInt(intValue)
  184. case reflect.Bool:
  185. boolValue, err := strconv.ParseBool(value)
  186. if err != nil {
  187. return fmt.Errorf("failed to parse bool field %s: %w", fieldName, err)
  188. }
  189. fieldValue.SetBool(boolValue)
  190. case reflect.Struct:
  191. // Special handling for gorm.DeletedAt
  192. if fieldValue.Type().String() == "gorm.DeletedAt" {
  193. if value != "" {
  194. timeValue, err := time.Parse(time.RFC3339, value)
  195. if err != nil {
  196. return fmt.Errorf("failed to parse DeletedAt field %s: %w", fieldName, err)
  197. }
  198. fieldValue.Set(reflect.ValueOf(gorm.DeletedAt{Time: timeValue, Valid: true}))
  199. }
  200. }
  201. default:
  202. return fmt.Errorf("unsupported field type: %s for field %s", fieldValue.Kind(), fieldName)
  203. }
  204. }
  205. }
  206. return nil
  207. }
  208. // RedisIncr Add this function to handle atomic increments
  209. func RedisIncr(key string, delta int64) error {
  210. if DebugEnabled {
  211. SysLog(fmt.Sprintf("Redis INCR: key=%s, delta=%d", key, delta))
  212. }
  213. // 检查键的剩余生存时间
  214. ttlCmd := RDB.TTL(context.Background(), key)
  215. ttl, err := ttlCmd.Result()
  216. if err != nil && !errors.Is(err, redis.Nil) {
  217. return fmt.Errorf("failed to get TTL: %w", err)
  218. }
  219. // 只有在 key 存在且有 TTL 时才需要特殊处理
  220. if ttl > 0 {
  221. ctx := context.Background()
  222. // 开始一个Redis事务
  223. txn := RDB.TxPipeline()
  224. // 减少余额
  225. decrCmd := txn.IncrBy(ctx, key, delta)
  226. if err := decrCmd.Err(); err != nil {
  227. return err // 如果减少失败,则直接返回错误
  228. }
  229. // 重新设置过期时间,使用原来的过期时间
  230. txn.Expire(ctx, key, ttl)
  231. // 执行事务
  232. _, err = txn.Exec(ctx)
  233. return err
  234. }
  235. return nil
  236. }
  237. func RedisHIncrBy(key, field string, delta int64) error {
  238. if DebugEnabled {
  239. SysLog(fmt.Sprintf("Redis HINCRBY: key=%s, field=%s, delta=%d", key, field, delta))
  240. }
  241. ttlCmd := RDB.TTL(context.Background(), key)
  242. ttl, err := ttlCmd.Result()
  243. if err != nil && !errors.Is(err, redis.Nil) {
  244. return fmt.Errorf("failed to get TTL: %w", err)
  245. }
  246. if ttl > 0 {
  247. ctx := context.Background()
  248. txn := RDB.TxPipeline()
  249. incrCmd := txn.HIncrBy(ctx, key, field, delta)
  250. if err := incrCmd.Err(); err != nil {
  251. return err
  252. }
  253. txn.Expire(ctx, key, ttl)
  254. _, err = txn.Exec(ctx)
  255. return err
  256. }
  257. return nil
  258. }
  259. func RedisHSetField(key, field string, value interface{}) error {
  260. if DebugEnabled {
  261. SysLog(fmt.Sprintf("Redis HSET field: key=%s, field=%s, value=%v", key, field, value))
  262. }
  263. ttlCmd := RDB.TTL(context.Background(), key)
  264. ttl, err := ttlCmd.Result()
  265. if err != nil && !errors.Is(err, redis.Nil) {
  266. return fmt.Errorf("failed to get TTL: %w", err)
  267. }
  268. if ttl > 0 {
  269. ctx := context.Background()
  270. txn := RDB.TxPipeline()
  271. hsetCmd := txn.HSet(ctx, key, field, value)
  272. if err := hsetCmd.Err(); err != nil {
  273. return err
  274. }
  275. txn.Expire(ctx, key, ttl)
  276. _, err = txn.Exec(ctx)
  277. return err
  278. }
  279. return nil
  280. }