relay_utils.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. package common
  2. import (
  3. "fmt"
  4. "net/http"
  5. "strconv"
  6. "strings"
  7. "github.com/QuantumNous/new-api/common"
  8. "github.com/QuantumNous/new-api/constant"
  9. "github.com/QuantumNous/new-api/dto"
  10. "github.com/gin-gonic/gin"
  11. "github.com/samber/lo"
  12. )
  13. type HasPrompt interface {
  14. GetPrompt() string
  15. }
  16. type HasImage interface {
  17. HasImage() bool
  18. }
  19. func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
  20. fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
  21. if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
  22. switch channelType {
  23. case constant.ChannelTypeOpenAI:
  24. fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
  25. case constant.ChannelTypeAzure:
  26. fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
  27. }
  28. }
  29. return fullRequestURL
  30. }
  31. func GetAPIVersion(c *gin.Context) string {
  32. query := c.Request.URL.Query()
  33. apiVersion := query.Get("api-version")
  34. if apiVersion == "" {
  35. apiVersion = c.GetString("api_version")
  36. }
  37. return apiVersion
  38. }
  39. func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError {
  40. return &dto.TaskError{
  41. Code: code,
  42. Message: err.Error(),
  43. StatusCode: statusCode,
  44. LocalError: localError,
  45. Error: err,
  46. }
  47. }
  48. func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj TaskSubmitReq) {
  49. info.Action = action
  50. c.Set("task_request", requestObj)
  51. }
  52. func validatePrompt(prompt string) *dto.TaskError {
  53. if strings.TrimSpace(prompt) == "" {
  54. return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true)
  55. }
  56. return nil
  57. }
  58. func validateMultipartTaskRequest(c *gin.Context, info *RelayInfo, action string) (TaskSubmitReq, error) {
  59. var req TaskSubmitReq
  60. if _, err := c.MultipartForm(); err != nil {
  61. return req, err
  62. }
  63. formData := c.Request.PostForm
  64. req = TaskSubmitReq{
  65. Prompt: formData.Get("prompt"),
  66. Model: formData.Get("model"),
  67. Mode: formData.Get("mode"),
  68. Image: formData.Get("image"),
  69. Size: formData.Get("size"),
  70. Metadata: make(map[string]interface{}),
  71. }
  72. if durationStr := formData.Get("seconds"); durationStr != "" {
  73. if duration, err := strconv.Atoi(durationStr); err == nil {
  74. req.Duration = duration
  75. }
  76. }
  77. if images := formData["images"]; len(images) > 0 {
  78. req.Images = images
  79. }
  80. for key, values := range formData {
  81. if len(values) > 0 && !isKnownTaskField(key) {
  82. if intVal, err := strconv.Atoi(values[0]); err == nil {
  83. req.Metadata[key] = intVal
  84. } else if floatVal, err := strconv.ParseFloat(values[0], 64); err == nil {
  85. req.Metadata[key] = floatVal
  86. } else {
  87. req.Metadata[key] = values[0]
  88. }
  89. }
  90. }
  91. return req, nil
  92. }
  93. func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
  94. contentType := c.GetHeader("Content-Type")
  95. var prompt string
  96. var model string
  97. var seconds int
  98. var size string
  99. var hasInputReference bool
  100. if strings.HasPrefix(contentType, "multipart/form-data") {
  101. form, err := common.ParseMultipartFormReusable(c)
  102. if err != nil {
  103. return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
  104. }
  105. defer form.RemoveAll()
  106. prompts, ok := form.Value["prompt"]
  107. if !ok || len(prompts) == 0 {
  108. return createTaskError(fmt.Errorf("prompt field is required"), "missing_prompt", http.StatusBadRequest, true)
  109. }
  110. prompt = prompts[0]
  111. if _, ok := form.Value["model"]; !ok {
  112. return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
  113. }
  114. model = form.Value["model"][0]
  115. if _, ok := form.File["input_reference"]; ok {
  116. hasInputReference = true
  117. }
  118. if ss, ok := form.Value["seconds"]; ok {
  119. sInt := common.String2Int(ss[0])
  120. if sInt > seconds {
  121. seconds = common.String2Int(ss[0])
  122. }
  123. }
  124. if sz, ok := form.Value["size"]; ok {
  125. size = sz[0]
  126. }
  127. } else {
  128. var req TaskSubmitReq
  129. if err := common.UnmarshalBodyReusable(c, &req); err != nil {
  130. return createTaskError(err, "invalid_json", http.StatusBadRequest, true)
  131. }
  132. prompt = req.Prompt
  133. model = req.Model
  134. seconds = req.Duration
  135. if strings.TrimSpace(req.Model) == "" {
  136. return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
  137. }
  138. if req.HasImage() {
  139. hasInputReference = true
  140. }
  141. }
  142. if taskErr := validatePrompt(prompt); taskErr != nil {
  143. return taskErr
  144. }
  145. action := constant.TaskActionTextGenerate
  146. if hasInputReference {
  147. action = constant.TaskActionGenerate
  148. }
  149. if strings.HasPrefix(model, "sora-2") {
  150. if size == "" {
  151. size = "720x1280"
  152. }
  153. if seconds <= 0 {
  154. seconds = 4
  155. }
  156. if model == "sora-2" && !lo.Contains([]string{"720x1280", "1280x720"}, size) {
  157. return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
  158. }
  159. if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) {
  160. return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
  161. }
  162. info.PriceData.OtherRatios = map[string]float64{
  163. "seconds": float64(seconds),
  164. "size": 1,
  165. }
  166. if lo.Contains([]string{"1792x1024", "1024x1792"}, size) {
  167. info.PriceData.OtherRatios["size"] = 1.666667
  168. }
  169. }
  170. info.Action = action
  171. return nil
  172. }
  173. func isKnownTaskField(field string) bool {
  174. knownFields := map[string]bool{
  175. "prompt": true,
  176. "model": true,
  177. "mode": true,
  178. "image": true,
  179. "images": true,
  180. "size": true,
  181. "duration": true,
  182. "input_reference": true, // Sora 特有字段
  183. }
  184. return knownFields[field]
  185. }
  186. func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
  187. var err error
  188. contentType := c.GetHeader("Content-Type")
  189. var req TaskSubmitReq
  190. if strings.HasPrefix(contentType, "multipart/form-data") {
  191. req, err = validateMultipartTaskRequest(c, info, action)
  192. if err != nil {
  193. return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
  194. }
  195. } else if err := common.UnmarshalBodyReusable(c, &req); err != nil {
  196. return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
  197. }
  198. if taskErr := validatePrompt(req.Prompt); taskErr != nil {
  199. return taskErr
  200. }
  201. if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" {
  202. // 兼容单图上传
  203. req.Images = []string{req.Image}
  204. }
  205. if req.HasImage() {
  206. action = constant.TaskActionGenerate
  207. if info.ChannelType == constant.ChannelTypeVidu {
  208. // vidu 增加 首尾帧生视频和参考图生视频
  209. if len(req.Images) == 2 {
  210. action = constant.TaskActionFirstTailGenerate
  211. } else if len(req.Images) > 2 {
  212. action = constant.TaskActionReferenceGenerate
  213. }
  214. }
  215. }
  216. storeTaskRequest(c, info, action, req)
  217. return nil
  218. }