|
@@ -16,7 +16,6 @@ import (
|
|
|
"github.com/golang-jwt/jwt"
|
|
"github.com/golang-jwt/jwt"
|
|
|
"github.com/pkg/errors"
|
|
"github.com/pkg/errors"
|
|
|
|
|
|
|
|
- "one-api/common"
|
|
|
|
|
"one-api/constant"
|
|
"one-api/constant"
|
|
|
"one-api/dto"
|
|
"one-api/dto"
|
|
|
"one-api/relay/channel"
|
|
"one-api/relay/channel"
|
|
@@ -28,16 +27,6 @@ import (
|
|
|
// Request / Response structures
|
|
// Request / Response structures
|
|
|
// ============================
|
|
// ============================
|
|
|
|
|
|
|
|
-type SubmitReq struct {
|
|
|
|
|
- Prompt string `json:"prompt"`
|
|
|
|
|
- Model string `json:"model,omitempty"`
|
|
|
|
|
- Mode string `json:"mode,omitempty"`
|
|
|
|
|
- Image string `json:"image,omitempty"`
|
|
|
|
|
- Size string `json:"size,omitempty"`
|
|
|
|
|
- Duration int `json:"duration,omitempty"`
|
|
|
|
|
- Metadata map[string]interface{} `json:"metadata,omitempty"`
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
type TrajectoryPoint struct {
|
|
type TrajectoryPoint struct {
|
|
|
X int `json:"x"`
|
|
X int `json:"x"`
|
|
|
Y int `json:"y"`
|
|
Y int `json:"y"`
|
|
@@ -121,23 +110,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
|
|
|
|
|
|
|
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
|
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
|
|
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
|
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
|
|
- // Accept only POST /v1/video/generations as "generate" action.
|
|
|
|
|
- action := constant.TaskActionGenerate
|
|
|
|
|
- info.Action = action
|
|
|
|
|
-
|
|
|
|
|
- var req SubmitReq
|
|
|
|
|
- if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
|
|
|
|
- taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
|
|
|
|
- return
|
|
|
|
|
- }
|
|
|
|
|
- if strings.TrimSpace(req.Prompt) == "" {
|
|
|
|
|
- taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
|
|
|
|
- return
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Store into context for later usage
|
|
|
|
|
- c.Set("task_request", req)
|
|
|
|
|
- return nil
|
|
|
|
|
|
|
+ // Use the standard validation method for TaskSubmitReq
|
|
|
|
|
+ return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// BuildRequestURL constructs the upstream URL.
|
|
// BuildRequestURL constructs the upstream URL.
|
|
@@ -166,7 +140,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
|
|
if !exists {
|
|
if !exists {
|
|
|
return nil, fmt.Errorf("request not found in context")
|
|
return nil, fmt.Errorf("request not found in context")
|
|
|
}
|
|
}
|
|
|
- req := v.(SubmitReq)
|
|
|
|
|
|
|
+ req := v.(relaycommon.TaskSubmitReq)
|
|
|
|
|
|
|
|
body, err := a.convertToRequestPayload(&req)
|
|
body, err := a.convertToRequestPayload(&req)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
@@ -255,7 +229,7 @@ func (a *TaskAdaptor) GetChannelName() string {
|
|
|
// helpers
|
|
// helpers
|
|
|
// ============================
|
|
// ============================
|
|
|
|
|
|
|
|
-func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
|
|
|
|
|
|
|
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
|
|
r := requestPayload{
|
|
r := requestPayload{
|
|
|
Prompt: req.Prompt,
|
|
Prompt: req.Prompt,
|
|
|
Image: req.Image,
|
|
Image: req.Image,
|