package common import ( "fmt" "net/http" "strconv" "strings" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/gin-gonic/gin" "github.com/samber/lo" ) type HasPrompt interface { GetPrompt() string } type HasImage interface { HasImage() bool } func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { switch channelType { case constant.ChannelTypeOpenAI: fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) case constant.ChannelTypeAzure: fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) } } return fullRequestURL } func GetAPIVersion(c *gin.Context) string { query := c.Request.URL.Query() apiVersion := query.Get("api-version") if apiVersion == "" { apiVersion = c.GetString("api_version") } return apiVersion } func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError { return &dto.TaskError{ Code: code, Message: err.Error(), StatusCode: statusCode, LocalError: localError, Error: err, } } func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj TaskSubmitReq) { info.Action = action c.Set("task_request", requestObj) } func validatePrompt(prompt string) *dto.TaskError { if strings.TrimSpace(prompt) == "" { return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true) } return nil } func validateMultipartTaskRequest(c *gin.Context, info *RelayInfo, action string) (TaskSubmitReq, error) { var req TaskSubmitReq if _, err := c.MultipartForm(); err != nil { return req, err } formData := c.Request.PostForm req = TaskSubmitReq{ Prompt: formData.Get("prompt"), Model: formData.Get("model"), Mode: formData.Get("mode"), Image: formData.Get("image"), Size: formData.Get("size"), Metadata: make(map[string]interface{}), } if durationStr := formData.Get("seconds"); durationStr != "" { if duration, err := strconv.Atoi(durationStr); err == nil { req.Duration = duration } } if images := formData["images"]; len(images) > 0 { req.Images = images } for key, values := range formData { if len(values) > 0 && !isKnownTaskField(key) { if intVal, err := strconv.Atoi(values[0]); err == nil { req.Metadata[key] = intVal } else if floatVal, err := strconv.ParseFloat(values[0], 64); err == nil { req.Metadata[key] = floatVal } else { req.Metadata[key] = values[0] } } } return req, nil } func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError { var prompt string var model string var seconds int var size string var hasInputReference bool var req TaskSubmitReq if err := common.UnmarshalBodyReusable(c, &req); err != nil { return createTaskError(err, "invalid_json", http.StatusBadRequest, true) } prompt = req.Prompt model = req.Model seconds, _ = strconv.Atoi(req.Seconds) if seconds == 0 { seconds = req.Duration } if req.InputReference != "" { req.Images = []string{req.InputReference} } if strings.TrimSpace(req.Model) == "" { return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true) } if req.HasImage() { hasInputReference = true } if taskErr := validatePrompt(prompt); taskErr != nil { return taskErr } action := constant.TaskActionTextGenerate if hasInputReference { action = constant.TaskActionGenerate } if strings.HasPrefix(model, "sora-2") { if size == "" { size = "720x1280" } if seconds <= 0 { seconds = 4 } if model == "sora-2" && !lo.Contains([]string{"720x1280", "1280x720"}, size) { return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true) } if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) { return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true) } info.PriceData.OtherRatios = map[string]float64{ "seconds": float64(seconds), "size": 1, } if lo.Contains([]string{"1792x1024", "1024x1792"}, size) { info.PriceData.OtherRatios["size"] = 1.666667 } } info.Action = action return nil } func isKnownTaskField(field string) bool { knownFields := map[string]bool{ "prompt": true, "model": true, "mode": true, "image": true, "images": true, "size": true, "duration": true, "input_reference": true, // Sora 特有字段 } return knownFields[field] } func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError { var err error contentType := c.GetHeader("Content-Type") var req TaskSubmitReq if strings.HasPrefix(contentType, "multipart/form-data") { req, err = validateMultipartTaskRequest(c, info, action) if err != nil { return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true) } } else if err := common.UnmarshalBodyReusable(c, &req); err != nil { return createTaskError(err, "invalid_request", http.StatusBadRequest, true) } if taskErr := validatePrompt(req.Prompt); taskErr != nil { return taskErr } if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" { // 兼容单图上传 req.Images = []string{req.Image} } if req.HasImage() { action = constant.TaskActionGenerate if info.ChannelType == constant.ChannelTypeVidu { // vidu 增加 首尾帧生视频和参考图生视频 if len(req.Images) == 2 { action = constant.TaskActionFirstTailGenerate } else if len(req.Images) > 2 { action = constant.TaskActionReferenceGenerate } } } // 模型映射 if info.IsModelMapped { req.Model = info.UpstreamModelName } storeTaskRequest(c, info, action, req) return nil }