Selaa lähdekoodia

feat: add origin kling api

skynono 6 kuukautta sitten
vanhempi
sitoutus
cd2870aebc

+ 7 - 6
controller/task_video.go

@@ -56,8 +56,15 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
 	if channel.GetBaseURL() != "" {
 		baseURL = channel.GetBaseURL()
 	}
+
+	task := taskM[taskId]
+	if task == nil {
+		common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
+		return fmt.Errorf("task %s not found", taskId)
+	}
 	resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
 		"task_id": taskId,
+		"action":  task.Action,
 	})
 	if err != nil {
 		return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err)
@@ -89,12 +96,6 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
 		return fmt.Errorf("video task data format error for task %s", taskId)
 	}
 
-	task := taskM[taskId]
-	if task == nil {
-		common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
-		return fmt.Errorf("task %s not found", taskId)
-	}
-
 	if status, ok := data["task_status"].(string); ok {
 		switch status {
 		case "submitted", "queued":

+ 45 - 0
middleware/kling_adapter.go

@@ -0,0 +1,45 @@
+package middleware
+
+import (
+	"bytes"
+	"encoding/json"
+	"github.com/gin-gonic/gin"
+	"io"
+	"one-api/common"
+)
+
+func KlingRequestConvert() func(c *gin.Context) {
+	return func(c *gin.Context) {
+		var originalReq map[string]interface{}
+		if err := common.UnmarshalBodyReusable(c, &originalReq); err != nil {
+			c.Next()
+			return
+		}
+
+		model, _ := originalReq["model"].(string)
+		prompt, _ := originalReq["prompt"].(string)
+
+		unifiedReq := map[string]interface{}{
+			"model":    model,
+			"prompt":   prompt,
+			"metadata": originalReq,
+		}
+
+		jsonData, err := json.Marshal(unifiedReq)
+		if err != nil {
+			c.Next()
+			return
+		}
+
+		// Rewrite request body and path
+		c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
+		c.Request.URL.Path = "/v1/video/generations"
+		if image := originalReq["image"]; image == "" {
+			c.Set("action", "textGenerate")
+		}
+
+		// We have to reset the request body for the next handlers
+		c.Set(common.KeyRequestBody, jsonData)
+		c.Next()
+	}
+}

+ 29 - 10
relay/channel/task/kling/adaptor.go

@@ -5,6 +5,7 @@ import (
 	"context"
 	"encoding/json"
 	"fmt"
+	"github.com/samber/lo"
 	"io"
 	"net/http"
 	"strings"
@@ -41,7 +42,6 @@ type requestPayload struct {
 	Mode        string  `json:"mode,omitempty"`
 	Duration    string  `json:"duration,omitempty"`
 	AspectRatio string  `json:"aspect_ratio,omitempty"`
-	Model       string  `json:"model,omitempty"`
 	ModelName   string  `json:"model_name,omitempty"`
 	CfgScale    float64 `json:"cfg_scale,omitempty"`
 }
@@ -100,7 +100,8 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
 
 // BuildRequestURL constructs the upstream URL.
 func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
-	return fmt.Sprintf("%s/v1/videos/image2video", a.baseURL), nil
+	path := lo.Ternary(info.Action == "generate", "/v1/videos/image2video", "/v1/videos/text2video")
+	return fmt.Sprintf("%s%s", a.baseURL, path), nil
 }
 
 // BuildRequestHeader sets required headers.
@@ -125,7 +126,10 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel
 	}
 	req := v.(SubmitReq)
 
-	body := a.convertToRequestPayload(&req)
+	body, err := a.convertToRequestPayload(&req)
+	if err != nil {
+		return nil, err
+	}
 	data, err := json.Marshal(body)
 	if err != nil {
 		return nil, err
@@ -135,6 +139,9 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel
 
 // DoRequest delegates to common helper.
 func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
+	if action := c.GetString("action"); action != "" {
+		info.Action = action
+	}
 	return channel.DoTaskApiRequest(a, c, info, requestBody)
 }
 
@@ -175,7 +182,12 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
 	if !ok {
 		return nil, fmt.Errorf("invalid task_id")
 	}
-	url := fmt.Sprintf("%s/v1/videos/image2video/%s", baseUrl, taskID)
+	action, ok := body["action"].(string)
+	if !ok {
+		return nil, fmt.Errorf("invalid action")
+	}
+	path := lo.Ternary(action == "generate", "/v1/videos/image2video", "/v1/videos/text2video")
+	url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID)
 
 	req, err := http.NewRequest(http.MethodGet, url, nil)
 	if err != nil {
@@ -210,22 +222,29 @@ func (a *TaskAdaptor) GetChannelName() string {
 // helpers
 // ============================
 
-func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) *requestPayload {
-	r := &requestPayload{
+func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
+	r := requestPayload{
 		Prompt:      req.Prompt,
 		Image:       req.Image,
 		Mode:        defaultString(req.Mode, "std"),
 		Duration:    fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
 		AspectRatio: a.getAspectRatio(req.Size),
-		Model:       req.Model,
 		ModelName:   req.Model,
 		CfgScale:    0.5,
 	}
-	if r.Model == "" {
-		r.Model = "kling-v1"
+	if r.ModelName == "" {
 		r.ModelName = "kling-v1"
 	}
-	return r
+	metadata := req.Metadata
+	medaBytes, err := json.Marshal(metadata)
+	if err != nil {
+		return nil, errors.Wrap(err, "metadata marshal metadata failed")
+	}
+	err = json.Unmarshal(medaBytes, &r)
+	if err != nil {
+		return nil, errors.Wrap(err, "unmarshal metadata failed")
+	}
+	return &r, nil
 }
 
 func (a *TaskAdaptor) getAspectRatio(size string) string {

+ 7 - 0
router/video-router.go

@@ -14,4 +14,11 @@ func SetVideoRouter(router *gin.Engine) {
 		videoV1Router.POST("/video/generations", controller.RelayTask)
 		videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
 	}
+
+	klingV1Router := router.Group("/kling/v1")
+	klingV1Router.Use(middleware.KlingRequestConvert(), middleware.TokenAuth(), middleware.Distribute())
+	{
+		klingV1Router.POST("/videos/text2video", controller.RelayTask)
+		klingV1Router.POST("/videos/image2video", controller.RelayTask)
+	}
 }

+ 8 - 2
web/src/components/table/TaskLogsTable.js

@@ -212,7 +212,13 @@ const LogsTable = () => {
       case 'generate':
         return (
           <Tag color='blue' size='large' shape='circle' prefixIcon={<Sparkles size={14} />}>
-            {t('生成视频')}
+            {t('图生视频')}
+          </Tag>
+        );
+      case 'textGenerate':
+        return (
+          <Tag color='blue' size='large' shape='circle' prefixIcon={<Sparkles size={14} />}>
+            {t('文生视频')}
           </Tag>
         );
       default:
@@ -434,7 +440,7 @@ const LogsTable = () => {
       fixed: 'right',
       render: (text, record, index) => {
         // 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接
-        const isVideoTask = record.action === 'generate';
+        const isVideoTask = record.action === 'generate' || record.action === 'textGenerate';
         const isSuccess = record.status === 'SUCCESS';
         const isUrl = typeof text === 'string' && /^https?:\/\//.test(text);
         if (isSuccess && isVideoTask && isUrl) {