Browse Source

Merge pull request #2225 from feitianbubu/pr/add-hailuo-video

新增MiniMax海螺视频模型支持
IcedTangerine 1 month ago
parent
commit
e7bef097dd

+ 1 - 0
controller/task_video.go

@@ -52,6 +52,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha
 	info.ChannelMeta = &relaycommon.ChannelMeta{
 		ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
 	}
+	info.ApiKey = cacheGetChannel.Key
 	adaptor.Init(info)
 	for _, taskId := range taskIds {
 		if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {

+ 11 - 0
model/task.go

@@ -429,3 +429,14 @@ func TaskCountAllUserTask(userId int, queryParams SyncTaskQueryParams) int64 {
 	_ = query.Count(&total).Error
 	return total
 }
+func (t *Task) ToOpenAIVideo() *dto.OpenAIVideo {
+	openAIVideo := dto.NewOpenAIVideo()
+	openAIVideo.ID = t.TaskID
+	openAIVideo.Status = t.Status.ToVideoStatus()
+	openAIVideo.Model = t.Properties.OriginModelName
+	openAIVideo.SetProgressStr(t.Progress)
+	openAIVideo.CreatedAt = t.CreatedAt
+	openAIVideo.CompletedAt = t.UpdatedAt
+	openAIVideo.SetMetadata("url", t.FailReason)
+	return openAIVideo
+}

+ 297 - 0
relay/channel/task/hailuo/adaptor.go

@@ -0,0 +1,297 @@
+package hailuo
+
+import (
+	"bytes"
+	"encoding/json"
+	"fmt"
+	"io"
+	"net/http"
+	"strconv"
+	"strings"
+	"time"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/model"
+	"github.com/gin-gonic/gin"
+	"github.com/pkg/errors"
+
+	"github.com/QuantumNous/new-api/constant"
+	"github.com/QuantumNous/new-api/dto"
+	"github.com/QuantumNous/new-api/relay/channel"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/QuantumNous/new-api/service"
+)
+
+// https://platform.minimaxi.com/docs/api-reference/video-generation-intro
+type TaskAdaptor struct {
+	ChannelType int
+	apiKey      string
+	baseURL     string
+}
+
+func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
+	a.ChannelType = info.ChannelType
+	a.baseURL = info.ChannelBaseUrl
+	a.apiKey = info.ApiKey
+}
+
+func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
+	return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
+}
+
+func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	return fmt.Sprintf("%s%s", a.baseURL, TextToVideoEndpoint), nil
+}
+
+func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+	req.Header.Set("Content-Type", "application/json")
+	req.Header.Set("Accept", "application/json")
+	req.Header.Set("Authorization", "Bearer "+a.apiKey)
+	return nil
+}
+
+func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
+	v, exists := c.Get("task_request")
+	if !exists {
+		return nil, fmt.Errorf("request not found in context")
+	}
+	req, ok := v.(relaycommon.TaskSubmitReq)
+	if !ok {
+		return nil, fmt.Errorf("invalid request type in context")
+	}
+
+	body, err := a.convertToRequestPayload(&req)
+	if err != nil {
+		return nil, errors.Wrap(err, "convert request payload failed")
+	}
+
+	data, err := json.Marshal(body)
+	if err != nil {
+		return nil, err
+	}
+
+	return bytes.NewReader(data), nil
+}
+
+func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+	return channel.DoTaskApiRequest(a, c, info, requestBody)
+}
+
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+		return
+	}
+	_ = resp.Body.Close()
+
+	var hResp VideoResponse
+	if err := json.Unmarshal(responseBody, &hResp); err != nil {
+		taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
+		return
+	}
+
+	if hResp.BaseResp.StatusCode != StatusSuccess {
+		taskErr = service.TaskErrorWrapper(
+			fmt.Errorf("hailuo api error: %s", hResp.BaseResp.StatusMsg),
+			strconv.Itoa(hResp.BaseResp.StatusCode),
+			http.StatusBadRequest,
+		)
+		return
+	}
+
+	ov := dto.NewOpenAIVideo()
+	ov.ID = hResp.TaskID
+	ov.TaskID = hResp.TaskID
+	ov.CreatedAt = time.Now().Unix()
+	ov.Model = info.OriginModelName
+
+	c.JSON(http.StatusOK, ov)
+	return hResp.TaskID, responseBody, nil
+}
+
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+	taskID, ok := body["task_id"].(string)
+	if !ok {
+		return nil, fmt.Errorf("invalid task_id")
+	}
+
+	uri := fmt.Sprintf("%s%s?task_id=%s", baseUrl, QueryTaskEndpoint, taskID)
+
+	req, err := http.NewRequest(http.MethodGet, uri, nil)
+	if err != nil {
+		return nil, err
+	}
+
+	req.Header.Set("Accept", "application/json")
+	req.Header.Set("Authorization", "Bearer "+key)
+
+	return service.GetHttpClient().Do(req)
+}
+
+func (a *TaskAdaptor) GetModelList() []string {
+	return ModelList
+}
+
+func (a *TaskAdaptor) GetChannelName() string {
+	return ChannelName
+}
+
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*VideoRequest, error) {
+	modelConfig := GetModelConfig(req.Model)
+	duration := DefaultDuration
+	if req.Duration > 0 {
+		duration = req.Duration
+	}
+	resolution := modelConfig.DefaultResolution
+	if req.Size != "" {
+		resolution = a.parseResolutionFromSize(req.Size, modelConfig)
+	}
+
+	videoRequest := &VideoRequest{
+		Model:      req.Model,
+		Prompt:     req.Prompt,
+		Duration:   &duration,
+		Resolution: resolution,
+	}
+	if err := req.UnmarshalMetadata(&videoRequest); err != nil {
+		return nil, errors.Wrap(err, "unmarshal metadata to video request failed")
+	}
+
+	return videoRequest, nil
+}
+
+func (a *TaskAdaptor) parseResolutionFromSize(size string, modelConfig ModelConfig) string {
+	switch {
+	case strings.Contains(size, "1080"):
+		return Resolution1080P
+	case strings.Contains(size, "768"):
+		return Resolution768P
+	case strings.Contains(size, "720"):
+		return Resolution720P
+	case strings.Contains(size, "512"):
+		return Resolution512P
+	default:
+		return modelConfig.DefaultResolution
+	}
+}
+
+func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
+	resTask := QueryTaskResponse{}
+	if err := json.Unmarshal(respBody, &resTask); err != nil {
+		return nil, errors.Wrap(err, "unmarshal task result failed")
+	}
+
+	taskResult := relaycommon.TaskInfo{}
+
+	if resTask.BaseResp.StatusCode == StatusSuccess {
+		taskResult.Code = 0
+	} else {
+		taskResult.Code = resTask.BaseResp.StatusCode
+		taskResult.Reason = resTask.BaseResp.StatusMsg
+		taskResult.Status = model.TaskStatusFailure
+		taskResult.Progress = "100%"
+	}
+
+	switch resTask.Status {
+	case TaskStatusPreparing, TaskStatusQueueing, TaskStatusProcessing:
+		taskResult.Status = model.TaskStatusInProgress
+		taskResult.Progress = "30%"
+		if resTask.Status == TaskStatusProcessing {
+			taskResult.Progress = "50%"
+		}
+	case TaskStatusSuccess:
+		taskResult.Status = model.TaskStatusSuccess
+		taskResult.Progress = "100%"
+		taskResult.Url = a.buildVideoURL(resTask.TaskID, resTask.FileID)
+	case TaskStatusFailed:
+		taskResult.Status = model.TaskStatusFailure
+		taskResult.Progress = "100%"
+		if taskResult.Reason == "" {
+			taskResult.Reason = "task failed"
+		}
+	default:
+		taskResult.Status = model.TaskStatusInProgress
+		taskResult.Progress = "30%"
+	}
+
+	return &taskResult, nil
+}
+
+func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
+	var hailuoResp QueryTaskResponse
+	if err := json.Unmarshal(originTask.Data, &hailuoResp); err != nil {
+		return nil, errors.Wrap(err, "unmarshal hailuo task data failed")
+	}
+
+	openAIVideo := originTask.ToOpenAIVideo()
+	if hailuoResp.BaseResp.StatusCode != StatusSuccess {
+		openAIVideo.Error = &dto.OpenAIVideoError{
+			Message: hailuoResp.BaseResp.StatusMsg,
+			Code:    strconv.Itoa(hailuoResp.BaseResp.StatusCode),
+		}
+	}
+
+	jsonData, err := common.Marshal(openAIVideo)
+	if err != nil {
+		return nil, errors.Wrap(err, "marshal openai video failed")
+	}
+
+	return jsonData, nil
+}
+
+func (a *TaskAdaptor) buildVideoURL(_, fileID string) string {
+	if a.apiKey == "" || a.baseURL == "" {
+		return ""
+	}
+
+	url := fmt.Sprintf("%s/v1/files/retrieve?file_id=%s", a.baseURL, fileID)
+
+	req, err := http.NewRequest(http.MethodGet, url, nil)
+	if err != nil {
+		return ""
+	}
+
+	req.Header.Set("Accept", "application/json")
+	req.Header.Set("Authorization", "Bearer "+a.apiKey)
+
+	resp, err := service.GetHttpClient().Do(req)
+	if err != nil {
+		return ""
+	}
+	defer resp.Body.Close()
+
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return ""
+	}
+
+	var retrieveResp RetrieveFileResponse
+	if err := json.Unmarshal(responseBody, &retrieveResp); err != nil {
+		return ""
+	}
+
+	if retrieveResp.BaseResp.StatusCode != StatusSuccess {
+		return ""
+	}
+
+	return retrieveResp.File.DownloadURL
+}
+
+func contains(slice []string, item string) bool {
+	for _, s := range slice {
+		if s == item {
+			return true
+		}
+	}
+	return false
+}
+
+func containsInt(slice []int, item int) bool {
+	for _, s := range slice {
+		if s == item {
+			return true
+		}
+	}
+	return false
+}

