Browse Source

Merge branch 'feitianbubu-pr/vidu-add-first-end-reference-video'

creamlike1024 3 months ago
parent
commit
41be436c04

+ 4 - 2
constant/task.go

@@ -11,8 +11,10 @@ const (
 	SunoActionMusic  = "MUSIC"
 	SunoActionLyrics = "LYRICS"
 
-	TaskActionGenerate     = "generate"
-	TaskActionTextGenerate = "textGenerate"
+	TaskActionGenerate          = "generate"
+	TaskActionTextGenerate      = "textGenerate"
+	TaskActionFirstTailGenerate = "firstTailGenerate"
+	TaskActionReferenceGenerate = "referenceGenerate"
 )
 
 var SunoModel2Action = map[string]string{

+ 6 - 8
relay/channel/task/vidu/adaptor.go

@@ -80,8 +80,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
-	// Use the unified validation method for TaskSubmitReq with image-based action determination
-	return relaycommon.ValidateTaskRequestWithImageBinding(c, info)
+	return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
 }
 
 func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) {
@@ -112,6 +111,10 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro
 	switch info.Action {
 	case constant.TaskActionGenerate:
 		path = "/img2video"
+	case constant.TaskActionFirstTailGenerate:
+		path = "/start-end2video"
+	case constant.TaskActionReferenceGenerate:
+		path = "/reference2video"
 	default:
 		path = "/text2video"
 	}
@@ -187,14 +190,9 @@ func (a *TaskAdaptor) GetChannelName() string {
 // ============================
 
 func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
-	var images []string
-	if req.Image != "" {
-		images = []string{req.Image}
-	}
-
 	r := requestPayload{
 		Model:             defaultString(req.Model, "viduq1"),
-		Images:            images,
+		Images:            req.Images,
 		Prompt:            req.Prompt,
 		Duration:          defaultInt(req.Duration, 5),
 		Resolution:        defaultString(req.Size, "1080p"),

+ 10 - 26
relay/common/relay_utils.go

@@ -79,34 +79,18 @@ func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *d
 		req.Images = []string{req.Image}
 	}
 
-	storeTaskRequest(c, info, action, req)
-	return nil
-}
-
-func ValidateTaskRequestWithImage(c *gin.Context, info *RelayInfo, requestObj interface{}) *dto.TaskError {
-	hasPrompt, ok := requestObj.(HasPrompt)
-	if !ok {
-		return createTaskError(fmt.Errorf("request must have prompt"), "invalid_request", http.StatusBadRequest, true)
-	}
-
-	if taskErr := validatePrompt(hasPrompt.GetPrompt()); taskErr != nil {
-		return taskErr
-	}
-
-	action := constant.TaskActionTextGenerate
-	if hasImage, ok := requestObj.(HasImage); ok && hasImage.HasImage() {
+	if req.HasImage() {
 		action = constant.TaskActionGenerate
+		if info.ChannelType == constant.ChannelTypeVidu {
+			// vidu 增加 首尾帧生视频和参考图生视频
+			if len(req.Images) == 2 {
+				action = constant.TaskActionFirstTailGenerate
+			} else if len(req.Images) > 2 {
+				action = constant.TaskActionReferenceGenerate
+			}
+		}
 	}
 
-	storeTaskRequest(c, info, action, requestObj)
+	storeTaskRequest(c, info, action, req)
 	return nil
 }
-
-func ValidateTaskRequestWithImageBinding(c *gin.Context, info *RelayInfo) *dto.TaskError {
-	var req TaskSubmitReq
-	if err := c.ShouldBindJSON(&req); err != nil {
-		return createTaskError(err, "invalid_request_body", http.StatusBadRequest, false)
-	}
-
-	return ValidateTaskRequestWithImage(c, info, req)
-}

+ 18 - 3
web/src/components/table/task-logs/TaskLogsColumnDefs.jsx

@@ -35,8 +35,9 @@ import {
   Sparkles,
 } from 'lucide-react';
 import {
-  TASK_ACTION_GENERATE,
-  TASK_ACTION_TEXT_GENERATE,
+  TASK_ACTION_FIRST_TAIL_GENERATE,
+  TASK_ACTION_GENERATE, TASK_ACTION_REFERENCE_GENERATE,
+  TASK_ACTION_TEXT_GENERATE
 } from '../../../constants/common.constant';
 import { CHANNEL_OPTIONS } from '../../../constants/channel.constants';
 
@@ -111,6 +112,18 @@ const renderType = (type, t) => {
           {t('文生视频')}
         </Tag>
       );
+    case TASK_ACTION_FIRST_TAIL_GENERATE:
+      return (
+        <Tag color='blue' shape='circle' prefixIcon={<Sparkles size={14} />}>
+          {t('首尾生视频')}
+        </Tag>
+      );
+    case TASK_ACTION_REFERENCE_GENERATE:
+      return (
+        <Tag color='blue' shape='circle' prefixIcon={<Sparkles size={14} />}>
+          {t('参照生视频')}
+        </Tag>
+      );
     default:
       return (
         <Tag color='white' shape='circle' prefixIcon={<HelpCircle size={14} />}>
@@ -343,7 +356,9 @@ export const getTaskLogsColumns = ({
         // 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接
         const isVideoTask =
           record.action === TASK_ACTION_GENERATE ||
-          record.action === TASK_ACTION_TEXT_GENERATE;
+          record.action === TASK_ACTION_TEXT_GENERATE ||
+          record.action === TASK_ACTION_FIRST_TAIL_GENERATE ||
+          record.action === TASK_ACTION_REFERENCE_GENERATE;
         const isSuccess = record.status === 'SUCCESS';
         const isUrl = typeof text === 'string' && /^https?:\/\//.test(text);
         if (isSuccess && isVideoTask && isUrl) {

+ 2 - 0
web/src/constants/common.constant.js

@@ -40,3 +40,5 @@ export const API_ENDPOINTS = [
 
 export const TASK_ACTION_GENERATE = 'generate';
 export const TASK_ACTION_TEXT_GENERATE = 'textGenerate';
+export const TASK_ACTION_FIRST_TAIL_GENERATE = 'firstTailGenerate';
+export const TASK_ACTION_REFERENCE_GENERATE = 'referenceGenerate';