Просмотр исходного кода

feat: select channel without database (#158)

JustSong 2 лет назад
Родитель
Сommit
ba54c71948
1 измененных файлов с 32 добавлено и 12 удалено
  1. 32 12
      model/cache.go

+ 32 - 12
model/cache.go

@@ -2,8 +2,11 @@ package model
 
 import (
 	"encoding/json"
+	"errors"
 	"fmt"
+	"math/rand"
 	"one-api/common"
+	"strings"
 	"sync"
 	"time"
 )
@@ -57,18 +60,15 @@ func CacheGetUserGroup(id int) (group string, err error) {
 	return group, err
 }
 
-var channelId2channel map[int]*Channel
-var channelSyncLock sync.RWMutex
 var group2model2channels map[string]map[string][]*Channel
+var channelSyncLock sync.RWMutex
 
 func InitChannelCache() {
-	channelSyncLock.Lock()
-	defer channelSyncLock.Unlock()
-	channelId2channel = make(map[int]*Channel)
+	newChannelId2channel := make(map[int]*Channel)
 	var channels []*Channel
 	DB.Find(&channels)
 	for _, channel := range channels {
-		channelId2channel[channel.Id] = channel
+		newChannelId2channel[channel.Id] = channel
 	}
 	var abilities []*Ability
 	DB.Find(&abilities)
@@ -76,11 +76,26 @@ func InitChannelCache() {
 	for _, ability := range abilities {
 		groups[ability.Group] = true
 	}
-	group2model2channels = make(map[string]map[string][]*Channel)
+	newGroup2model2channels := make(map[string]map[string][]*Channel)
 	for group := range groups {
-		group2model2channels[group] = make(map[string][]*Channel)
-		// TODO: implement this
+		newGroup2model2channels[group] = make(map[string][]*Channel)
+	}
+	for _, channel := range channels {
+		groups := strings.Split(channel.Group, ",")
+		for _, group := range groups {
+			models := strings.Split(channel.Models, ",")
+			for _, model := range models {
+				if _, ok := newGroup2model2channels[group][model]; !ok {
+					newGroup2model2channels[group][model] = make([]*Channel, 0)
+				}
+				newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
+			}
+		}
 	}
+	channelSyncLock.Lock()
+	group2model2channels = newGroup2model2channels
+	channelSyncLock.Unlock()
+	common.SysLog("Channels synced from database")
 }
 
 func SyncChannelCache(frequency int) {
@@ -95,7 +110,12 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
 	if !common.RedisEnabled {
 		return GetRandomSatisfiedChannel(group, model)
 	}
-	return GetRandomSatisfiedChannel(group, model)
-	// TODO: implement this
-	return nil, nil
+	channelSyncLock.RLock()
+	defer channelSyncLock.RUnlock()
+	channels := group2model2channels[group][model]
+	if len(channels) == 0 {
+		return nil, errors.New("channel not found")
+	}
+	idx := rand.Intn(len(channels))
+	return channels[idx], nil
 }