mcp-mpsc.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. package controller
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "runtime"
  7. "sync"
  8. "time"
  9. "github.com/labring/aiproxy/core/common"
  10. "github.com/redis/go-redis/v9"
  11. )
  12. // Interface for multi-producer, single-consumer message passing
  13. type mpsc interface {
  14. recv(ctx context.Context, id string) ([]byte, error)
  15. send(ctx context.Context, id string, data []byte) error
  16. }
  17. // Global MPSC instances
  18. var (
  19. memMCPMpsc mpsc = newChannelMCPMpsc()
  20. redisMCPMpsc mpsc
  21. redisMCPMpscOnce = &sync.Once{}
  22. )
  23. func getMCPMpsc() mpsc {
  24. if common.RedisEnabled {
  25. redisMCPMpscOnce.Do(func() {
  26. redisMCPMpsc = newRedisMCPMPSC(common.RDB)
  27. })
  28. return redisMCPMpsc
  29. }
  30. return memMCPMpsc
  31. }
  32. // In-memory channel-based MPSC implementation
  33. type channelMCPMpsc struct {
  34. channels map[string]chan []byte
  35. lastAccess map[string]time.Time
  36. channelMutex sync.RWMutex
  37. }
  38. // newChannelMCPMpsc creates a new channel-based mpsc implementation
  39. func newChannelMCPMpsc() *channelMCPMpsc {
  40. c := &channelMCPMpsc{
  41. channels: make(map[string]chan []byte),
  42. lastAccess: make(map[string]time.Time),
  43. }
  44. // Start a goroutine to clean up expired channels
  45. go c.cleanupExpiredChannels()
  46. return c
  47. }
  48. // cleanupExpiredChannels periodically checks for and removes channels that haven't been accessed in
  49. // 15 seconds
  50. func (c *channelMCPMpsc) cleanupExpiredChannels() {
  51. ticker := time.NewTicker(15 * time.Second)
  52. defer ticker.Stop()
  53. for range ticker.C {
  54. c.channelMutex.Lock()
  55. now := time.Now()
  56. for id, lastAccess := range c.lastAccess {
  57. if now.Sub(lastAccess) > 15*time.Second {
  58. // Close and delete the channel
  59. if ch, exists := c.channels[id]; exists {
  60. close(ch)
  61. delete(c.channels, id)
  62. }
  63. delete(c.lastAccess, id)
  64. }
  65. }
  66. c.channelMutex.Unlock()
  67. }
  68. }
  69. // getOrCreateChannel gets an existing channel or creates a new one for the session
  70. func (c *channelMCPMpsc) getOrCreateChannel(id string) chan []byte {
  71. c.channelMutex.RLock()
  72. ch, exists := c.channels[id]
  73. c.channelMutex.RUnlock()
  74. if !exists {
  75. c.channelMutex.Lock()
  76. if ch, exists = c.channels[id]; !exists {
  77. ch = make(chan []byte, 10)
  78. c.channels[id] = ch
  79. }
  80. c.lastAccess[id] = time.Now()
  81. c.channelMutex.Unlock()
  82. } else {
  83. c.channelMutex.Lock()
  84. c.lastAccess[id] = time.Now()
  85. c.channelMutex.Unlock()
  86. }
  87. return ch
  88. }
  89. // recv receives data for the specified session
  90. func (c *channelMCPMpsc) recv(ctx context.Context, id string) ([]byte, error) {
  91. ch := c.getOrCreateChannel(id)
  92. select {
  93. case data, ok := <-ch:
  94. if !ok {
  95. return nil, fmt.Errorf("channel closed for session %s", id)
  96. }
  97. return data, nil
  98. case <-ctx.Done():
  99. return nil, ctx.Err()
  100. }
  101. }
  102. // send sends data to the specified session
  103. func (c *channelMCPMpsc) send(ctx context.Context, id string, data []byte) error {
  104. ch := c.getOrCreateChannel(id)
  105. select {
  106. case ch <- data:
  107. return nil
  108. case <-ctx.Done():
  109. return ctx.Err()
  110. default:
  111. return fmt.Errorf("channel buffer full for session %s", id)
  112. }
  113. }
  114. // Redis-based MPSC implementation
  115. type redisMCPMPSC struct {
  116. rdb *redis.Client
  117. }
  118. // newRedisMCPMPSC creates a new Redis MPSC instance
  119. func newRedisMCPMPSC(rdb *redis.Client) *redisMCPMPSC {
  120. return &redisMCPMPSC{rdb: rdb}
  121. }
  122. func (r *redisMCPMPSC) send(ctx context.Context, id string, data []byte) error {
  123. // Set expiration to 15 seconds when sending data
  124. id = common.RedisKey("mcp:mpsc", id)
  125. pipe := r.rdb.Pipeline()
  126. pipe.LPush(ctx, id, data)
  127. pipe.Expire(ctx, id, 15*time.Second)
  128. _, err := pipe.Exec(ctx)
  129. return err
  130. }
  131. func (r *redisMCPMPSC) recv(ctx context.Context, id string) ([]byte, error) {
  132. id = common.RedisKey("mcp:mpsc", id)
  133. for {
  134. select {
  135. case <-ctx.Done():
  136. return nil, ctx.Err()
  137. default:
  138. result, err := r.rdb.BRPop(ctx, time.Second, id).Result()
  139. if err != nil {
  140. if errors.Is(err, redis.Nil) {
  141. runtime.Gosched()
  142. continue
  143. }
  144. return nil, err
  145. }
  146. if len(result) != 2 {
  147. return nil, errors.New("invalid BRPop result")
  148. }
  149. return []byte(result[1]), nil
  150. }
  151. }
  152. }