|
|
@@ -5,6 +5,7 @@ import (
|
|
|
"context"
|
|
|
"encoding/json"
|
|
|
"fmt"
|
|
|
+ "github.com/samber/lo"
|
|
|
"io"
|
|
|
"net/http"
|
|
|
"strings"
|
|
|
@@ -41,7 +42,6 @@ type requestPayload struct {
|
|
|
Mode string `json:"mode,omitempty"`
|
|
|
Duration string `json:"duration,omitempty"`
|
|
|
AspectRatio string `json:"aspect_ratio,omitempty"`
|
|
|
- Model string `json:"model,omitempty"`
|
|
|
ModelName string `json:"model_name,omitempty"`
|
|
|
CfgScale float64 `json:"cfg_scale,omitempty"`
|
|
|
}
|
|
|
@@ -100,7 +100,8 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
|
|
|
|
|
// BuildRequestURL constructs the upstream URL.
|
|
|
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
|
|
- return fmt.Sprintf("%s/v1/videos/image2video", a.baseURL), nil
|
|
|
+ path := lo.Ternary(info.Action == "generate", "/v1/videos/image2video", "/v1/videos/text2video")
|
|
|
+ return fmt.Sprintf("%s%s", a.baseURL, path), nil
|
|
|
}
|
|
|
|
|
|
// BuildRequestHeader sets required headers.
|
|
|
@@ -125,7 +126,10 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel
|
|
|
}
|
|
|
req := v.(SubmitReq)
|
|
|
|
|
|
- body := a.convertToRequestPayload(&req)
|
|
|
+ body, err := a.convertToRequestPayload(&req)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
data, err := json.Marshal(body)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
@@ -135,6 +139,9 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel
|
|
|
|
|
|
// DoRequest delegates to common helper.
|
|
|
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
|
|
+ if action := c.GetString("action"); action != "" {
|
|
|
+ info.Action = action
|
|
|
+ }
|
|
|
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
|
|
}
|
|
|
|
|
|
@@ -175,7 +182,12 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
|
|
if !ok {
|
|
|
return nil, fmt.Errorf("invalid task_id")
|
|
|
}
|
|
|
- url := fmt.Sprintf("%s/v1/videos/image2video/%s", baseUrl, taskID)
|
|
|
+ action, ok := body["action"].(string)
|
|
|
+ if !ok {
|
|
|
+ return nil, fmt.Errorf("invalid action")
|
|
|
+ }
|
|
|
+ path := lo.Ternary(action == "generate", "/v1/videos/image2video", "/v1/videos/text2video")
|
|
|
+ url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID)
|
|
|
|
|
|
req, err := http.NewRequest(http.MethodGet, url, nil)
|
|
|
if err != nil {
|
|
|
@@ -210,22 +222,29 @@ func (a *TaskAdaptor) GetChannelName() string {
|
|
|
// helpers
|
|
|
// ============================
|
|
|
|
|
|
-func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) *requestPayload {
|
|
|
- r := &requestPayload{
|
|
|
+func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
|
|
|
+ r := requestPayload{
|
|
|
Prompt: req.Prompt,
|
|
|
Image: req.Image,
|
|
|
Mode: defaultString(req.Mode, "std"),
|
|
|
Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
|
|
|
AspectRatio: a.getAspectRatio(req.Size),
|
|
|
- Model: req.Model,
|
|
|
ModelName: req.Model,
|
|
|
CfgScale: 0.5,
|
|
|
}
|
|
|
- if r.Model == "" {
|
|
|
- r.Model = "kling-v1"
|
|
|
+ if r.ModelName == "" {
|
|
|
r.ModelName = "kling-v1"
|
|
|
}
|
|
|
- return r
|
|
|
+ metadata := req.Metadata
|
|
|
+ medaBytes, err := json.Marshal(metadata)
|
|
|
+ if err != nil {
|
|
|
+ return nil, errors.Wrap(err, "metadata marshal metadata failed")
|
|
|
+ }
|
|
|
+ err = json.Unmarshal(medaBytes, &r)
|
|
|
+ if err != nil {
|
|
|
+ return nil, errors.Wrap(err, "unmarshal metadata failed")
|
|
|
+ }
|
|
|
+ return &r, nil
|
|
|
}
|
|
|
|
|
|
func (a *TaskAdaptor) getAspectRatio(size string) string {
|