|
|
@@ -268,6 +268,7 @@ func GetChannelFromHeader(
|
|
|
mc *model.ModelCaches,
|
|
|
availableSet []string,
|
|
|
model string,
|
|
|
+ m mode.Mode,
|
|
|
) (*model.Channel, error) {
|
|
|
channelIDInt, err := strconv.ParseInt(header, 10, 64)
|
|
|
if err != nil {
|
|
|
@@ -279,6 +280,13 @@ func GetChannelFromHeader(
|
|
|
if len(enabledChannels) > 0 {
|
|
|
for _, channel := range enabledChannels {
|
|
|
if int64(channel.ID) == channelIDInt {
|
|
|
+ a, ok := adaptors.GetAdaptor(channel.Type)
|
|
|
+ if !ok {
|
|
|
+ return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
|
|
|
+ }
|
|
|
+ if !a.SupportMode(m) {
|
|
|
+ return nil, fmt.Errorf("channel %d not supported by adaptor", channel.ID)
|
|
|
+ }
|
|
|
return channel, nil
|
|
|
}
|
|
|
}
|
|
|
@@ -288,6 +296,13 @@ func GetChannelFromHeader(
|
|
|
if len(disabledChannels) > 0 {
|
|
|
for _, channel := range disabledChannels {
|
|
|
if int64(channel.ID) == channelIDInt {
|
|
|
+ a, ok := adaptors.GetAdaptor(channel.Type)
|
|
|
+ if !ok {
|
|
|
+ return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
|
|
|
+ }
|
|
|
+ if !a.SupportMode(m) {
|
|
|
+ return nil, fmt.Errorf("channel %d not supported by adaptor", channel.ID)
|
|
|
+ }
|
|
|
return channel, nil
|
|
|
}
|
|
|
}
|
|
|
@@ -316,6 +331,16 @@ func GetChannelFromRequest(
|
|
|
if len(enabledChannels) > 0 {
|
|
|
for _, channel := range enabledChannels {
|
|
|
if channel.ID == channelID {
|
|
|
+ a, ok := adaptors.GetAdaptor(channel.Type)
|
|
|
+ if !ok {
|
|
|
+ return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
|
|
|
+ }
|
|
|
+ if !a.SupportMode(m) {
|
|
|
+ return nil, fmt.Errorf(
|
|
|
+ "channel %d not supported by adaptor",
|
|
|
+ channel.ID,
|
|
|
+ )
|
|
|
+ }
|
|
|
return channel, nil
|
|
|
}
|
|
|
}
|
|
|
@@ -332,6 +357,16 @@ func GetChannelFromRequest(
|
|
|
if len(enabledChannels) > 0 {
|
|
|
for _, channel := range enabledChannels {
|
|
|
if channel.ID == channelID {
|
|
|
+ a, ok := adaptors.GetAdaptor(channel.Type)
|
|
|
+ if !ok {
|
|
|
+ return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
|
|
|
+ }
|
|
|
+ if !a.SupportMode(m) {
|
|
|
+ return nil, fmt.Errorf(
|
|
|
+ "channel %d not supported by adaptor",
|
|
|
+ channel.ID,
|
|
|
+ )
|
|
|
+ }
|
|
|
return channel, nil
|
|
|
}
|
|
|
}
|
|
|
@@ -454,12 +489,23 @@ func notifyChannelIssue(
|
|
|
)
|
|
|
}
|
|
|
|
|
|
-func filterChannels(channels []*model.Channel, ignoreChannel ...int64) []*model.Channel {
|
|
|
+func filterChannels(
|
|
|
+ channels []*model.Channel,
|
|
|
+ mode mode.Mode,
|
|
|
+ ignoreChannel ...int64,
|
|
|
+) []*model.Channel {
|
|
|
filtered := make([]*model.Channel, 0)
|
|
|
for _, channel := range channels {
|
|
|
if channel.Status != model.ChannelStatusEnabled {
|
|
|
continue
|
|
|
}
|
|
|
+ a, ok := adaptors.GetAdaptor(channel.Type)
|
|
|
+ if !ok {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if !a.SupportMode(mode) {
|
|
|
+ continue
|
|
|
+ }
|
|
|
if slices.Contains(ignoreChannel, int64(channel.ID)) {
|
|
|
continue
|
|
|
}
|
|
|
@@ -477,6 +523,7 @@ func GetRandomChannel(
|
|
|
mc *model.ModelCaches,
|
|
|
availableSet []string,
|
|
|
modelName string,
|
|
|
+ mode mode.Mode,
|
|
|
errorRates map[int64]float64,
|
|
|
ignoreChannel ...int64,
|
|
|
) (*model.Channel, []*model.Channel, error) {
|
|
|
@@ -484,12 +531,26 @@ func GetRandomChannel(
|
|
|
if len(availableSet) != 0 {
|
|
|
for _, set := range availableSet {
|
|
|
for _, channel := range mc.EnabledModel2ChannelsBySet[set][modelName] {
|
|
|
+ a, ok := adaptors.GetAdaptor(channel.Type)
|
|
|
+ if !ok {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if !a.SupportMode(mode) {
|
|
|
+ continue
|
|
|
+ }
|
|
|
channelMap[channel.ID] = channel
|
|
|
}
|
|
|
}
|
|
|
} else {
|
|
|
for _, sets := range mc.EnabledModel2ChannelsBySet {
|
|
|
for _, channel := range sets[modelName] {
|
|
|
+ a, ok := adaptors.GetAdaptor(channel.Type)
|
|
|
+ if !ok {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if !a.SupportMode(mode) {
|
|
|
+ continue
|
|
|
+ }
|
|
|
channelMap[channel.ID] = channel
|
|
|
}
|
|
|
}
|
|
|
@@ -498,7 +559,7 @@ func GetRandomChannel(
|
|
|
for _, channel := range channelMap {
|
|
|
migratedChannels = append(migratedChannels, channel)
|
|
|
}
|
|
|
- channel, err := getRandomChannel(migratedChannels, errorRates, ignoreChannel...)
|
|
|
+ channel, err := getRandomChannel(migratedChannels, mode, errorRates, ignoreChannel...)
|
|
|
return channel, migratedChannels, err
|
|
|
}
|
|
|
|
|
|
@@ -512,10 +573,9 @@ func getPriority(channel *model.Channel, errorRate float64) int32 {
|
|
|
return int32(float64(priority) / errorRate)
|
|
|
}
|
|
|
|
|
|
-//
|
|
|
-
|
|
|
func getRandomChannel(
|
|
|
channels []*model.Channel,
|
|
|
+ mode mode.Mode,
|
|
|
errorRates map[int64]float64,
|
|
|
ignoreChannel ...int64,
|
|
|
) (*model.Channel, error) {
|
|
|
@@ -523,7 +583,7 @@ func getRandomChannel(
|
|
|
return nil, ErrChannelsNotFound
|
|
|
}
|
|
|
|
|
|
- channels = filterChannels(channels, ignoreChannel...)
|
|
|
+ channels = filterChannels(channels, mode, ignoreChannel...)
|
|
|
if len(channels) == 0 {
|
|
|
return nil, ErrChannelsExhausted
|
|
|
}
|
|
|
@@ -559,6 +619,7 @@ func getChannelWithFallback(
|
|
|
cache *model.ModelCaches,
|
|
|
availableSet []string,
|
|
|
modelName string,
|
|
|
+ mode mode.Mode,
|
|
|
errorRates map[int64]float64,
|
|
|
ignoreChannelIDs ...int64,
|
|
|
) (*model.Channel, []*model.Channel, error) {
|
|
|
@@ -566,6 +627,7 @@ func getChannelWithFallback(
|
|
|
cache,
|
|
|
availableSet,
|
|
|
modelName,
|
|
|
+ mode,
|
|
|
errorRates,
|
|
|
ignoreChannelIDs...)
|
|
|
if err == nil {
|
|
|
@@ -574,7 +636,13 @@ func getChannelWithFallback(
|
|
|
if !errors.Is(err, ErrChannelsExhausted) {
|
|
|
return nil, migratedChannels, err
|
|
|
}
|
|
|
- channel, migratedChannels, err = GetRandomChannel(cache, availableSet, modelName, errorRates)
|
|
|
+ channel, migratedChannels, err = GetRandomChannel(
|
|
|
+ cache,
|
|
|
+ availableSet,
|
|
|
+ modelName,
|
|
|
+ mode,
|
|
|
+ errorRates,
|
|
|
+ )
|
|
|
return channel, migratedChannels, err
|
|
|
}
|
|
|
|
|
|
@@ -776,6 +844,7 @@ func getInitialChannel(c *gin.Context, modelName string, m mode.Mode) (*initialC
|
|
|
middleware.GetModelCaches(c),
|
|
|
availableSet,
|
|
|
modelName,
|
|
|
+ m,
|
|
|
)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
@@ -815,6 +884,7 @@ func getInitialChannel(c *gin.Context, modelName string, m mode.Mode) (*initialC
|
|
|
mc,
|
|
|
availableSet,
|
|
|
modelName,
|
|
|
+ m,
|
|
|
errorRates,
|
|
|
ids...)
|
|
|
if err != nil {
|
|
|
@@ -844,7 +914,13 @@ func getWebSearchChannel(c *gin.Context, modelName string) (*model.Channel, erro
|
|
|
log.Errorf("get channel model error rates failed: %+v", err)
|
|
|
}
|
|
|
|
|
|
- channel, _, err := getChannelWithFallback(mc, nil, modelName, errorRates, ids...)
|
|
|
+ channel, _, err := getChannelWithFallback(
|
|
|
+ mc,
|
|
|
+ nil,
|
|
|
+ modelName,
|
|
|
+ mode.ChatCompletions,
|
|
|
+ errorRates,
|
|
|
+ ids...)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
@@ -1006,6 +1082,7 @@ func getRetryChannel(state *retryState) (*model.Channel, error) {
|
|
|
|
|
|
newChannel, err := getRandomChannel(
|
|
|
state.migratedChannels,
|
|
|
+ state.meta.Mode,
|
|
|
state.errorRates,
|
|
|
state.ignoreChannelIDs...)
|
|
|
if err != nil {
|