cache.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. package cache
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/sha256"
  6. "encoding/hex"
  7. "errors"
  8. "fmt"
  9. "net/http"
  10. "strconv"
  11. "sync"
  12. "time"
  13. "github.com/bytedance/sonic"
  14. "github.com/gin-gonic/gin"
  15. "github.com/labring/aiproxy/core/common"
  16. "github.com/labring/aiproxy/core/model"
  17. "github.com/labring/aiproxy/core/relay/adaptor"
  18. "github.com/labring/aiproxy/core/relay/meta"
  19. "github.com/labring/aiproxy/core/relay/plugin"
  20. "github.com/labring/aiproxy/core/relay/plugin/noop"
  21. gcache "github.com/patrickmn/go-cache"
  22. "github.com/redis/go-redis/v9"
  23. )
  24. // Constants for cache metadata keys
  25. const (
  26. cacheKey = "cache_key"
  27. cacheHit = "cache_hit"
  28. cacheValue = "cache_value"
  29. )
  30. // Constants for plugin configuration
  31. const (
  32. pluginConfigCacheKey = "cache-config"
  33. cacheHeader = "X-Aiproxy-Cache"
  34. redisCachePrefix = "cache:"
  35. )
  36. // Buffer size constants
  37. const (
  38. defaultBufferSize = 512 * 1024
  39. maxBufferSize = 4 * defaultBufferSize
  40. )
  41. // Item represents a cached response
  42. type Item struct {
  43. Body []byte `json:"body"`
  44. Header map[string][]string `json:"header"`
  45. Usage *model.Usage `json:"usage"`
  46. }
  47. // Cache implements caching functionality for AI requests
  48. type Cache struct {
  49. noop.Noop
  50. rdb *redis.Client
  51. }
  52. var (
  53. _ plugin.Plugin = (*Cache)(nil)
  54. // Global cache instance with 5 minute default TTL and 10 minute cleanup interval
  55. cache = gcache.New(30*time.Second, 5*time.Minute)
  56. // Buffer pool for response writers
  57. bufferPool = sync.Pool{
  58. New: func() any {
  59. return bytes.NewBuffer(make([]byte, 0, defaultBufferSize))
  60. },
  61. }
  62. )
  63. // NewCachePlugin creates a new cache plugin
  64. func NewCachePlugin(rdb *redis.Client) plugin.Plugin {
  65. return &Cache{rdb: rdb}
  66. }
  67. // Cache metadata helpers
  68. func getCacheKey(meta *meta.Meta) string {
  69. return meta.GetString(cacheKey)
  70. }
  71. func setCacheKey(meta *meta.Meta, key string) {
  72. meta.Set(cacheKey, key)
  73. }
  74. func isCacheHit(meta *meta.Meta) bool {
  75. return meta.GetBool(cacheHit)
  76. }
  77. func getCacheItem(meta *meta.Meta) *Item {
  78. v, ok := meta.Get(cacheValue)
  79. if !ok {
  80. return nil
  81. }
  82. item, ok := v.(*Item)
  83. if !ok {
  84. panic(fmt.Sprintf("cache item type not match: %T", v))
  85. }
  86. return item
  87. }
  88. func setCacheHit(meta *meta.Meta, item *Item) {
  89. meta.Set(cacheHit, true)
  90. meta.Set(cacheValue, item)
  91. }
  92. // Buffer pool helpers
  93. func getBuffer() *bytes.Buffer {
  94. return bufferPool.Get().(*bytes.Buffer)
  95. }
  96. func putBuffer(buf *bytes.Buffer) {
  97. buf.Reset()
  98. if buf.Cap() > maxBufferSize {
  99. return
  100. }
  101. bufferPool.Put(buf)
  102. }
  103. // getPluginConfig retrieves the plugin configuration from metadata
  104. func getPluginConfig(meta *meta.Meta) (config *Config, err error) {
  105. v, ok := meta.Get(pluginConfigCacheKey)
  106. if ok {
  107. config, ok := v.(*Config)
  108. if !ok {
  109. panic(fmt.Sprintf("cache config type not match: %T", v))
  110. }
  111. return config, nil
  112. }
  113. pluginConfig := Config{}
  114. if err := meta.ModelConfig.LoadPluginConfig("cache", &pluginConfig); err != nil {
  115. return nil, err
  116. }
  117. meta.Set(pluginConfigCacheKey, &pluginConfig)
  118. return &pluginConfig, nil
  119. }
  120. // Redis cache operations
  121. func (c *Cache) getFromRedis(ctx context.Context, key string) (*Item, error) {
  122. if c.rdb == nil {
  123. return nil, nil
  124. }
  125. data, err := c.rdb.Get(ctx, redisCachePrefix+key).Bytes()
  126. if err != nil {
  127. if errors.Is(err, redis.Nil) {
  128. return nil, nil
  129. }
  130. return nil, err
  131. }
  132. var item Item
  133. if err := sonic.Unmarshal(data, &item); err != nil {
  134. return nil, err
  135. }
  136. return &item, nil
  137. }
  138. func (c *Cache) setToRedis(ctx context.Context, key string, item *Item, ttl time.Duration) error {
  139. if c.rdb == nil {
  140. return nil
  141. }
  142. data, err := sonic.Marshal(item)
  143. if err != nil {
  144. return err
  145. }
  146. return c.rdb.Set(ctx, redisCachePrefix+key, data, ttl).Err()
  147. }
  148. // getFromCache retrieves item from cache (Redis or memory)
  149. func (c *Cache) getFromCache(ctx context.Context, key string) (*Item, bool) {
  150. // Try Redis first if available
  151. if c.rdb != nil {
  152. item, err := c.getFromRedis(ctx, key)
  153. if err == nil && item != nil {
  154. return item, true
  155. }
  156. // If Redis fails, fallback to memory cache
  157. }
  158. // Try memory cache
  159. if v, ok := cache.Get(key); ok {
  160. if item, ok := v.(Item); ok {
  161. return &item, true
  162. }
  163. }
  164. return nil, false
  165. }
  166. // setToCache stores item in cache (Redis and/or memory)
  167. func (c *Cache) setToCache(ctx context.Context, key string, item Item, ttl time.Duration) {
  168. // Set to Redis if available
  169. if c.rdb != nil {
  170. if err := c.setToRedis(ctx, key, &item, ttl); err == nil {
  171. // If Redis succeeds, also set to memory cache for faster access
  172. cache.Set(key, item, ttl)
  173. return
  174. }
  175. // If Redis fails, fallback to memory cache only
  176. }
  177. // Set to memory cache
  178. cache.Set(key, item, ttl)
  179. }
  180. // ConvertRequest handles the request conversion phase
  181. func (c *Cache) ConvertRequest(meta *meta.Meta, req *http.Request, do adaptor.ConvertRequest) (*adaptor.ConvertRequestResult, error) {
  182. pluginConfig, err := getPluginConfig(meta)
  183. if err != nil {
  184. return do.ConvertRequest(meta, req)
  185. }
  186. if !pluginConfig.EnablePlugin {
  187. return do.ConvertRequest(meta, req)
  188. }
  189. body, err := common.GetRequestBody(req)
  190. if err != nil {
  191. return nil, err
  192. }
  193. if len(body) == 0 {
  194. return do.ConvertRequest(meta, req)
  195. }
  196. // Generate hash as cache key
  197. hash := sha256.Sum256(body)
  198. cacheKey := fmt.Sprintf("%d:%s", meta.Mode, hex.EncodeToString(hash[:]))
  199. setCacheKey(meta, cacheKey)
  200. // Check cache
  201. ctx := req.Context()
  202. if item, ok := c.getFromCache(ctx, cacheKey); ok {
  203. setCacheHit(meta, item)
  204. return &adaptor.ConvertRequestResult{}, nil
  205. }
  206. return do.ConvertRequest(meta, req)
  207. }
  208. // DoRequest handles the request execution phase
  209. func (c *Cache) DoRequest(meta *meta.Meta, ctx *gin.Context, req *http.Request, do adaptor.DoRequest) (*http.Response, error) {
  210. if isCacheHit(meta) {
  211. return &http.Response{}, nil
  212. }
  213. return do.DoRequest(meta, ctx, req)
  214. }
  215. // Custom response writer to capture response for caching
  216. type responseWriter struct {
  217. gin.ResponseWriter
  218. cacheBody *bytes.Buffer
  219. maxSize int
  220. overflow bool
  221. }
  222. func (rw *responseWriter) Write(b []byte) (int, error) {
  223. if rw.overflow {
  224. return rw.ResponseWriter.Write(b)
  225. }
  226. if rw.maxSize > 0 && rw.cacheBody.Len()+len(b) > rw.maxSize {
  227. rw.overflow = true
  228. rw.cacheBody.Reset()
  229. return rw.ResponseWriter.Write(b)
  230. }
  231. rw.cacheBody.Write(b)
  232. return rw.ResponseWriter.Write(b)
  233. }
  234. func (rw *responseWriter) WriteString(s string) (int, error) {
  235. if rw.overflow {
  236. return rw.ResponseWriter.WriteString(s)
  237. }
  238. if rw.maxSize > 0 && rw.cacheBody.Len()+len(s) > rw.maxSize {
  239. rw.overflow = true
  240. rw.cacheBody.Reset()
  241. return rw.ResponseWriter.WriteString(s)
  242. }
  243. rw.cacheBody.WriteString(s)
  244. return rw.ResponseWriter.WriteString(s)
  245. }
  246. func (c *Cache) writeCacheHeader(ctx *gin.Context, pluginConfig *Config, value string) {
  247. if pluginConfig.AddCacheHitHeader {
  248. header := pluginConfig.CacheHitHeader
  249. if header == "" {
  250. header = cacheHeader
  251. }
  252. ctx.Header(header, value)
  253. }
  254. }
  255. // DoResponse handles the response processing phase
  256. func (c *Cache) DoResponse(meta *meta.Meta, ctx *gin.Context, resp *http.Response, do adaptor.DoResponse) (usage *model.Usage, adapterErr adaptor.Error) {
  257. pluginConfig, err := getPluginConfig(meta)
  258. if err != nil {
  259. return do.DoResponse(meta, ctx, resp)
  260. }
  261. // Handle cache hit
  262. if isCacheHit(meta) {
  263. item := getCacheItem(meta)
  264. if item == nil {
  265. return do.DoResponse(meta, ctx, resp)
  266. }
  267. // Restore headers from cache
  268. for k, v := range item.Header {
  269. for _, val := range v {
  270. ctx.Header(k, val)
  271. }
  272. }
  273. // Override specific headers
  274. ctx.Header("Content-Type", item.Header["Content-Type"][0])
  275. ctx.Header("Content-Length", strconv.Itoa(len(item.Body)))
  276. c.writeCacheHeader(ctx, pluginConfig, "hit")
  277. ctx.Status(http.StatusOK)
  278. _, _ = ctx.Writer.Write(item.Body)
  279. return item.Usage, nil
  280. }
  281. if !pluginConfig.EnablePlugin {
  282. return do.DoResponse(meta, ctx, resp)
  283. }
  284. c.writeCacheHeader(ctx, pluginConfig, "miss")
  285. // Set up response capture for caching
  286. buf := getBuffer()
  287. defer putBuffer(buf)
  288. rw := &responseWriter{
  289. ResponseWriter: ctx.Writer,
  290. maxSize: pluginConfig.ItemMaxSize,
  291. cacheBody: buf,
  292. }
  293. ctx.Writer = rw
  294. defer func() {
  295. ctx.Writer = rw.ResponseWriter
  296. if adapterErr != nil ||
  297. rw.overflow ||
  298. rw.cacheBody.Len() == 0 {
  299. return
  300. }
  301. // Convert http.Header to map[string][]string for JSON serialization
  302. headerMap := make(map[string][]string)
  303. for k, v := range rw.Header() {
  304. headerMap[k] = v
  305. }
  306. // Store in cache
  307. item := Item{
  308. Body: bytes.Clone(rw.cacheBody.Bytes()),
  309. Header: headerMap,
  310. Usage: usage,
  311. }
  312. ttl := time.Duration(pluginConfig.TTL) * time.Second
  313. c.setToCache(ctx.Request.Context(), getCacheKey(meta), item, ttl)
  314. }()
  315. return do.DoResponse(meta, ctx, resp)
  316. }