Просмотр исходного кода

fix: check pinned channel and ignore some record mode (#329)

zijiren 4 месяцев назад
Родитель
Сommit
78be20edca
2 измененных файлов с 44 добавлено и 54 удалено
  1. 5 1
      core/common/consume/consume.go
  2. 39 53
      core/controller/relay-channel.go

+ 5 - 1
core/common/consume/consume.go

@@ -117,7 +117,11 @@ func Consume(
 func checkNeedRecordConsume(code int, meta *meta.Meta) bool {
 	switch meta.Mode {
 	case mode.VideoGenerationsGetJobs,
-		mode.VideoGenerationsContent:
+		mode.VideoGenerationsContent,
+		mode.ResponsesGet,
+		mode.ResponsesDelete,
+		mode.ResponsesCancel,
+		mode.ResponsesInputItems:
 		return code != http.StatusOK
 	default:
 		return true

+ 39 - 53
core/controller/relay-channel.go

@@ -74,6 +74,20 @@ func GetChannelFromHeader(
 	return nil, fmt.Errorf("channel %d not found for model `%s`", channelIDInt, model)
 }
 
+func needPinChannel(m mode.Mode) bool {
+	switch m {
+	case mode.VideoGenerationsGetJobs,
+		mode.VideoGenerationsContent,
+		mode.ResponsesGet,
+		mode.ResponsesDelete,
+		mode.ResponsesCancel,
+		mode.ResponsesInputItems:
+		return true
+	default:
+		return false
+	}
+}
+
 func GetChannelFromRequest(
 	c *gin.Context,
 	mc *model.ModelCaches,
@@ -81,69 +95,41 @@ func GetChannelFromRequest(
 	modelName string,
 	m mode.Mode,
 ) (*model.Channel, error) {
-	switch m {
-	case mode.VideoGenerationsGetJobs,
-		mode.VideoGenerationsContent:
-		channelID := middleware.GetChannelID(c)
-		if channelID == 0 {
-			return nil, errors.New("channel id is required")
+	channelID := middleware.GetChannelID(c)
+	if channelID == 0 {
+		if needPinChannel(m) {
+			return nil, fmt.Errorf("%s need pinned channel", m)
 		}
+		return nil, nil
+	}
 
-		for _, set := range availableSet {
-			enabledChannels := mc.EnabledModel2ChannelsBySet[set][modelName]
-			if len(enabledChannels) > 0 {
-				for _, channel := range enabledChannels {
-					if channel.ID == channelID {
-						a, ok := adaptors.GetAdaptor(channel.Type)
-						if !ok {
-							return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
-						}
-
-						if !a.SupportMode(m) {
-							return nil, fmt.Errorf(
-								"channel %d not supported by adaptor",
-								channel.ID,
-							)
-						}
-
-						return channel, nil
+	for _, set := range availableSet {
+		enabledChannels := mc.EnabledModel2ChannelsBySet[set][modelName]
+		if len(enabledChannels) > 0 {
+			for _, channel := range enabledChannels {
+				if channel.ID == channelID {
+					a, ok := adaptors.GetAdaptor(channel.Type)
+					if !ok {
+						return nil, fmt.Errorf(
+							"adaptor not found for pinned channel %d",
+							channel.ID,
+						)
 					}
-				}
-			}
-		}
-
-		return nil, fmt.Errorf("channel %d not found for model `%s`", channelID, modelName)
-	default:
-		channelID := middleware.GetChannelID(c)
-		if channelID == 0 {
-			return nil, nil
-		}
 
-		for _, set := range availableSet {
-			enabledChannels := mc.EnabledModel2ChannelsBySet[set][modelName]
-			if len(enabledChannels) > 0 {
-				for _, channel := range enabledChannels {
-					if channel.ID == channelID {
-						a, ok := adaptors.GetAdaptor(channel.Type)
-						if !ok {
-							return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
-						}
-
-						if !a.SupportMode(m) {
-							return nil, fmt.Errorf(
-								"channel %d not supported by adaptor",
-								channel.ID,
-							)
-						}
-
-						return channel, nil
+					if !a.SupportMode(m) {
+						return nil, fmt.Errorf(
+							"pinned channel %d not supported by adaptor",
+							channel.ID,
+						)
 					}
+
+					return channel, nil
 				}
 			}
 		}
 	}
 
-	return nil, nil
+	return nil, fmt.Errorf("pinned channel %d not found for model `%s`", channelID, modelName)
 }
 
 var (