2
0

cache.go 8.7 KB


  1. package cache
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/sha256"
  6. "encoding/hex"
  7. "errors"
  8. "fmt"
  9. "maps"
  10. "net/http"
  11. "strconv"
  12. "sync"
  13. "time"
  14. "github.com/bytedance/sonic"
  15. "github.com/gin-gonic/gin"
  16. "github.com/labring/aiproxy/core/common"
  17. "github.com/labring/aiproxy/core/model"
  18. "github.com/labring/aiproxy/core/relay/adaptor"
  19. "github.com/labring/aiproxy/core/relay/meta"
  20. "github.com/labring/aiproxy/core/relay/plugin"
  21. "github.com/labring/aiproxy/core/relay/plugin/noop"
  22. gcache "github.com/patrickmn/go-cache"
  23. "github.com/redis/go-redis/v9"
  24. )
  25. // Constants for cache metadata keys
  26. const (
  27. cacheKey = "cache_key"
  28. cacheHit = "cache_hit"
  29. cacheValue = "cache_value"
  30. )
  31. // Constants for plugin configuration
  32. const (
  33. pluginConfigCacheKey = "cache-config"
  34. cacheHeader = "X-Aiproxy-Cache"
  35. redisCachePrefix = "cache:"
  36. )
  37. // Buffer size constants
  38. const (
  39. defaultBufferSize = 512 * 1024
  40. maxBufferSize = 4 * defaultBufferSize
  41. )
  42. // Item represents a cached response
  43. type Item struct {
  44. Body []byte `json:"body"`
  45. Header map[string][]string `json:"header"`
  46. Usage model.Usage `json:"usage"`
  47. }
  48. // Cache implements caching functionality for AI requests
  49. type Cache struct {
  50. noop.Noop
  51. rdb *redis.Client
  52. }
  53. var (
  54. _ plugin.Plugin = (*Cache)(nil)
  55. // Global cache instance with 5 minute default TTL and 10 minute cleanup interval
  56. cache = gcache.New(30*time.Second, 5*time.Minute)
  57. // Buffer pool for response writers
  58. bufferPool = sync.Pool{
  59. New: func() any {
  60. return bytes.NewBuffer(make([]byte, 0, defaultBufferSize))
  61. },
  62. }
  63. )
  64. // NewCachePlugin creates a new cache plugin
  65. func NewCachePlugin(rdb *redis.Client) plugin.Plugin {
  66. return &Cache{rdb: rdb}
  67. }
  68. // Cache metadata helpers
  69. func getCacheKey(meta *meta.Meta) string {
  70. return meta.GetString(cacheKey)
  71. }
  72. func setCacheKey(meta *meta.Meta, key string) {
  73. meta.Set(cacheKey, key)
  74. }
  75. func isCacheHit(meta *meta.Meta) bool {
  76. return meta.GetBool(cacheHit)
  77. }
  78. func getCacheItem(meta *meta.Meta) *Item {
  79. v, ok := meta.Get(cacheValue)
  80. if !ok {
  81. return nil
  82. }
  83. item, ok := v.(*Item)
  84. if !ok {
  85. panic(fmt.Sprintf("cache item type not match: %T", v))
  86. }
  87. return item
  88. }
  89. func setCacheHit(meta *meta.Meta, item *Item) {
  90. meta.Set(cacheHit, true)
  91. meta.Set(cacheValue, item)
  92. }
  93. // Buffer pool helpers
  94. func getBuffer() *bytes.Buffer {
  95. v, ok := bufferPool.Get().(*bytes.Buffer)
  96. if !ok {
  97. panic(fmt.Sprintf("buffer type error: %T, %v", v, v))
  98. }
  99. return v
  100. }
  101. func putBuffer(buf *bytes.Buffer) {
  102. buf.Reset()
  103. if buf.Cap() > maxBufferSize {
  104. return
  105. }
  106. bufferPool.Put(buf)
  107. }
  108. // getPluginConfig retrieves the plugin configuration from metadata
  109. func getPluginConfig(meta *meta.Meta) (config *Config, err error) {
  110. v, ok := meta.Get(pluginConfigCacheKey)
  111. if ok {
  112. config, ok := v.(*Config)
  113. if !ok {
  114. panic(fmt.Sprintf("cache config type not match: %T", v))
  115. }
  116. return config, nil
  117. }
  118. pluginConfig := Config{}
  119. if err := meta.ModelConfig.LoadPluginConfig("cache", &pluginConfig); err != nil {
  120. return nil, err
  121. }
  122. meta.Set(pluginConfigCacheKey, &pluginConfig)
  123. return &pluginConfig, nil
  124. }
  125. // Redis cache operations
  126. func (c *Cache) getFromRedis(ctx context.Context, key string) (*Item, error) {
  127. if c.rdb == nil {
  128. return nil, nil
  129. }
  130. data, err := c.rdb.Get(ctx, common.RedisKey(redisCachePrefix, key)).Bytes()
  131. if err != nil {
  132. if errors.Is(err, redis.Nil) {
  133. return nil, nil
  134. }
  135. return nil, err
  136. }
  137. var item Item
  138. if err := sonic.Unmarshal(data, &item); err != nil {
  139. return nil, err
  140. }
  141. return &item, nil
  142. }
  143. func (c *Cache) setToRedis(ctx context.Context, key string, item *Item, ttl time.Duration) error {
  144. if c.rdb == nil {
  145. return nil
  146. }
  147. data, err := sonic.Marshal(item)
  148. if err != nil {
  149. return err
  150. }
  151. return c.rdb.Set(ctx, common.RedisKey(redisCachePrefix, key), data, ttl).Err()
  152. }
  153. // getFromCache retrieves item from cache (Redis or memory)
  154. func (c *Cache) getFromCache(ctx context.Context, key string) (*Item, bool) {
  155. // Try Redis first if available
  156. if c.rdb != nil {
  157. item, err := c.getFromRedis(ctx, key)
  158. if err == nil && item != nil {
  159. return item, true
  160. }
  161. // If Redis fails, fallback to memory cache
  162. }
  163. // Try memory cache
  164. if v, ok := cache.Get(key); ok {
  165. if item, ok := v.(Item); ok {
  166. return &item, true
  167. }
  168. }
  169. return nil, false
  170. }
  171. // setToCache stores item in cache (Redis and/or memory)
  172. func (c *Cache) setToCache(ctx context.Context, key string, item Item, ttl time.Duration) {
  173. // Set to Redis if available
  174. if c.rdb != nil {
  175. if err := c.setToRedis(ctx, key, &item, ttl); err == nil {
  176. // If Redis succeeds, also set to memory cache for faster access
  177. cache.Set(key, item, ttl)
  178. return
  179. }
  180. // If Redis fails, fallback to memory cache only
  181. }
  182. // Set to memory cache
  183. cache.Set(key, item, ttl)
  184. }
  185. // ConvertRequest handles the request conversion phase
  186. func (c *Cache) ConvertRequest(
  187. meta *meta.Meta,
  188. store adaptor.Store,
  189. req *http.Request,
  190. do adaptor.ConvertRequest,
  191. ) (adaptor.ConvertResult, error) {
  192. pluginConfig, err := getPluginConfig(meta)
  193. if err != nil {
  194. return do.ConvertRequest(meta, store, req)
  195. }
  196. if !pluginConfig.Enable {
  197. return do.ConvertRequest(meta, store, req)
  198. }
  199. body, err := common.GetRequestBodyReusable(req)
  200. if err != nil {
  201. return adaptor.ConvertResult{}, err
  202. }
  203. if len(body) == 0 {
  204. return do.ConvertRequest(meta, store, req)
  205. }
  206. // Generate hash as cache key
  207. hash := sha256.Sum256(body)
  208. cacheKey := fmt.Sprintf("%d:%s", meta.Mode, hex.EncodeToString(hash[:]))
  209. setCacheKey(meta, cacheKey)
  210. // Check cache
  211. ctx := req.Context()
  212. if item, ok := c.getFromCache(ctx, cacheKey); ok {
  213. setCacheHit(meta, item)
  214. return adaptor.ConvertResult{}, nil
  215. }
  216. return do.ConvertRequest(meta, store, req)
  217. }
  218. // DoRequest handles the request execution phase
  219. func (c *Cache) DoRequest(
  220. meta *meta.Meta,
  221. store adaptor.Store,
  222. ctx *gin.Context,
  223. req *http.Request,
  224. do adaptor.DoRequest,
  225. ) (*http.Response, error) {
  226. if isCacheHit(meta) {
  227. return &http.Response{}, nil
  228. }
  229. return do.DoRequest(meta, store, ctx, req)
  230. }
  231. // Custom response writer to capture response for caching
  232. type responseWriter struct {
  233. gin.ResponseWriter
  234. cacheBody *bytes.Buffer
  235. maxSize int
  236. overflow bool
  237. }
  238. func (rw *responseWriter) Write(b []byte) (int, error) {
  239. if rw.overflow {
  240. return rw.ResponseWriter.Write(b)
  241. }
  242. if rw.maxSize > 0 && rw.cacheBody.Len()+len(b) > rw.maxSize {
  243. rw.overflow = true
  244. rw.cacheBody.Reset()
  245. return rw.ResponseWriter.Write(b)
  246. }
  247. rw.cacheBody.Write(b)
  248. return rw.ResponseWriter.Write(b)
  249. }
  250. func (rw *responseWriter) WriteString(s string) (int, error) {
  251. if rw.overflow {
  252. return rw.ResponseWriter.WriteString(s)
  253. }
  254. if rw.maxSize > 0 && rw.cacheBody.Len()+len(s) > rw.maxSize {
  255. rw.overflow = true
  256. rw.cacheBody.Reset()
  257. return rw.ResponseWriter.WriteString(s)
  258. }
  259. rw.cacheBody.WriteString(s)
  260. return rw.ResponseWriter.WriteString(s)
  261. }
  262. func (c *Cache) writeCacheHeader(ctx *gin.Context, pluginConfig *Config, value string) {
  263. if pluginConfig.AddCacheHitHeader {
  264. header := pluginConfig.CacheHitHeader
  265. if header == "" {
  266. header = cacheHeader
  267. }
  268. ctx.Header(header, value)
  269. }
  270. }
  271. // DoResponse handles the response processing phase
  272. func (c *Cache) DoResponse(
  273. meta *meta.Meta,
  274. store adaptor.Store,
  275. ctx *gin.Context,
  276. resp *http.Response,
  277. do adaptor.DoResponse,
  278. ) (usage model.Usage, adapterErr adaptor.Error) {
  279. pluginConfig, err := getPluginConfig(meta)
  280. if err != nil {
  281. return do.DoResponse(meta, store, ctx, resp)
  282. }
  283. // Handle cache hit
  284. if isCacheHit(meta) {
  285. item := getCacheItem(meta)
  286. if item == nil {
  287. return do.DoResponse(meta, store, ctx, resp)
  288. }
  289. // Restore headers from cache
  290. for k, v := range item.Header {
  291. for _, val := range v {
  292. ctx.Header(k, val)
  293. }
  294. }
  295. // Override specific headers
  296. ctx.Header("Content-Type", item.Header["Content-Type"][0])
  297. ctx.Header("Content-Length", strconv.Itoa(len(item.Body)))
  298. c.writeCacheHeader(ctx, pluginConfig, "hit")
  299. _, _ = ctx.Writer.Write(item.Body)
  300. return item.Usage, nil
  301. }
  302. if !pluginConfig.Enable {
  303. return do.DoResponse(meta, store, ctx, resp)
  304. }
  305. c.writeCacheHeader(ctx, pluginConfig, "miss")
  306. // Set up response capture for caching
  307. buf := getBuffer()
  308. defer putBuffer(buf)
  309. rw := &responseWriter{
  310. ResponseWriter: ctx.Writer,
  311. maxSize: pluginConfig.ItemMaxSize,
  312. cacheBody: buf,
  313. }
  314. ctx.Writer = rw
  315. defer func() {
  316. ctx.Writer = rw.ResponseWriter
  317. if adapterErr != nil ||
  318. rw.overflow ||
  319. rw.cacheBody.Len() == 0 {
  320. return
  321. }
  322. // Convert http.Header to map[string][]string for JSON serialization
  323. headerMap := maps.Clone(rw.Header())
  324. // Store in cache
  325. item := Item{
  326. Body: bytes.Clone(rw.cacheBody.Bytes()),
  327. Header: headerMap,
  328. Usage: usage,
  329. }
  330. ttl := time.Duration(pluginConfig.TTL) * time.Second
  331. c.setToCache(ctx.Request.Context(), getCacheKey(meta), item, ttl)
  332. }()
  333. return do.DoResponse(meta, store, ctx, resp)
  334. }