| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257 |
- package common
- import (
- "fmt"
- "net/http"
- "strconv"
- "strings"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/constant"
- "github.com/QuantumNous/new-api/dto"
- "github.com/gin-gonic/gin"
- "github.com/samber/lo"
- )
- type HasPrompt interface {
- GetPrompt() string
- }
- type HasImage interface {
- HasImage() bool
- }
- func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
- fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
- if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
- switch channelType {
- case constant.ChannelTypeOpenAI:
- fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
- case constant.ChannelTypeAzure:
- fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
- }
- }
- return fullRequestURL
- }
- func GetAPIVersion(c *gin.Context) string {
- query := c.Request.URL.Query()
- apiVersion := query.Get("api-version")
- if apiVersion == "" {
- apiVersion = c.GetString("api_version")
- }
- return apiVersion
- }
- func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError {
- return &dto.TaskError{
- Code: code,
- Message: err.Error(),
- StatusCode: statusCode,
- LocalError: localError,
- Error: err,
- }
- }
- func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj TaskSubmitReq) {
- info.Action = action
- c.Set("task_request", requestObj)
- }
- func validatePrompt(prompt string) *dto.TaskError {
- if strings.TrimSpace(prompt) == "" {
- return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true)
- }
- return nil
- }
- func validateMultipartTaskRequest(c *gin.Context, info *RelayInfo, action string) (TaskSubmitReq, error) {
- var req TaskSubmitReq
- if _, err := c.MultipartForm(); err != nil {
- return req, err
- }
- formData := c.Request.PostForm
- req = TaskSubmitReq{
- Prompt: formData.Get("prompt"),
- Model: formData.Get("model"),
- Mode: formData.Get("mode"),
- Image: formData.Get("image"),
- Size: formData.Get("size"),
- Metadata: make(map[string]interface{}),
- }
- if durationStr := formData.Get("seconds"); durationStr != "" {
- if duration, err := strconv.Atoi(durationStr); err == nil {
- req.Duration = duration
- }
- }
- if images := formData["images"]; len(images) > 0 {
- req.Images = images
- }
- for key, values := range formData {
- if len(values) > 0 && !isKnownTaskField(key) {
- if intVal, err := strconv.Atoi(values[0]); err == nil {
- req.Metadata[key] = intVal
- } else if floatVal, err := strconv.ParseFloat(values[0], 64); err == nil {
- req.Metadata[key] = floatVal
- } else {
- req.Metadata[key] = values[0]
- }
- }
- }
- return req, nil
- }
- func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
- contentType := c.GetHeader("Content-Type")
- var prompt string
- var model string
- var seconds int
- var size string
- var hasInputReference bool
- if strings.HasPrefix(contentType, "multipart/form-data") {
- form, err := common.ParseMultipartFormReusable(c)
- if err != nil {
- return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
- }
- defer form.RemoveAll()
- prompts, ok := form.Value["prompt"]
- if !ok || len(prompts) == 0 {
- return createTaskError(fmt.Errorf("prompt field is required"), "missing_prompt", http.StatusBadRequest, true)
- }
- prompt = prompts[0]
- if _, ok := form.Value["model"]; !ok {
- return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
- }
- model = form.Value["model"][0]
- if _, ok := form.File["input_reference"]; ok {
- hasInputReference = true
- }
- if ss, ok := form.Value["seconds"]; ok {
- sInt := common.String2Int(ss[0])
- if sInt > seconds {
- seconds = common.String2Int(ss[0])
- }
- }
- if sz, ok := form.Value["size"]; ok {
- size = sz[0]
- }
- } else {
- var req TaskSubmitReq
- if err := common.UnmarshalBodyReusable(c, &req); err != nil {
- return createTaskError(err, "invalid_json", http.StatusBadRequest, true)
- }
- prompt = req.Prompt
- model = req.Model
- seconds = req.Duration
- if strings.TrimSpace(req.Model) == "" {
- return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
- }
- if req.HasImage() {
- hasInputReference = true
- }
- }
- if taskErr := validatePrompt(prompt); taskErr != nil {
- return taskErr
- }
- action := constant.TaskActionTextGenerate
- if hasInputReference {
- action = constant.TaskActionGenerate
- }
- if strings.HasPrefix(model, "sora-2") {
- if size == "" {
- size = "720x1280"
- }
- if seconds <= 0 {
- seconds = 4
- }
- if model == "sora-2" && !lo.Contains([]string{"720x1280", "1280x720"}, size) {
- return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
- }
- if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) {
- return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
- }
- info.PriceData.OtherRatios = map[string]float64{
- "seconds": float64(seconds),
- "size": 1,
- }
- if lo.Contains([]string{"1792x1024", "1024x1792"}, size) {
- info.PriceData.OtherRatios["size"] = 1.666667
- }
- }
- info.Action = action
- return nil
- }
- func isKnownTaskField(field string) bool {
- knownFields := map[string]bool{
- "prompt": true,
- "model": true,
- "mode": true,
- "image": true,
- "images": true,
- "size": true,
- "duration": true,
- "input_reference": true, // Sora 特有字段
- }
- return knownFields[field]
- }
- func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
- var err error
- contentType := c.GetHeader("Content-Type")
- var req TaskSubmitReq
- if strings.HasPrefix(contentType, "multipart/form-data") {
- req, err = validateMultipartTaskRequest(c, info, action)
- if err != nil {
- return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
- }
- } else if err := common.UnmarshalBodyReusable(c, &req); err != nil {
- return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
- }
- if taskErr := validatePrompt(req.Prompt); taskErr != nil {
- return taskErr
- }
- if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" {
- // 兼容单图上传
- req.Images = []string{req.Image}
- }
- if req.HasImage() {
- action = constant.TaskActionGenerate
- if info.ChannelType == constant.ChannelTypeVidu {
- // vidu 增加 首尾帧生视频和参考图生视频
- if len(req.Images) == 2 {
- action = constant.TaskActionFirstTailGenerate
- } else if len(req.Images) > 2 {
- action = constant.TaskActionReferenceGenerate
- }
- }
- }
- storeTaskRequest(c, info, action, req)
- return nil
- }
|