Просмотр исходного кода

Merge pull request #2632 from feitianbubu/pr/add-doubao-video-1.5

Seefs 1 неделя назад
Родитель
Сommit
ca11fcbabd
4 измененных файлов с 132 добавлено и 18 удалено
  1. 55 0
      dto/values.go
  2. 75 17
      relay/channel/task/doubao/adaptor.go
  3. 1 0
      relay/channel/task/doubao/constants.go
  4. 1 1
      relay/relay_adaptor.go

+ 55 - 0
dto/values.go

@@ -0,0 +1,55 @@
+package dto
+
+import (
+	"encoding/json"
+	"strconv"
+)
+
+type IntValue int
+
+func (i *IntValue) UnmarshalJSON(b []byte) error {
+	var n int
+	if err := json.Unmarshal(b, &n); err == nil {
+		*i = IntValue(n)
+		return nil
+	}
+	var s string
+	if err := json.Unmarshal(b, &s); err != nil {
+		return err
+	}
+	v, err := strconv.Atoi(s)
+	if err != nil {
+		return err
+	}
+	*i = IntValue(v)
+	return nil
+}
+
+func (i IntValue) MarshalJSON() ([]byte, error) {
+	return json.Marshal(int(i))
+}
+
+type BoolValue bool
+
+func (b *BoolValue) UnmarshalJSON(data []byte) error {
+	var boolean bool
+	if err := json.Unmarshal(data, &boolean); err == nil {
+		*b = BoolValue(boolean)
+		return nil
+	}
+	var str string
+	if err := json.Unmarshal(data, &str); err != nil {
+		return err
+	}
+	if str == "true" {
+		*b = BoolValue(true)
+	} else if str == "false" {
+		*b = BoolValue(false)
+	} else {
+		return json.Unmarshal(data, &boolean)
+	}
+	return nil
+}
+func (b BoolValue) MarshalJSON() ([]byte, error) {
+	return json.Marshal(bool(b))
+}

+ 75 - 17
relay/channel/task/doubao/adaptor.go