+ 52 - 0
relay/channel/task/hailuo/constants.go

@@ -0,0 +1,52 @@
+package hailuo
+
+const (
+	ChannelName = "hailuo-video"
+)
+
+var ModelList = []string{
+	"MiniMax-Hailuo-2.3",
+	"MiniMax-Hailuo-2.3-Fast",
+	"MiniMax-Hailuo-02",
+	"T2V-01-Director",
+	"T2V-01",
+	"I2V-01-Director",
+	"I2V-01-live",
+	"I2V-01",
+	"S2V-01",
+}
+
+const (
+	TextToVideoEndpoint = "/v1/video_generation"
+	QueryTaskEndpoint   = "/v1/query/video_generation"
+)
+
+const (
+	StatusSuccess    = 0
+	StatusRateLimit  = 1002
+	StatusAuthFailed = 1004
+	StatusNoBalance  = 1008
+	StatusSensitive  = 1026
+	StatusParamError = 2013
+	StatusInvalidKey = 2049
+)
+
+const (
+	TaskStatusPreparing  = "Preparing"
+	TaskStatusQueueing   = "Queueing"
+	TaskStatusProcessing = "Processing"
+	TaskStatusSuccess    = "Success"
+	TaskStatusFailed     = "Fail"
+)
+
+const (
+	Resolution512P  = "512P"
+	Resolution720P  = "720P"
+	Resolution768P  = "768P"
+	Resolution1080P = "1080P"
+)
+
+const (
+	DefaultDuration   = 6
+	DefaultResolution = Resolution720P
+)

