Bläddra i källkod

feat: add openai video remix endpoint

creamlike1024 3 veckor sedan
förälder
incheckning
d732cdd259
4 ändrade filer med 85 tillägg och 28 borttagningar
  1. 4 0
      middleware/distributor.go
  2. 20 0
      relay/channel/task/sora/adaptor.go
  3. 60 28
      relay/relay_task.go
  4. 1 0
      router/video-router.go

+ 4 - 0
middleware/distributor.go

@@ -181,6 +181,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 		}
 		c.Set("platform", string(constant.TaskPlatformSuno))
 		c.Set("relay_mode", relayMode)
+	} else if strings.Contains(c.Request.URL.Path, "/v1/videos/") && strings.HasSuffix(c.Request.URL.Path, "/remix") {
+		relayMode := relayconstant.RelayModeVideoSubmit
+		c.Set("relay_mode", relayMode)
+		shouldSelectChannel = false
 	} else if strings.Contains(c.Request.URL.Path, "/v1/videos") {
 		//curl https://api.openai.com/v1/videos \
 		//  -H "Authorization: Bearer $OPENAI_API_KEY" \

+ 20 - 0
relay/channel/task/sora/adaptor.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"io"
 	"net/http"
+	"strings"
 
 	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/dto"
@@ -67,11 +68,30 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
 	a.apiKey = info.ApiKey
 }
 
+func validateRemixRequest(c *gin.Context) *dto.TaskError {
+	var req struct {
+		Prompt string `json:"prompt"`
+	}
+	if err := common.UnmarshalBodyReusable(c, &req); err != nil {
+		return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
+	}
+	if strings.TrimSpace(req.Prompt) == "" {
+		return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest)
+	}
+	return nil
+}
+
 func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
+	if info.Action == "remix" {
+		return validateRemixRequest(c)
+	}
 	return relaycommon.ValidateMultipartDirect(c, info)
 }
 
 func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	if info.Action == "remix" {
+		return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil
+	}
 	return fmt.Sprintf("%s/v1/videos", a.baseURL), nil
 }
 

+ 60 - 28
relay/relay_task.go

@@ -32,7 +32,67 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
 	if info.TaskRelayInfo == nil {
 		info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
 	}
+	path := c.Request.URL.Path
+	if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") {
+		info.Action = "remix"
+	}
+
+	// 提取 remix 任务的 video_id
+	if info.Action == "remix" {
+		videoID := c.Param("video_id")
+		if strings.TrimSpace(videoID) == "" {
+			return service.TaskErrorWrapperLocal(fmt.Errorf("video_id is required"), "invalid_request", http.StatusBadRequest)
+		}
+		info.OriginTaskID = videoID
+	}
+
 	platform := constant.TaskPlatform(c.GetString("platform"))
+
+	// 获取原始任务信息
+	if info.OriginTaskID != "" {
+		originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
+		if err != nil {
+			taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
+			return
+		}
+		if !exist {
+			taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
+			return
+		}
+		if info.OriginModelName == "" {
+			if originTask.Properties.OriginModelName != "" {
+				info.OriginModelName = originTask.Properties.OriginModelName
+			} else if originTask.Properties.UpstreamModelName != "" {
+				info.OriginModelName = originTask.Properties.UpstreamModelName
+			} else {
+				var taskData map[string]interface{}
+				_ = json.Unmarshal(originTask.Data, &taskData)
+				if m, ok := taskData["model"].(string); ok && m != "" {
+					info.OriginModelName = m
+					platform = originTask.Platform
+				}
+			}
+		}
+		if originTask.ChannelId != info.ChannelId {
+			channel, err := model.GetChannelById(originTask.ChannelId, true)
+			if err != nil {
+				taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
+				return
+			}
+			if channel.Status != common.ChannelStatusEnabled {
+				taskErr = service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
+				return
+			}
+			c.Set("base_url", channel.GetBaseURL())
+			c.Set("channel_id", originTask.ChannelId)
+			c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
+
+			info.ChannelBaseUrl = channel.GetBaseURL()
+			info.ChannelId = originTask.ChannelId
+			platform = originTask.Platform
+		}
+
+	}
 	if platform == "" {
 		platform = GetTaskPlatform(c)
 	}
@@ -94,34 +154,6 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
 		return
 	}
 
-	if info.OriginTaskID != "" {
-		originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
-		if err != nil {
-			taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
-			return
-		}
-		if !exist {
-			taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
-			return
-		}
-		if originTask.ChannelId != info.ChannelId {
-			channel, err := model.GetChannelById(originTask.ChannelId, true)
-			if err != nil {
-				taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
-				return
-			}
-			if channel.Status != common.ChannelStatusEnabled {
-				return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
-			}
-			c.Set("base_url", channel.GetBaseURL())
-			c.Set("channel_id", originTask.ChannelId)
-			c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
-
-			info.ChannelBaseUrl = channel.GetBaseURL()
-			info.ChannelId = originTask.ChannelId
-		}
-	}
-
 	// build body
 	requestBody, err := adaptor.BuildRequestBody(c, info)
 	if err != nil {

+ 1 - 0
router/video-router.go

@@ -14,6 +14,7 @@ func SetVideoRouter(router *gin.Engine) {
 		videoV1Router.GET("/videos/:task_id/content", controller.VideoProxy)
 		videoV1Router.POST("/video/generations", controller.RelayTask)
 		videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
+		videoV1Router.POST("/videos/:video_id/remix", controller.RelayTask)
 	}
 	// openai compatible API video routes
 	// docs: https://platform.openai.com/docs/api-reference/videos/create