| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406 |
- package cache
- import (
- "bytes"
- "context"
- "crypto/sha256"
- "encoding/hex"
- "errors"
- "fmt"
- "maps"
- "net/http"
- "strconv"
- "sync"
- "time"
- "github.com/bytedance/sonic"
- "github.com/gin-gonic/gin"
- "github.com/labring/aiproxy/core/common"
- "github.com/labring/aiproxy/core/model"
- "github.com/labring/aiproxy/core/relay/adaptor"
- "github.com/labring/aiproxy/core/relay/meta"
- "github.com/labring/aiproxy/core/relay/plugin"
- "github.com/labring/aiproxy/core/relay/plugin/noop"
- gcache "github.com/patrickmn/go-cache"
- "github.com/redis/go-redis/v9"
- )
- // Constants for cache metadata keys
- const (
- cacheKey = "cache_key"
- cacheHit = "cache_hit"
- cacheValue = "cache_value"
- )
- // Constants for plugin configuration
- const (
- pluginConfigCacheKey = "cache-config"
- cacheHeader = "X-Aiproxy-Cache"
- redisCachePrefix = "cache:"
- )
- // Buffer size constants
- const (
- defaultBufferSize = 512 * 1024
- maxBufferSize = 4 * defaultBufferSize
- )
- // Item represents a cached response
- type Item struct {
- Body []byte `json:"body"`
- Header map[string][]string `json:"header"`
- Usage model.Usage `json:"usage"`
- }
- // Cache implements caching functionality for AI requests
- type Cache struct {
- noop.Noop
- rdb *redis.Client
- }
- var (
- _ plugin.Plugin = (*Cache)(nil)
- // Global cache instance with 5 minute default TTL and 10 minute cleanup interval
- cache = gcache.New(30*time.Second, 5*time.Minute)
- // Buffer pool for response writers
- bufferPool = sync.Pool{
- New: func() any {
- return bytes.NewBuffer(make([]byte, 0, defaultBufferSize))
- },
- }
- )
- // NewCachePlugin creates a new cache plugin
- func NewCachePlugin(rdb *redis.Client) plugin.Plugin {
- return &Cache{rdb: rdb}
- }
- // Cache metadata helpers
- func getCacheKey(meta *meta.Meta) string {
- return meta.GetString(cacheKey)
- }
- func setCacheKey(meta *meta.Meta, key string) {
- meta.Set(cacheKey, key)
- }
- func isCacheHit(meta *meta.Meta) bool {
- return meta.GetBool(cacheHit)
- }
- func getCacheItem(meta *meta.Meta) *Item {
- v, ok := meta.Get(cacheValue)
- if !ok {
- return nil
- }
- item, ok := v.(*Item)
- if !ok {
- panic(fmt.Sprintf("cache item type not match: %T", v))
- }
- return item
- }
- func setCacheHit(meta *meta.Meta, item *Item) {
- meta.Set(cacheHit, true)
- meta.Set(cacheValue, item)
- }
- // Buffer pool helpers
- func getBuffer() *bytes.Buffer {
- v, ok := bufferPool.Get().(*bytes.Buffer)
- if !ok {
- panic(fmt.Sprintf("buffer type error: %T, %v", v, v))
- }
- return v
- }
- func putBuffer(buf *bytes.Buffer) {
- buf.Reset()
- if buf.Cap() > maxBufferSize {
- return
- }
- bufferPool.Put(buf)
- }
- // getPluginConfig retrieves the plugin configuration from metadata
- func getPluginConfig(meta *meta.Meta) (config *Config, err error) {
- v, ok := meta.Get(pluginConfigCacheKey)
- if ok {
- config, ok := v.(*Config)
- if !ok {
- panic(fmt.Sprintf("cache config type not match: %T", v))
- }
- return config, nil
- }
- pluginConfig := Config{}
- if err := meta.ModelConfig.LoadPluginConfig("cache", &pluginConfig); err != nil {
- return nil, err
- }
- meta.Set(pluginConfigCacheKey, &pluginConfig)
- return &pluginConfig, nil
- }
- // Redis cache operations
- func (c *Cache) getFromRedis(ctx context.Context, key string) (*Item, error) {
- if c.rdb == nil {
- return nil, nil
- }
- data, err := c.rdb.Get(ctx, common.RedisKey(redisCachePrefix, key)).Bytes()
- if err != nil {
- if errors.Is(err, redis.Nil) {
- return nil, nil
- }
- return nil, err
- }
- var item Item
- if err := sonic.Unmarshal(data, &item); err != nil {
- return nil, err
- }
- return &item, nil
- }
- func (c *Cache) setToRedis(ctx context.Context, key string, item *Item, ttl time.Duration) error {
- if c.rdb == nil {
- return nil
- }
- data, err := sonic.Marshal(item)
- if err != nil {
- return err
- }
- return c.rdb.Set(ctx, common.RedisKey(redisCachePrefix, key), data, ttl).Err()
- }
- // getFromCache retrieves item from cache (Redis or memory)
- func (c *Cache) getFromCache(ctx context.Context, key string) (*Item, bool) {
- // Try Redis first if available
- if c.rdb != nil {
- item, err := c.getFromRedis(ctx, key)
- if err == nil && item != nil {
- return item, true
- }
- // If Redis fails, fallback to memory cache
- }
- // Try memory cache
- if v, ok := cache.Get(key); ok {
- if item, ok := v.(Item); ok {
- return &item, true
- }
- }
- return nil, false
- }
- // setToCache stores item in cache (Redis and/or memory)
- func (c *Cache) setToCache(ctx context.Context, key string, item Item, ttl time.Duration) {
- // Set to Redis if available
- if c.rdb != nil {
- if err := c.setToRedis(ctx, key, &item, ttl); err == nil {
- // If Redis succeeds, also set to memory cache for faster access
- cache.Set(key, item, ttl)
- return
- }
- // If Redis fails, fallback to memory cache only
- }
- // Set to memory cache
- cache.Set(key, item, ttl)
- }
- // ConvertRequest handles the request conversion phase
- func (c *Cache) ConvertRequest(
- meta *meta.Meta,
- store adaptor.Store,
- req *http.Request,
- do adaptor.ConvertRequest,
- ) (adaptor.ConvertResult, error) {
- pluginConfig, err := getPluginConfig(meta)
- if err != nil {
- return do.ConvertRequest(meta, store, req)
- }
- if !pluginConfig.Enable {
- return do.ConvertRequest(meta, store, req)
- }
- body, err := common.GetRequestBodyReusable(req)
- if err != nil {
- return adaptor.ConvertResult{}, err
- }
- if len(body) == 0 {
- return do.ConvertRequest(meta, store, req)
- }
- // Generate hash as cache key
- hash := sha256.Sum256(body)
- cacheKey := fmt.Sprintf("%d:%s", meta.Mode, hex.EncodeToString(hash[:]))
- setCacheKey(meta, cacheKey)
- // Check cache
- ctx := req.Context()
- if item, ok := c.getFromCache(ctx, cacheKey); ok {
- setCacheHit(meta, item)
- return adaptor.ConvertResult{}, nil
- }
- return do.ConvertRequest(meta, store, req)
- }
- // DoRequest handles the request execution phase
- func (c *Cache) DoRequest(
- meta *meta.Meta,
- store adaptor.Store,
- ctx *gin.Context,
- req *http.Request,
- do adaptor.DoRequest,
- ) (*http.Response, error) {
- if isCacheHit(meta) {
- return &http.Response{}, nil
- }
- return do.DoRequest(meta, store, ctx, req)
- }
- // Custom response writer to capture response for caching
- type responseWriter struct {
- gin.ResponseWriter
- cacheBody *bytes.Buffer
- maxSize int
- overflow bool
- }
- func (rw *responseWriter) Write(b []byte) (int, error) {
- if rw.overflow {
- return rw.ResponseWriter.Write(b)
- }
- if rw.maxSize > 0 && rw.cacheBody.Len()+len(b) > rw.maxSize {
- rw.overflow = true
- rw.cacheBody.Reset()
- return rw.ResponseWriter.Write(b)
- }
- rw.cacheBody.Write(b)
- return rw.ResponseWriter.Write(b)
- }
- func (rw *responseWriter) WriteString(s string) (int, error) {
- if rw.overflow {
- return rw.ResponseWriter.WriteString(s)
- }
- if rw.maxSize > 0 && rw.cacheBody.Len()+len(s) > rw.maxSize {
- rw.overflow = true
- rw.cacheBody.Reset()
- return rw.ResponseWriter.WriteString(s)
- }
- rw.cacheBody.WriteString(s)
- return rw.ResponseWriter.WriteString(s)
- }
- func (c *Cache) writeCacheHeader(ctx *gin.Context, pluginConfig *Config, value string) {
- if pluginConfig.AddCacheHitHeader {
- header := pluginConfig.CacheHitHeader
- if header == "" {
- header = cacheHeader
- }
- ctx.Header(header, value)
- }
- }
- // DoResponse handles the response processing phase
- func (c *Cache) DoResponse(
- meta *meta.Meta,
- store adaptor.Store,
- ctx *gin.Context,
- resp *http.Response,
- do adaptor.DoResponse,
- ) (usage model.Usage, adapterErr adaptor.Error) {
- pluginConfig, err := getPluginConfig(meta)
- if err != nil {
- return do.DoResponse(meta, store, ctx, resp)
- }
- // Handle cache hit
- if isCacheHit(meta) {
- item := getCacheItem(meta)
- if item == nil {
- return do.DoResponse(meta, store, ctx, resp)
- }
- // Restore headers from cache
- for k, v := range item.Header {
- for _, val := range v {
- ctx.Header(k, val)
- }
- }
- // Override specific headers
- ctx.Header("Content-Type", item.Header["Content-Type"][0])
- ctx.Header("Content-Length", strconv.Itoa(len(item.Body)))
- c.writeCacheHeader(ctx, pluginConfig, "hit")
- _, _ = ctx.Writer.Write(item.Body)
- return item.Usage, nil
- }
- if !pluginConfig.Enable {
- return do.DoResponse(meta, store, ctx, resp)
- }
- c.writeCacheHeader(ctx, pluginConfig, "miss")
- // Set up response capture for caching
- buf := getBuffer()
- defer putBuffer(buf)
- rw := &responseWriter{
- ResponseWriter: ctx.Writer,
- maxSize: pluginConfig.ItemMaxSize,
- cacheBody: buf,
- }
- ctx.Writer = rw
- defer func() {
- ctx.Writer = rw.ResponseWriter
- if adapterErr != nil ||
- rw.overflow ||
- rw.cacheBody.Len() == 0 {
- return
- }
- // Convert http.Header to map[string][]string for JSON serialization
- headerMap := maps.Clone(rw.Header())
- // Store in cache
- item := Item{
- Body: bytes.Clone(rw.cacheBody.Bytes()),
- Header: headerMap,
- Usage: usage,
- }
- ttl := time.Duration(pluginConfig.TTL) * time.Second
- c.setToCache(ctx.Request.Context(), getCacheKey(meta), item, ttl)
- }()
- return do.DoResponse(meta, store, ctx, resp)
- }
|