|
|
@@ -0,0 +1,312 @@
|
|
|
+package kling
|
|
|
+
|
|
|
+import (
|
|
|
+ "bytes"
|
|
|
+ "context"
|
|
|
+ "encoding/json"
|
|
|
+ "fmt"
|
|
|
+ "io"
|
|
|
+ "net/http"
|
|
|
+ "strings"
|
|
|
+ "time"
|
|
|
+
|
|
|
+ "github.com/gin-gonic/gin"
|
|
|
+ "github.com/golang-jwt/jwt"
|
|
|
+ "github.com/pkg/errors"
|
|
|
+
|
|
|
+ "one-api/common"
|
|
|
+ "one-api/dto"
|
|
|
+ "one-api/relay/channel"
|
|
|
+ relaycommon "one-api/relay/common"
|
|
|
+ "one-api/service"
|
|
|
+)
|
|
|
+
|
|
|
+// ============================
|
|
|
+// Request / Response structures
|
|
|
+// ============================
|
|
|
+
|
|
|
+type SubmitReq struct {
|
|
|
+ Prompt string `json:"prompt"`
|
|
|
+ Model string `json:"model,omitempty"`
|
|
|
+ Mode string `json:"mode,omitempty"`
|
|
|
+ Image string `json:"image,omitempty"`
|
|
|
+ Size string `json:"size,omitempty"`
|
|
|
+ Duration int `json:"duration,omitempty"`
|
|
|
+ Metadata map[string]interface{} `json:"metadata,omitempty"`
|
|
|
+}
|
|
|
+
|
|
|
+type requestPayload struct {
|
|
|
+ Prompt string `json:"prompt,omitempty"`
|
|
|
+ Image string `json:"image,omitempty"`
|
|
|
+ Mode string `json:"mode,omitempty"`
|
|
|
+ Duration string `json:"duration,omitempty"`
|
|
|
+ AspectRatio string `json:"aspect_ratio,omitempty"`
|
|
|
+ Model string `json:"model,omitempty"`
|
|
|
+ ModelName string `json:"model_name,omitempty"`
|
|
|
+ CfgScale float64 `json:"cfg_scale,omitempty"`
|
|
|
+}
|
|
|
+
|
|
|
+type responsePayload struct {
|
|
|
+ Code int `json:"code"`
|
|
|
+ Message string `json:"message"`
|
|
|
+ Data struct {
|
|
|
+ TaskID string `json:"task_id"`
|
|
|
+ } `json:"data"`
|
|
|
+}
|
|
|
+
|
|
|
+// ============================
|
|
|
+// Adaptor implementation
|
|
|
+// ============================
|
|
|
+
|
|
|
+type TaskAdaptor struct {
|
|
|
+ ChannelType int
|
|
|
+ accessKey string
|
|
|
+ secretKey string
|
|
|
+ baseURL string
|
|
|
+}
|
|
|
+
|
|
|
+func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
|
|
+ a.ChannelType = info.ChannelType
|
|
|
+ a.baseURL = info.BaseUrl
|
|
|
+
|
|
|
+ // apiKey format: "access_key,secret_key"
|
|
|
+ keyParts := strings.Split(info.ApiKey, ",")
|
|
|
+ if len(keyParts) == 2 {
|
|
|
+ a.accessKey = strings.TrimSpace(keyParts[0])
|
|
|
+ a.secretKey = strings.TrimSpace(keyParts[1])
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
|
|
+func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
|
|
+ // Accept only POST /v1/video/generations as "generate" action.
|
|
|
+ action := "generate"
|
|
|
+ info.Action = action
|
|
|
+
|
|
|
+ var req SubmitReq
|
|
|
+ if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
|
|
+ taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ if strings.TrimSpace(req.Prompt) == "" {
|
|
|
+ taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Store into context for later usage
|
|
|
+ c.Set("kling_request", req)
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+// BuildRequestURL constructs the upstream URL.
|
|
|
+func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
|
|
+ return fmt.Sprintf("%s/v1/videos/image2video", a.baseURL), nil
|
|
|
+}
|
|
|
+
|
|
|
+// BuildRequestHeader sets required headers.
|
|
|
+func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
|
|
+ token, err := a.createJWTToken()
|
|
|
+ if err != nil {
|
|
|
+ token = info.ApiKey // fallback
|
|
|
+ }
|
|
|
+
|
|
|
+ req.Header.Set("Content-Type", "application/json")
|
|
|
+ req.Header.Set("Accept", "application/json")
|
|
|
+ req.Header.Set("Authorization", "Bearer "+token)
|
|
|
+ req.Header.Set("User-Agent", "kling-sdk/1.0")
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+// BuildRequestBody converts request into Kling specific format.
|
|
|
+func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
|
|
+ v, exists := c.Get("kling_request")
|
|
|
+ if !exists {
|
|
|
+ return nil, fmt.Errorf("request not found in context")
|
|
|
+ }
|
|
|
+ req := v.(SubmitReq)
|
|
|
+
|
|
|
+ body := a.convertToRequestPayload(&req)
|
|
|
+ data, err := json.Marshal(body)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ return bytes.NewReader(data), nil
|
|
|
+}
|
|
|
+
|
|
|
+// DoRequest delegates to common helper.
|
|
|
+func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
|
|
+ return channel.DoTaskApiRequest(a, c, info, requestBody)
|
|
|
+}
|
|
|
+
|
|
|
+// DoResponse handles upstream response, returns taskID etc.
|
|
|
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
|
|
+ responseBody, err := io.ReadAll(resp.Body)
|
|
|
+ if err != nil {
|
|
|
+ taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Attempt Kling response parse first.
|
|
|
+ var kResp responsePayload
|
|
|
+ if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 {
|
|
|
+ c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskID})
|
|
|
+ return kResp.Data.TaskID, responseBody, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ // Fallback generic task response.
|
|
|
+ var generic dto.TaskResponse[string]
|
|
|
+ if err := json.Unmarshal(responseBody, &generic); err != nil {
|
|
|
+ taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if !generic.IsSuccess() {
|
|
|
+ taskErr = service.TaskErrorWrapper(fmt.Errorf(generic.Message), generic.Code, http.StatusInternalServerError)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ c.JSON(http.StatusOK, gin.H{"task_id": generic.Data})
|
|
|
+ return generic.Data, responseBody, nil
|
|
|
+}
|
|
|
+
|
|
|
+// FetchTask fetch task status
|
|
|
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
|
|
+ taskID, ok := body["task_id"].(string)
|
|
|
+ if !ok {
|
|
|
+ return nil, fmt.Errorf("invalid task_id")
|
|
|
+ }
|
|
|
+ url := fmt.Sprintf("%s/v1/videos/image2video/%s", baseUrl, taskID)
|
|
|
+
|
|
|
+ req, err := http.NewRequest(http.MethodGet, url, nil)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ token, err := a.createJWTTokenWithKey(key)
|
|
|
+ if err != nil {
|
|
|
+ token = key
|
|
|
+ }
|
|
|
+
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ req = req.WithContext(ctx)
|
|
|
+ req.Header.Set("Accept", "application/json")
|
|
|
+ req.Header.Set("Authorization", "Bearer "+token)
|
|
|
+ req.Header.Set("User-Agent", "kling-sdk/1.0")
|
|
|
+
|
|
|
+ return service.GetHttpClient().Do(req)
|
|
|
+}
|
|
|
+
|
|
|
+func (a *TaskAdaptor) GetModelList() []string {
|
|
|
+ return []string{"kling-v1", "kling-v1-6", "kling-v2-master"}
|
|
|
+}
|
|
|
+
|
|
|
+func (a *TaskAdaptor) GetChannelName() string {
|
|
|
+ return "kling"
|
|
|
+}
|
|
|
+
|
|
|
+// ============================
|
|
|
+// helpers
|
|
|
+// ============================
|
|
|
+
|
|
|
+func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) *requestPayload {
|
|
|
+ r := &requestPayload{
|
|
|
+ Prompt: req.Prompt,
|
|
|
+ Image: req.Image,
|
|
|
+ Mode: defaultString(req.Mode, "std"),
|
|
|
+ Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
|
|
|
+ AspectRatio: a.getAspectRatio(req.Size),
|
|
|
+ Model: req.Model,
|
|
|
+ ModelName: req.Model,
|
|
|
+ CfgScale: 0.5,
|
|
|
+ }
|
|
|
+ if r.Model == "" {
|
|
|
+ r.Model = "kling-v1"
|
|
|
+ r.ModelName = "kling-v1"
|
|
|
+ }
|
|
|
+ return r
|
|
|
+}
|
|
|
+
|
|
|
+func (a *TaskAdaptor) getAspectRatio(size string) string {
|
|
|
+ switch size {
|
|
|
+ case "1024x1024", "512x512":
|
|
|
+ return "1:1"
|
|
|
+ case "1280x720", "1920x1080":
|
|
|
+ return "16:9"
|
|
|
+ case "720x1280", "1080x1920":
|
|
|
+ return "9:16"
|
|
|
+ default:
|
|
|
+ return "1:1"
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func defaultString(s, def string) string {
|
|
|
+ if strings.TrimSpace(s) == "" {
|
|
|
+ return def
|
|
|
+ }
|
|
|
+ return s
|
|
|
+}
|
|
|
+
|
|
|
+func defaultInt(v int, def int) int {
|
|
|
+ if v == 0 {
|
|
|
+ return def
|
|
|
+ }
|
|
|
+ return v
|
|
|
+}
|
|
|
+
|
|
|
+// ============================
|
|
|
+// JWT helpers
|
|
|
+// ============================
|
|
|
+
|
|
|
+func (a *TaskAdaptor) createJWTToken() (string, error) {
|
|
|
+ return a.createJWTTokenWithKeys(a.accessKey, a.secretKey)
|
|
|
+}
|
|
|
+
|
|
|
+func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
|
|
|
+ parts := strings.Split(apiKey, ",")
|
|
|
+ if len(parts) != 2 {
|
|
|
+ return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'")
|
|
|
+ }
|
|
|
+ return a.createJWTTokenWithKeys(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
|
|
|
+}
|
|
|
+
|
|
|
+func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (string, error) {
|
|
|
+ if accessKey == "" || secretKey == "" {
|
|
|
+ return "", fmt.Errorf("access key and secret key are required")
|
|
|
+ }
|
|
|
+ now := time.Now().Unix()
|
|
|
+ claims := jwt.MapClaims{
|
|
|
+ "iss": accessKey,
|
|
|
+ "exp": now + 1800, // 30 minutes
|
|
|
+ "nbf": now - 5,
|
|
|
+ }
|
|
|
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
|
+ token.Header["typ"] = "JWT"
|
|
|
+ return token.SignedString([]byte(secretKey))
|
|
|
+}
|
|
|
+
|
|
|
+// ParseResultUrl 提取视频任务结果的 url
|
|
|
+func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
|
|
|
+ data, ok := resp["data"].(map[string]any)
|
|
|
+ if !ok {
|
|
|
+ return "", fmt.Errorf("data field not found or invalid")
|
|
|
+ }
|
|
|
+ taskResult, ok := data["task_result"].(map[string]any)
|
|
|
+ if !ok {
|
|
|
+ return "", fmt.Errorf("task_result field not found or invalid")
|
|
|
+ }
|
|
|
+ videos, ok := taskResult["videos"].([]interface{})
|
|
|
+ if !ok || len(videos) == 0 {
|
|
|
+ return "", fmt.Errorf("videos field not found or empty")
|
|
|
+ }
|
|
|
+ video, ok := videos[0].(map[string]interface{})
|
|
|
+ if !ok {
|
|
|
+ return "", fmt.Errorf("video item invalid")
|
|
|
+ }
|
|
|
+ url, ok := video["url"].(string)
|
|
|
+ if !ok || url == "" {
|
|
|
+ return "", fmt.Errorf("url field not found or invalid")
|
|
|
+ }
|
|
|
+ return url, nil
|
|
|
+}
|