|
|
@@ -64,8 +64,8 @@ func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewA
|
|
|
func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
|
|
|
|
|
requestId := c.GetString(common.RequestIdKey)
|
|
|
- group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
|
|
|
- originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
|
|
|
+ //group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
|
|
|
+ //originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
|
|
|
|
|
|
var (
|
|
|
newAPIError *types.NewAPIError
|
|
|
@@ -158,7 +158,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
|
|
}()
|
|
|
|
|
|
for i := 0; i <= common.RetryTimes; i++ {
|
|
|
- channel, err := getChannel(c, group, originalModel, i)
|
|
|
+ channel, err := getChannel(c, relayInfo, i)
|
|
|
if err != nil {
|
|
|
logger.LogError(c, err.Error())
|
|
|
newAPIError = err
|
|
|
@@ -211,7 +211,7 @@ func addUsedChannel(c *gin.Context, channelId int) {
|
|
|
c.Set("use_channel", useChannel)
|
|
|
}
|
|
|
|
|
|
-func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) {
|
|
|
+func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryCount int) (*model.Channel, *types.NewAPIError) {
|
|
|
if retryCount == 0 {
|
|
|
autoBan := c.GetBool("auto_ban")
|
|
|
autoBanInt := 1
|
|
|
@@ -225,14 +225,18 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
|
|
|
AutoBan: &autoBanInt,
|
|
|
}, nil
|
|
|
}
|
|
|
- channel, selectGroup, err := service.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
|
|
+ channel, selectGroup, err := service.CacheGetRandomSatisfiedChannel(c, info.TokenGroup, info.OriginModelName, retryCount)
|
|
|
+
|
|
|
+ info.PriceData.GroupRatioInfo = helper.HandleGroupRatio(c, info)
|
|
|
+
|
|
|
if err != nil {
|
|
|
- return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
|
|
+ return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, info.OriginModelName, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
|
|
}
|
|
|
if channel == nil {
|
|
|
- return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(retry)", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
|
|
+ return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(retry)", selectGroup, info.OriginModelName), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
|
|
}
|
|
|
- newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
|
|
+
|
|
|
+ newAPIError := middleware.SetupContextForSelectedChannel(c, channel, info.OriginModelName)
|
|
|
if newAPIError != nil {
|
|
|
return nil, newAPIError
|
|
|
}
|
|
|
@@ -392,8 +396,6 @@ func RelayNotFound(c *gin.Context) {
|
|
|
func RelayTask(c *gin.Context) {
|
|
|
retryTimes := common.RetryTimes
|
|
|
channelId := c.GetInt("channel_id")
|
|
|
- group := c.GetString("group")
|
|
|
- originalModel := c.GetString("original_model")
|
|
|
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
|
|
|
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
|
|
|
if err != nil {
|
|
|
@@ -404,7 +406,7 @@ func RelayTask(c *gin.Context) {
|
|
|
retryTimes = 0
|
|
|
}
|
|
|
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
|
|
- channel, newAPIError := getChannel(c, group, originalModel, i)
|
|
|
+ channel, newAPIError := getChannel(c, relayInfo, i)
|
|
|
if newAPIError != nil {
|
|
|
logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
|
|
|
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
|