@@ -6,6 +6,9 @@ import (
 	"fmt"
 	"io"
 	"net/http"
+	"time"
+
+	"github.com/QuantumNous/new-api/common"
 
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
@@ -23,18 +26,36 @@ import (
 // ============================
 
 type ContentItem struct {
-	Type     string    `json:"type"`                // "text" or "image_url"
-	Text     string    `json:"text,omitempty"`      // for text type
-	ImageURL *ImageURL `json:"image_url,omitempty"` // for image_url type
+	Type     string          `json:"type"`                // "text", "image_url" or "video"
+	Text     string          `json:"text,omitempty"`      // for text type
+	ImageURL *ImageURL       `json:"image_url,omitempty"` // for image_url type
+	Video    *VideoReference `json:"video,omitempty"`     // for video (sample) type
 }
 
 type ImageURL struct {
 	URL string `json:"url"`
 }
 
+type VideoReference struct {
+	URL string `json:"url"` // Draft video URL
+}
+
 type requestPayload struct {
-	Model   string        `json:"model"`
-	Content []ContentItem `json:"content"`
+	Model                 string         `json:"model"`
+	Content               []ContentItem  `json:"content"`
+	CallbackURL           string         `json:"callback_url,omitempty"`
+	ReturnLastFrame       *dto.BoolValue `json:"return_last_frame,omitempty"`
+	ServiceTier           string         `json:"service_tier,omitempty"`
+	ExecutionExpiresAfter dto.IntValue   `json:"execution_expires_after,omitempty"`
+	GenerateAudio         *dto.BoolValue `json:"generate_audio,omitempty"`
+	Draft                 *dto.BoolValue `json:"draft,omitempty"`
+	Resolution            string         `json:"resolution,omitempty"`
+	Ratio                 string         `json:"ratio,omitempty"`
+	Duration              dto.IntValue   `json:"duration,omitempty"`
+	Frames                dto.IntValue   `json:"frames,omitempty"`
+	Seed                  dto.IntValue   `json:"seed,omitempty"`
+	CameraFixed           *dto.BoolValue `json:"camera_fixed,omitempty"`
+	Watermark             *dto.BoolValue `json:"watermark,omitempty"`
 }
 
 type responsePayload struct {
@@ -53,6 +74,7 @@ type responseTask struct {
 	Duration        int    `json:"duration"`
 	Ratio           string `json:"ratio"`
 	FramesPerSecond int    `json:"framespersecond"`
+	ServiceTier     string `json:"service_tier"`
 	Usage           struct {
 		CompletionTokens int `json:"completion_tokens"`
 		TotalTokens      int `json:"total_tokens"`
@@ -98,16 +120,16 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
 
 // BuildRequestBody converts request into Doubao specific format.
 func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
-	v, exists := c.Get("task_request")
-	if !exists {
-		return nil, fmt.Errorf("request not found in context")
+	req, err := relaycommon.GetTaskRequest(c)
+	if err != nil {
+		return nil, err
 	}
-	req := v.(relaycommon.TaskSubmitReq)
 
 	body, err := a.convertToRequestPayload(&req)
 	if err != nil {
 		return nil, errors.Wrap(err, "convert request payload failed")
 	}
+	info.UpstreamModelName = body.Model
 	data, err := json.Marshal(body)
 	if err != nil {
 		return nil, err
@@ -141,7 +163,13 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 		return
 	}
 
-	c.JSON(http.StatusOK, gin.H{"task_id": dResp.ID})
+	ov := dto.NewOpenAIVideo()
+	ov.ID = dResp.ID
+	ov.TaskID = dResp.ID
+	ov.CreatedAt = time.Now().Unix()
+	ov.Model = info.OriginModelName
+
+	c.JSON(http.StatusOK, ov)
 	return dResp.ID, responseBody, nil
 }
 
@@ -204,12 +232,15 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 		}
 	}
 
-	// TODO: Add support for additional parameters from metadata
-	// such as ratio, duration, seed, etc.
-	// metadata := req.Metadata
-	// if metadata != nil {
-	//     // Parse and apply metadata parameters
-	// }
+	metadata := req.Metadata
+	medaBytes, err := json.Marshal(metadata)
+	if err != nil {
+		return nil, errors.Wrap(err, "metadata marshal metadata failed")
+	}
+	err = json.Unmarshal(medaBytes, &r)
+	if err != nil {
+		return nil, errors.Wrap(err, "unmarshal metadata failed")
+	}
 
 	return &r, nil
 }
@@ -229,7 +260,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 	case "pending", "queued":
 		taskResult.Status = model.TaskStatusQueued
 		taskResult.Progress = "10%"
-	case "processing":
+	case "processing", "running":
 		taskResult.Status = model.TaskStatusInProgress
 		taskResult.Progress = "50%"
 	case "succeeded":
@@ -251,3 +282,30 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 
 	return &taskResult, nil
 }
+
+func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
+	var dResp responseTask
+	if err := json.Unmarshal(originTask.Data, &dResp); err != nil {
+		return nil, errors.Wrap(err, "unmarshal doubao task data failed")
+	}
+
+	openAIVideo := dto.NewOpenAIVideo()
+	openAIVideo.ID = originTask.TaskID
+	openAIVideo.TaskID = originTask.TaskID
+	openAIVideo.Status = originTask.Status.ToVideoStatus()
+	openAIVideo.SetProgressStr(originTask.Progress)
+	openAIVideo.SetMetadata("url", dResp.Content.VideoURL)
+	openAIVideo.CreatedAt = originTask.CreatedAt
+	openAIVideo.CompletedAt = originTask.UpdatedAt
+	openAIVideo.Model = originTask.Properties.OriginModelName
+
+	if dResp.Status == "failed" {
+		openAIVideo.Error = &dto.OpenAIVideoError{
+			Message: "task failed",
+			Code:    "failed",
+		}
+	}
+
+	jsonData, _ := common.Marshal(openAIVideo)
+	return jsonData, nil
+}

+ 1 - 0
relay/channel/task/doubao/constants.go

@@ -4,6 +4,7 @@ var ModelList = []string{
 	"doubao-seedance-1-0-pro-250528",
 	"doubao-seedance-1-0-lite-t2v",
 	"doubao-seedance-1-0-lite-i2v",
+	"doubao-seedance-1-5-pro-251215",
 }
 
 var ChannelName = "doubao-video"

+ 1 - 1
relay/relay_adaptor.go

@@ -148,7 +148,7 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
 			return &taskvertex.TaskAdaptor{}
 		case constant.ChannelTypeVidu:
 			return &taskVidu.TaskAdaptor{}
-		case constant.ChannelTypeDoubaoVideo:
+		case constant.ChannelTypeDoubaoVideo, constant.ChannelTypeVolcEngine:
 			return &taskdoubao.TaskAdaptor{}
 		case constant.ChannelTypeSora, constant.ChannelTypeOpenAI:
 			return &tasksora.TaskAdaptor{}