+ 170 - 0
relay/channel/task/hailuo/models.go

@@ -0,0 +1,170 @@
+package hailuo
+
+type SubjectReference struct {
+	Type  string   `json:"type"`  // Subject type, currently only supports "character"
+	Image []string `json:"image"` // Array of subject reference images (currently only supports single image)
+}
+
+type VideoRequest struct {
+	Model            string             `json:"model"`
+	Prompt           string             `json:"prompt,omitempty"`
+	PromptOptimizer  *bool              `json:"prompt_optimizer,omitempty"`
+	FastPretreatment *bool              `json:"fast_pretreatment,omitempty"`
+	Duration         *int               `json:"duration,omitempty"`
+	Resolution       string             `json:"resolution,omitempty"`
+	CallbackURL      string             `json:"callback_url,omitempty"`
+	AigcWatermark    *bool              `json:"aigc_watermark,omitempty"`
+	FirstFrameImage  string             `json:"first_frame_image,omitempty"` // For image-to-video and start-end-to-video
+	LastFrameImage   string             `json:"last_frame_image,omitempty"`  // For start-end-to-video
+	SubjectReference []SubjectReference `json:"subject_reference,omitempty"` // For subject-reference-to-video
+}
+
+type VideoResponse struct {
+	TaskID   string   `json:"task_id"`
+	BaseResp BaseResp `json:"base_resp"`
+}
+
+type BaseResp struct {
+	StatusCode int    `json:"status_code"`
+	StatusMsg  string `json:"status_msg"`
+}
+
+type QueryTaskRequest struct {
+	TaskID string `json:"task_id"`
+}
+
+type QueryTaskResponse struct {
+	TaskID      string   `json:"task_id"`
+	Status      string   `json:"status"`
+	FileID      string   `json:"file_id,omitempty"`
+	VideoWidth  int      `json:"video_width,omitempty"`
+	VideoHeight int      `json:"video_height,omitempty"`
+	BaseResp    BaseResp `json:"base_resp"`
+}
+
+type ErrorInfo struct {
+	StatusCode int    `json:"status_code"`
+	StatusMsg  string `json:"status_msg"`
+}
+
+type TaskStatusInfo struct {
+	TaskID    string `json:"task_id"`
+	Status    string `json:"status"`
+	FileID    string `json:"file_id,omitempty"`
+	VideoURL  string `json:"video_url,omitempty"`
+	ErrorCode int    `json:"error_code,omitempty"`
+	ErrorMsg  string `json:"error_msg,omitempty"`
+}
+
+type ModelConfig struct {
+	Name                 string
+	DefaultResolution    string
+	SupportedDurations   []int
+	SupportedResolutions []string
+	HasPromptOptimizer   bool
+	HasFastPretreatment  bool
+}
+
+type RetrieveFileResponse struct {
+	File     FileObject `json:"file"`
+	BaseResp BaseResp   `json:"base_resp"`
+}
+
+type FileObject struct {
+	FileID      int64  `json:"file_id"`
+	Bytes       int64  `json:"bytes"`
+	CreatedAt   int64  `json:"created_at"`
+	Filename    string `json:"filename"`
+	Purpose     string `json:"purpose"`
+	DownloadURL string `json:"download_url"`
+}
+
+func GetModelConfig(model string) ModelConfig {
+	configs := map[string]ModelConfig{
+		"MiniMax-Hailuo-2.3": {
+			Name:                 "MiniMax-Hailuo-2.3",
+			DefaultResolution:    Resolution768P,
+			SupportedDurations:   []int{6, 10},
+			SupportedResolutions: []string{Resolution768P, Resolution1080P},
+			HasPromptOptimizer:   true,
+			HasFastPretreatment:  true,
+		},
+		"MiniMax-Hailuo-2.3-Fast": {
+			Name:                 "MiniMax-Hailuo-2.3-Fast",
+			DefaultResolution:    Resolution768P,
+			SupportedDurations:   []int{6, 10},
+			SupportedResolutions: []string{Resolution768P, Resolution1080P},
+			HasPromptOptimizer:   true,
+			HasFastPretreatment:  true,
+		},
+		"MiniMax-Hailuo-02": {
+			Name:                 "MiniMax-Hailuo-02",
+			DefaultResolution:    Resolution768P,
+			SupportedDurations:   []int{6, 10},
+			SupportedResolutions: []string{Resolution512P, Resolution768P, Resolution1080P},
+			HasPromptOptimizer:   true,
+			HasFastPretreatment:  true,
+		},
+		"T2V-01-Director": {
+			Name:                 "T2V-01-Director",
+			DefaultResolution:    Resolution768P,
+			SupportedDurations:   []int{6},
+			SupportedResolutions: []string{Resolution768P, Resolution1080P},
+			HasPromptOptimizer:   true,
+			HasFastPretreatment:  false,
+		},
+		"T2V-01": {
+			Name:                 "T2V-01",
+			DefaultResolution:    Resolution720P,
+			SupportedDurations:   []int{6},
+			SupportedResolutions: []string{Resolution720P},
+			HasPromptOptimizer:   true,
+			HasFastPretreatment:  false,
+		},
+		"I2V-01-Director": {
+			Name:                 "I2V-01-Director",
+			DefaultResolution:    Resolution720P,
+			SupportedDurations:   []int{6},
+			SupportedResolutions: []string{Resolution720P, Resolution1080P},
+			HasPromptOptimizer:   true,
+			HasFastPretreatment:  false,
+		},
+		"I2V-01-live": {
+			Name:                 "I2V-01-live",
+			DefaultResolution:    Resolution720P,
+			SupportedDurations:   []int{6},
+			SupportedResolutions: []string{Resolution720P, Resolution1080P},
+			HasPromptOptimizer:   true,
+			HasFastPretreatment:  false,
+		},
+		"I2V-01": {
+			Name:                 "I2V-01",
+			DefaultResolution:    Resolution720P,
+			SupportedDurations:   []int{6},
+			SupportedResolutions: []string{Resolution720P, Resolution1080P},
+			HasPromptOptimizer:   true,
+			HasFastPretreatment:  false,
+		},
+		"S2V-01": {
+			Name:                 "S2V-01",
+			DefaultResolution:    Resolution720P,
+			SupportedDurations:   []int{6},
+			SupportedResolutions: []string{Resolution720P},
+			HasPromptOptimizer:   true,
+			HasFastPretreatment:  false,
+		},
+	}
+
+	if config, exists := configs[model]; exists {
+		return config
+	}
+
+	return ModelConfig{
+		Name:                 model,
+		DefaultResolution:    DefaultResolution,
+		SupportedDurations:   []int{6},
+		SupportedResolutions: []string{DefaultResolution},
+		HasPromptOptimizer:   true,
+		HasFastPretreatment:  false,
+	}
+}

