| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391 |
- package model
- import (
- "errors"
- "fmt"
- "math/rand"
- "one-api/common"
- "one-api/constant"
- "one-api/setting"
- "sort"
- "strings"
- "sync"
- "time"
- "github.com/gin-gonic/gin"
- )
- var group2model2channels map[string]map[string][]int // enabled channel
- var channelsIDM map[int]*Channel // all channels include disabled
- var channelSyncLock sync.RWMutex
- func InitChannelCache() {
- if !common.MemoryCacheEnabled {
- return
- }
- newChannelId2channel := make(map[int]*Channel)
- var channels []*Channel
- DB.Find(&channels)
- for _, channel := range channels {
- newChannelId2channel[channel.Id] = channel
- }
- var abilities []*Ability
- DB.Find(&abilities)
- groups := make(map[string]bool)
- for _, ability := range abilities {
- groups[ability.Group] = true
- }
- newGroup2model2channels := make(map[string]map[string][]int)
- for group := range groups {
- newGroup2model2channels[group] = make(map[string][]int)
- }
- for _, channel := range channels {
- if channel.Status != common.ChannelStatusEnabled {
- continue // skip disabled channels
- }
- groups := strings.Split(channel.Group, ",")
- for _, group := range groups {
- models := strings.Split(channel.Models, ",")
- for _, model := range models {
- if _, ok := newGroup2model2channels[group][model]; !ok {
- newGroup2model2channels[group][model] = make([]int, 0)
- }
- newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel.Id)
- }
- }
- }
- // sort by priority
- for group, model2channels := range newGroup2model2channels {
- for model, channels := range model2channels {
- sort.Slice(channels, func(i, j int) bool {
- return newChannelId2channel[channels[i]].GetPriority() > newChannelId2channel[channels[j]].GetPriority()
- })
- newGroup2model2channels[group][model] = channels
- }
- }
- channelSyncLock.Lock()
- group2model2channels = newGroup2model2channels
- channelsIDM = newChannelId2channel
- channelSyncLock.Unlock()
- common.SysLog("channels synced from database")
- }
- func SyncChannelCache(frequency int) {
- for {
- time.Sleep(time.Duration(frequency) * time.Second)
- common.SysLog("syncing channels from database")
- InitChannelCache()
- }
- }
- func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
- var channel *Channel
- var err error
- selectGroup := group
- // 获取令牌渠道标签
- tokenChannelTag := common.GetContextKeyString(c, constant.ContextKeyTokenChannelTag)
- var channelTag *string = nil
- if tokenChannelTag != "" {
- channelTag = &tokenChannelTag
- }
- if group == "auto" {
- if len(setting.AutoGroups) == 0 {
- return nil, selectGroup, errors.New("auto groups is not enabled")
- }
- for _, autoGroup := range setting.AutoGroups {
- if common.DebugEnabled {
- println("autoGroup:", autoGroup)
- }
- channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
- if channel == nil {
- continue
- } else {
- c.Set("auto_group", autoGroup)
- selectGroup = autoGroup
- if common.DebugEnabled {
- println("selectGroup:", selectGroup)
- }
- break
- }
- }
- } else {
- // 传递channelTag参数给getRandomSatisfiedChannel
- channel, err = getRandomSatisfiedChannelWithTag(group, model, retry, channelTag)
- if err != nil {
- return nil, group, err
- }
- }
- if channel == nil {
- return nil, group, errors.New("channel not found")
- }
- return channel, selectGroup, nil
- }
- // 新增带标签过滤的渠道选择函数
- func getRandomSatisfiedChannelWithTag(group string, model string, retry int, channelTag *string) (*Channel, error) {
- if strings.HasPrefix(model, "gpt-4-gizmo") {
- model = "gpt-4-gizmo-*"
- }
- if strings.HasPrefix(model, "gpt-4o-gizmo") {
- model = "gpt-4o-gizmo-*"
- }
- // if memory cache is disabled, get channel directly from database
- if !common.MemoryCacheEnabled {
- return GetRandomSatisfiedChannel(group, model, retry, channelTag)
- }
- channelSyncLock.RLock()
- defer channelSyncLock.RUnlock()
- channels := group2model2channels[group][model]
- if len(channels) == 0 {
- return nil, errors.New("channel not found")
- }
- // 过滤符合标签要求的渠道
- var filteredChannels []int
- for _, channelId := range channels {
- if channel, ok := channelsIDM[channelId]; ok {
- // 如果没有指定标签要求,则所有渠道都符合
- if channelTag == nil || *channelTag == "" {
- filteredChannels = append(filteredChannels, channelId)
- } else {
- // 如果指定了标签要求,则只选择匹配标签的渠道
- channelTagStr := channel.GetTag()
- if channelTagStr == *channelTag {
- filteredChannels = append(filteredChannels, channelId)
- }
- }
- }
- }
- // 如果没有符合标签要求的渠道,返回错误
- if len(filteredChannels) == 0 {
- if channelTag != nil && *channelTag != "" {
- return nil, fmt.Errorf("没有找到标签为 '%s' 的可用渠道", *channelTag)
- }
- return nil, errors.New("channel not found")
- }
- if len(filteredChannels) == 1 {
- if channel, ok := channelsIDM[filteredChannels[0]]; ok {
- return channel, nil
- }
- return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", filteredChannels[0])
- }
- uniquePriorities := make(map[int]bool)
- for _, channelId := range filteredChannels {
- if channel, ok := channelsIDM[channelId]; ok {
- uniquePriorities[int(channel.GetPriority())] = true
- } else {
- return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
- }
- }
- var sortedUniquePriorities []int
- for priority := range uniquePriorities {
- sortedUniquePriorities = append(sortedUniquePriorities, priority)
- }
- sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
- if retry >= len(uniquePriorities) {
- retry = len(uniquePriorities) - 1
- }
- targetPriority := int64(sortedUniquePriorities[retry])
- // get the priority for the given retry number
- var targetChannels []*Channel
- for _, channelId := range filteredChannels {
- if channel, ok := channelsIDM[channelId]; ok {
- if channel.GetPriority() == targetPriority {
- targetChannels = append(targetChannels, channel)
- }
- } else {
- return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
- }
- }
- // 如果只有一个符合条件的渠道,直接返回
- if len(targetChannels) == 1 {
- return targetChannels[0], nil
- }
- // 平滑系数
- smoothingFactor := 10
- // Calculate the total weight of all channels up to endIdx
- totalWeight := 0
- for _, channel := range targetChannels {
- totalWeight += channel.GetWeight() + smoothingFactor
- }
- // 如果总权重为0,则平均分配权重
- if totalWeight == 0 {
- // 随机选择一个渠道
- randomIndex := common.GetRandomInt(len(targetChannels))
- return targetChannels[randomIndex], nil
- }
- // Generate a random value in the range [0, totalWeight)
- randomWeight := common.GetRandomInt(totalWeight)
- // Find a channel based on its weight
- for _, channel := range targetChannels {
- randomWeight -= channel.GetWeight() + smoothingFactor
- if randomWeight < 0 {
- return channel, nil
- }
- }
- // 如果循环结束还没有找到,则返回第一个渠道(兜底)
- return targetChannels[0], nil
- }
- func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
- if strings.HasPrefix(model, "gpt-4-gizmo") {
- model = "gpt-4-gizmo-*"
- }
- if strings.HasPrefix(model, "gpt-4o-gizmo") {
- model = "gpt-4o-gizmo-*"
- }
- // if memory cache is disabled, get channel directly from database
- if !common.MemoryCacheEnabled {
- return GetRandomSatisfiedChannel(group, model, retry, nil)
- }
- channelSyncLock.RLock()
- defer channelSyncLock.RUnlock()
- channels := group2model2channels[group][model]
- if len(channels) == 0 {
- return nil, errors.New("channel not found")
- }
- if len(channels) == 1 {
- if channel, ok := channelsIDM[channels[0]]; ok {
- return channel, nil
- }
- return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channels[0])
- }
- uniquePriorities := make(map[int]bool)
- for _, channelId := range channels {
- if channel, ok := channelsIDM[channelId]; ok {
- uniquePriorities[int(channel.GetPriority())] = true
- } else {
- return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
- }
- }
- var sortedUniquePriorities []int
- for priority := range uniquePriorities {
- sortedUniquePriorities = append(sortedUniquePriorities, priority)
- }
- sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
- if retry >= len(uniquePriorities) {
- retry = len(uniquePriorities) - 1
- }
- targetPriority := int64(sortedUniquePriorities[retry])
- // get the priority for the given retry number
- var targetChannels []*Channel
- for _, channelId := range channels {
- if channel, ok := channelsIDM[channelId]; ok {
- if channel.GetPriority() == targetPriority {
- targetChannels = append(targetChannels, channel)
- }
- } else {
- return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
- }
- }
- // 平滑系数
- smoothingFactor := 10
- // Calculate the total weight of all channels up to endIdx
- totalWeight := 0
- for _, channel := range targetChannels {
- totalWeight += channel.GetWeight() + smoothingFactor
- }
- // Generate a random value in the range [0, totalWeight)
- randomWeight := rand.Intn(totalWeight)
- // Find a channel based on its weight
- for _, channel := range targetChannels {
- randomWeight -= channel.GetWeight() + smoothingFactor
- if randomWeight < 0 {
- return channel, nil
- }
- }
- // return null if no channel is not found
- return nil, errors.New("channel not found")
- }
- func CacheGetChannel(id int) (*Channel, error) {
- if !common.MemoryCacheEnabled {
- return GetChannelById(id, true)
- }
- channelSyncLock.RLock()
- defer channelSyncLock.RUnlock()
- c, ok := channelsIDM[id]
- if !ok {
- return nil, fmt.Errorf("渠道# %d,已不存在", id)
- }
- if c.Status != common.ChannelStatusEnabled {
- return nil, fmt.Errorf("渠道# %d,已被禁用", id)
- }
- return c, nil
- }
- func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
- if !common.MemoryCacheEnabled {
- channel, err := GetChannelById(id, true)
- if err != nil {
- return nil, err
- }
- return &channel.ChannelInfo, nil
- }
- channelSyncLock.RLock()
- defer channelSyncLock.RUnlock()
- c, ok := channelsIDM[id]
- if !ok {
- return nil, fmt.Errorf("渠道# %d,已不存在", id)
- }
- if c.Status != common.ChannelStatusEnabled {
- return nil, fmt.Errorf("渠道# %d,已被禁用", id)
- }
- return &c.ChannelInfo, nil
- }
- func CacheUpdateChannelStatus(id int, status int) {
- if !common.MemoryCacheEnabled {
- return
- }
- channelSyncLock.Lock()
- defer channelSyncLock.Unlock()
- if channel, ok := channelsIDM[id]; ok {
- channel.Status = status
- }
- }
- func CacheUpdateChannel(channel *Channel) {
- if !common.MemoryCacheEnabled {
- return
- }
- channelSyncLock.Lock()
- defer channelSyncLock.Unlock()
- if channel == nil {
- return
- }
- println("CacheUpdateChannel:", channel.Id, channel.Name, channel.Status, channel.ChannelInfo.MultiKeyPollingIndex)
- println("before:", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
- channelsIDM[channel.Id] = channel
- println("after :", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
- }
|