|
|
@@ -2,8 +2,11 @@ package model
|
|
|
|
|
|
import (
|
|
|
"encoding/json"
|
|
|
+ "errors"
|
|
|
"fmt"
|
|
|
+ "math/rand"
|
|
|
"one-api/common"
|
|
|
+ "strings"
|
|
|
"sync"
|
|
|
"time"
|
|
|
)
|
|
|
@@ -57,18 +60,15 @@ func CacheGetUserGroup(id int) (group string, err error) {
|
|
|
return group, err
|
|
|
}
|
|
|
|
|
|
-var channelId2channel map[int]*Channel
|
|
|
-var channelSyncLock sync.RWMutex
|
|
|
var group2model2channels map[string]map[string][]*Channel
|
|
|
+var channelSyncLock sync.RWMutex
|
|
|
|
|
|
func InitChannelCache() {
|
|
|
- channelSyncLock.Lock()
|
|
|
- defer channelSyncLock.Unlock()
|
|
|
- channelId2channel = make(map[int]*Channel)
|
|
|
+ newChannelId2channel := make(map[int]*Channel)
|
|
|
var channels []*Channel
|
|
|
DB.Find(&channels)
|
|
|
for _, channel := range channels {
|
|
|
- channelId2channel[channel.Id] = channel
|
|
|
+ newChannelId2channel[channel.Id] = channel
|
|
|
}
|
|
|
var abilities []*Ability
|
|
|
DB.Find(&abilities)
|
|
|
@@ -76,11 +76,26 @@ func InitChannelCache() {
|
|
|
for _, ability := range abilities {
|
|
|
groups[ability.Group] = true
|
|
|
}
|
|
|
- group2model2channels = make(map[string]map[string][]*Channel)
|
|
|
+ newGroup2model2channels := make(map[string]map[string][]*Channel)
|
|
|
for group := range groups {
|
|
|
- group2model2channels[group] = make(map[string][]*Channel)
|
|
|
- // TODO: implement this
|
|
|
+ newGroup2model2channels[group] = make(map[string][]*Channel)
|
|
|
+ }
|
|
|
+ for _, channel := range channels {
|
|
|
+ groups := strings.Split(channel.Group, ",")
|
|
|
+ for _, group := range groups {
|
|
|
+ models := strings.Split(channel.Models, ",")
|
|
|
+ for _, model := range models {
|
|
|
+ if _, ok := newGroup2model2channels[group][model]; !ok {
|
|
|
+ newGroup2model2channels[group][model] = make([]*Channel, 0)
|
|
|
+ }
|
|
|
+ newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
+ channelSyncLock.Lock()
|
|
|
+ group2model2channels = newGroup2model2channels
|
|
|
+ channelSyncLock.Unlock()
|
|
|
+ common.SysLog("Channels synced from database")
|
|
|
}
|
|
|
|
|
|
func SyncChannelCache(frequency int) {
|
|
|
@@ -95,7 +110,12 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
|
|
if !common.RedisEnabled {
|
|
|
return GetRandomSatisfiedChannel(group, model)
|
|
|
}
|
|
|
- return GetRandomSatisfiedChannel(group, model)
|
|
|
- // TODO: implement this
|
|
|
- return nil, nil
|
|
|
+ channelSyncLock.RLock()
|
|
|
+ defer channelSyncLock.RUnlock()
|
|
|
+ channels := group2model2channels[group][model]
|
|
|
+ if len(channels) == 0 {
|
|
|
+ return nil, errors.New("channel not found")
|
|
|
+ }
|
|
|
+ idx := rand.Intn(len(channels))
|
|
|
+ return channels[idx], nil
|
|
|
}
|