+ 16 - 2
relay/common/relay_info.go

@@ -498,11 +498,11 @@ type TaskSubmitReq struct {
 	Metadata       map[string]interface{} `json:"metadata,omitempty"`
 }
 
-func (t TaskSubmitReq) GetPrompt() string {
+func (t *TaskSubmitReq) GetPrompt() string {
 	return t.Prompt
 }
 
-func (t TaskSubmitReq) HasImage() bool {
+func (t *TaskSubmitReq) HasImage() bool {
 	return len(t.Images) > 0
 }
 
@@ -537,6 +537,20 @@ func (t *TaskSubmitReq) UnmarshalJSON(data []byte) error {
 
 	return nil
 }
+func (t *TaskSubmitReq) UnmarshalMetadata(v any) error {
+	metadata := t.Metadata
+	if metadata != nil {
+		metadataBytes, err := json.Marshal(metadata)
+		if err != nil {
+			return fmt.Errorf("marshal metadata failed: %w", err)
+		}
+		err = json.Unmarshal(metadataBytes, v)
+		if err != nil {
+			return fmt.Errorf("unmarshal metadata to target failed: %w", err)
+		}
+	}
+	return nil
+}
 
 type TaskInfo struct {
 	Code             int    `json:"code"`

+ 3 - 0
relay/relay_adaptor.go

@@ -32,6 +32,7 @@ import (
 	taskali "github.com/QuantumNous/new-api/relay/channel/task/ali"
 	taskdoubao "github.com/QuantumNous/new-api/relay/channel/task/doubao"
 	taskGemini "github.com/QuantumNous/new-api/relay/channel/task/gemini"
+	"github.com/QuantumNous/new-api/relay/channel/task/hailuo"
 	taskjimeng "github.com/QuantumNous/new-api/relay/channel/task/jimeng"
 	"github.com/QuantumNous/new-api/relay/channel/task/kling"
 	tasksora "github.com/QuantumNous/new-api/relay/channel/task/sora"
@@ -153,6 +154,8 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
 			return &tasksora.TaskAdaptor{}
 		case constant.ChannelTypeGemini:
 			return &taskGemini.TaskAdaptor{}
+		case constant.ChannelTypeMiniMax:
+			return &hailuo.TaskAdaptor{}
 		}
 	}
 	return nil