Browse Source

Merge remote-tracking branch 'origin/main'

CaIon 3 tháng trước cách đây
mục cha
commit
406a3e4dca

+ 5 - 21
relay/channel/task/jimeng/adaptor.go

@@ -18,7 +18,6 @@ import (
 	"github.com/gin-gonic/gin"
 	"github.com/pkg/errors"
 
-	"one-api/common"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/relay/channel"
@@ -89,22 +88,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
 // ValidateRequestAndSetAction parses body, validates fields and sets default action.
 func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
 	// Accept only POST /v1/video/generations as "generate" action.
-	action := constant.TaskActionGenerate
-	info.Action = action
-
-	req := relaycommon.TaskSubmitReq{}
-	if err := common.UnmarshalBodyReusable(c, &req); err != nil {
-		taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
-		return
-	}
-	if strings.TrimSpace(req.Prompt) == "" {
-		taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
-		return
-	}
-
-	// Store into context for later usage
-	c.Set("task_request", req)
-	return nil
+	return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
 }
 
 // BuildRequestURL constructs the upstream URL.
@@ -334,11 +318,11 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 	}
 
 	// Handle one-of image_urls or binary_data_base64
-	if req.Image != "" {
-		if strings.HasPrefix(req.Image, "http") {
-			r.ImageUrls = []string{req.Image}
+	if req.HasImage() {
+		if strings.HasPrefix(req.Images[0], "http") {
+			r.ImageUrls = req.Images
 		} else {
-			r.BinaryDataBase64 = []string{req.Image}
+			r.BinaryDataBase64 = req.Images
 		}
 	}
 	metadata := req.Metadata

+ 4 - 30
relay/channel/task/kling/adaptor.go

@@ -16,7 +16,6 @@ import (
 	"github.com/golang-jwt/jwt"
 	"github.com/pkg/errors"
 
-	"one-api/common"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/relay/channel"
@@ -28,16 +27,6 @@ import (
 // Request / Response structures
 // ============================
 
-type SubmitReq struct {
-	Prompt   string                 `json:"prompt"`
-	Model    string                 `json:"model,omitempty"`
-	Mode     string                 `json:"mode,omitempty"`
-	Image    string                 `json:"image,omitempty"`
-	Size     string                 `json:"size,omitempty"`
-	Duration int                    `json:"duration,omitempty"`
-	Metadata map[string]interface{} `json:"metadata,omitempty"`
-}
-
 type TrajectoryPoint struct {
 	X int `json:"x"`
 	Y int `json:"y"`
@@ -121,23 +110,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
 
 // ValidateRequestAndSetAction parses body, validates fields and sets default action.
 func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
-	// Accept only POST /v1/video/generations as "generate" action.
-	action := constant.TaskActionGenerate
-	info.Action = action
-
-	var req SubmitReq
-	if err := common.UnmarshalBodyReusable(c, &req); err != nil {
-		taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
-		return
-	}
-	if strings.TrimSpace(req.Prompt) == "" {
-		taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
-		return
-	}
-
-	// Store into context for later usage
-	c.Set("task_request", req)
-	return nil
+	// Use the standard validation method for TaskSubmitReq
+	return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
 }
 
 // BuildRequestURL constructs the upstream URL.
@@ -166,7 +140,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 	if !exists {
 		return nil, fmt.Errorf("request not found in context")
 	}
-	req := v.(SubmitReq)
+	req := v.(relaycommon.TaskSubmitReq)
 
 	body, err := a.convertToRequestPayload(&req)
 	if err != nil {
@@ -255,7 +229,7 @@ func (a *TaskAdaptor) GetChannelName() string {
 // helpers
 // ============================
 
-func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
 	r := requestPayload{
 		Prompt:         req.Prompt,
 		Image:          req.Image,

+ 4 - 29
relay/channel/task/vidu/adaptor.go

@@ -23,16 +23,6 @@ import (
 // Request / Response structures
 // ============================
 
-type SubmitReq struct {
-	Prompt   string                 `json:"prompt"`
-	Model    string                 `json:"model,omitempty"`
-	Mode     string                 `json:"mode,omitempty"`
-	Image    string                 `json:"image,omitempty"`
-	Size     string                 `json:"size,omitempty"`
-	Duration int                    `json:"duration,omitempty"`
-	Metadata map[string]interface{} `json:"metadata,omitempty"`
-}
-
 type requestPayload struct {
 	Model             string   `json:"model"`
 	Images            []string `json:"images"`
@@ -90,23 +80,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
-	var req SubmitReq
-	if err := c.ShouldBindJSON(&req); err != nil {
-		return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest)
-	}
-
-	if req.Prompt == "" {
-		return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest)
-	}
-
-	if req.Image != "" {
-		info.Action = constant.TaskActionGenerate
-	} else {
-		info.Action = constant.TaskActionTextGenerate
-	}
-
-	c.Set("task_request", req)
-	return nil
+	// Use the unified validation method for TaskSubmitReq with image-based action determination
+	return relaycommon.ValidateTaskRequestWithImageBinding(c, info)
 }
 
 func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) {
@@ -114,7 +89,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo)
 	if !exists {
 		return nil, fmt.Errorf("request not found in context")
 	}
-	req := v.(SubmitReq)
+	req := v.(relaycommon.TaskSubmitReq)
 
 	body, err := a.convertToRequestPayload(&req)
 	if err != nil {
@@ -211,7 +186,7 @@ func (a *TaskAdaptor) GetChannelName() string {
 // helpers
 // ============================
 
-func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
 	var images []string
 	if req.Image != "" {
 		images = []string{req.Image}

+ 9 - 0
relay/common/relay_info.go

@@ -481,11 +481,20 @@ type TaskSubmitReq struct {
 	Model    string                 `json:"model,omitempty"`
 	Mode     string                 `json:"mode,omitempty"`
 	Image    string                 `json:"image,omitempty"`
+	Images   []string               `json:"images,omitempty"`
 	Size     string                 `json:"size,omitempty"`
 	Duration int                    `json:"duration,omitempty"`
 	Metadata map[string]interface{} `json:"metadata,omitempty"`
 }
 
+func (t TaskSubmitReq) GetPrompt() string {
+	return t.Prompt
+}
+
+func (t TaskSubmitReq) HasImage() bool {
+	return len(t.Images) > 0
+}
+
 type TaskInfo struct {
 	Code     int    `json:"code"`
 	TaskID   string `json:"task_id"`

+ 80 - 0
relay/common/relay_utils.go

@@ -2,12 +2,23 @@ package common
 
 import (
 	"fmt"
+	"net/http"
+	"one-api/common"
 	"one-api/constant"
+	"one-api/dto"
 	"strings"
 
 	"github.com/gin-gonic/gin"
 )
 
+type HasPrompt interface {
+	GetPrompt() string
+}
+
+type HasImage interface {
+	HasImage() bool
+}
+
 func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
 
@@ -30,3 +41,72 @@ func GetAPIVersion(c *gin.Context) string {
 	}
 	return apiVersion
 }
+
+func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError {
+	return &dto.TaskError{
+		Code:       code,
+		Message:    err.Error(),
+		StatusCode: statusCode,
+		LocalError: localError,
+		Error:      err,
+	}
+}
+
+func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj interface{}) {
+	info.Action = action
+	c.Set("task_request", requestObj)
+}
+
+func validatePrompt(prompt string) *dto.TaskError {
+	if strings.TrimSpace(prompt) == "" {
+		return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true)
+	}
+	return nil
+}
+
+func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
+	var req TaskSubmitReq
+	if err := common.UnmarshalBodyReusable(c, &req); err != nil {
+		return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
+	}
+
+	if taskErr := validatePrompt(req.Prompt); taskErr != nil {
+		return taskErr
+	}
+
+	if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" {
+		// 兼容单图上传
+		req.Images = []string{req.Image}
+	}
+
+	storeTaskRequest(c, info, action, req)
+	return nil
+}
+
+func ValidateTaskRequestWithImage(c *gin.Context, info *RelayInfo, requestObj interface{}) *dto.TaskError {
+	hasPrompt, ok := requestObj.(HasPrompt)
+	if !ok {
+		return createTaskError(fmt.Errorf("request must have prompt"), "invalid_request", http.StatusBadRequest, true)
+	}
+
+	if taskErr := validatePrompt(hasPrompt.GetPrompt()); taskErr != nil {
+		return taskErr
+	}
+
+	action := constant.TaskActionTextGenerate
+	if hasImage, ok := requestObj.(HasImage); ok && hasImage.HasImage() {
+		action = constant.TaskActionGenerate
+	}
+
+	storeTaskRequest(c, info, action, requestObj)
+	return nil
+}
+
+func ValidateTaskRequestWithImageBinding(c *gin.Context, info *RelayInfo) *dto.TaskError {
+	var req TaskSubmitReq
+	if err := c.ShouldBindJSON(&req); err != nil {
+		return createTaskError(err, "invalid_request_body", http.StatusBadRequest, false)
+	}
+
+	return ValidateTaskRequestWithImage(c, info, req)
+}