فهرست منبع

✨ feat(channel): implement thread-safe polling

CaIon 5 ماه پیش
والد
کامیت
23e4e25e9a
3فایلهای تغییر یافته به همراه52 افزوده شده و 6 حذف شده
  1. 1 1
      controller/channel-test.go
  2. 40 3
      model/channel.go
  3. 11 2
      model/channel_cache.go

+ 1 - 1
controller/channel-test.go

@@ -312,7 +312,7 @@ func TestChannel(c *gin.Context) {
 		})
 		return
 	}
-	channel, err := model.GetChannelById(channelId, true)
+	channel, err := model.CacheGetChannel(channelId)
 	if err != nil {
 		c.JSON(http.StatusOK, gin.H{
 			"success": false,

+ 40 - 3
model/channel.go

@@ -117,15 +117,19 @@ func (channel *Channel) GetNextEnabledKey() (string, *types.NewAPIError) {
 		// Randomly pick one enabled key
 		return keys[enabledIdx[rand.Intn(len(enabledIdx))]], nil
 	case constant.MultiKeyModePolling:
+		// Use channel-specific lock to ensure thread-safe polling
+		lock := getChannelPollingLock(channel.Id)
+		lock.Lock()
+		defer lock.Unlock()
+
 		defer func() {
 			if !common.MemoryCacheEnabled {
 				_ = channel.Save()
 			} else {
-				CacheUpdateChannel(channel)
+				// CacheUpdateChannel(channel)
 			}
 		}()
 		// Start from the saved polling index and look for the next enabled key
-		println(channel.ChannelInfo.MultiKeyPollingIndex)
 		start := channel.ChannelInfo.MultiKeyPollingIndex
 		if start < 0 || start >= len(keys) {
 			start = 0
@@ -135,7 +139,6 @@ func (channel *Channel) GetNextEnabledKey() (string, *types.NewAPIError) {
 			if getStatus(idx) == common.ChannelStatusEnabled {
 				// update polling index for next call (point to the next position)
 				channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys)
-				println(channel.ChannelInfo.MultiKeyPollingIndex)
 				return keys[idx], nil
 			}
 		}
@@ -421,6 +424,40 @@ func (channel *Channel) Delete() error {
 
 var channelStatusLock sync.Mutex
 
+// channelPollingLocks stores locks for each channel.id to ensure thread-safe polling
+var channelPollingLocks sync.Map
+
+// getChannelPollingLock returns or creates a mutex for the given channel ID
+func getChannelPollingLock(channelId int) *sync.Mutex {
+	if lock, exists := channelPollingLocks.Load(channelId); exists {
+		return lock.(*sync.Mutex)
+	}
+	// Create new lock for this channel
+	newLock := &sync.Mutex{}
+	actual, _ := channelPollingLocks.LoadOrStore(channelId, newLock)
+	return actual.(*sync.Mutex)
+}
+
+// CleanupChannelPollingLocks removes locks for channels that no longer exist
+// This is optional and can be called periodically to prevent memory leaks
+func CleanupChannelPollingLocks() {
+	var activeChannelIds []int
+	DB.Model(&Channel{}).Pluck("id", &activeChannelIds)
+
+	activeChannelSet := make(map[int]bool)
+	for _, id := range activeChannelIds {
+		activeChannelSet[id] = true
+	}
+
+	channelPollingLocks.Range(func(key, value interface{}) bool {
+		channelId := key.(int)
+		if !activeChannelSet[channelId] {
+			channelPollingLocks.Delete(channelId)
+		}
+		return true
+	})
+}
+
 func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) {
 	keys := channel.getKeys()
 	if len(keys) == 0 {

+ 11 - 2
model/channel_cache.go

@@ -128,13 +128,20 @@ func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
 	}
 
 	channelSyncLock.RLock()
+	defer channelSyncLock.RUnlock()
 	channels := group2model2channels[group][model]
-	channelSyncLock.RUnlock()
 
 	if len(channels) == 0 {
 		return nil, errors.New("channel not found")
 	}
 
+	if len(channels) == 1 {
+		if channel, ok := channelsIDM[channels[0]]; ok {
+			return channel, nil
+		}
+		return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channels[0])
+	}
+
 	uniquePriorities := make(map[int]bool)
 	for _, channelId := range channels {
 		if channel, ok := channelsIDM[channelId]; ok {
@@ -196,7 +203,7 @@ func CacheGetChannel(id int) (*Channel, error) {
 
 	c, ok := channelsIDM[id]
 	if !ok {
-		return nil, errors.New(fmt.Sprintf("当前渠道# %d,已不存在", id))
+		return nil, fmt.Errorf("当前渠道# %d,已不存在", id)
 	}
 	return c, nil
 }
@@ -224,5 +231,7 @@ func CacheUpdateChannel(channel *Channel) {
 
 	println("CacheUpdateChannel:", channel.Id, channel.Name, channel.Status, channel.ChannelInfo.MultiKeyPollingIndex)
 
+	println("before:", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
 	channelsIDM[channel.Id] = channel
+	println("after :", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
 }