relay_task.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. package relay
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "strconv"
  10. "strings"
  11. "github.com/QuantumNous/new-api/common"
  12. "github.com/QuantumNous/new-api/constant"
  13. "github.com/QuantumNous/new-api/dto"
  14. "github.com/QuantumNous/new-api/model"
  15. "github.com/QuantumNous/new-api/relay/channel"
  16. relaycommon "github.com/QuantumNous/new-api/relay/common"
  17. relayconstant "github.com/QuantumNous/new-api/relay/constant"
  18. "github.com/QuantumNous/new-api/service"
  19. "github.com/QuantumNous/new-api/setting/ratio_setting"
  20. "github.com/gin-gonic/gin"
  21. )
  22. /*
  23. Task 任务通过平台、Action 区分任务
  24. */
  25. func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
  26. info.InitChannelMeta(c)
  27. // ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields
  28. if info.TaskRelayInfo == nil {
  29. info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
  30. }
  31. path := c.Request.URL.Path
  32. if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") {
  33. info.Action = constant.TaskActionRemix
  34. }
  35. // 提取 remix 任务的 video_id
  36. if info.Action == constant.TaskActionRemix {
  37. videoID := c.Param("video_id")
  38. if strings.TrimSpace(videoID) == "" {
  39. return service.TaskErrorWrapperLocal(fmt.Errorf("video_id is required"), "invalid_request", http.StatusBadRequest)
  40. }
  41. info.OriginTaskID = videoID
  42. }
  43. platform := constant.TaskPlatform(c.GetString("platform"))
  44. // 获取原始任务信息
  45. if info.OriginTaskID != "" {
  46. originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
  47. if err != nil {
  48. taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
  49. return
  50. }
  51. if !exist {
  52. taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
  53. return
  54. }
  55. if info.OriginModelName == "" {
  56. if originTask.Properties.OriginModelName != "" {
  57. info.OriginModelName = originTask.Properties.OriginModelName
  58. } else if originTask.Properties.UpstreamModelName != "" {
  59. info.OriginModelName = originTask.Properties.UpstreamModelName
  60. } else {
  61. var taskData map[string]interface{}
  62. _ = json.Unmarshal(originTask.Data, &taskData)
  63. if m, ok := taskData["model"].(string); ok && m != "" {
  64. info.OriginModelName = m
  65. platform = originTask.Platform
  66. }
  67. }
  68. }
  69. if originTask.ChannelId != info.ChannelId {
  70. channel, err := model.GetChannelById(originTask.ChannelId, true)
  71. if err != nil {
  72. taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
  73. return
  74. }
  75. if channel.Status != common.ChannelStatusEnabled {
  76. taskErr = service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
  77. return
  78. }
  79. key, _, newAPIError := channel.GetNextEnabledKey()
  80. if newAPIError != nil {
  81. taskErr = service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
  82. return
  83. }
  84. common.SetContextKey(c, constant.ContextKeyChannelKey, key)
  85. common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
  86. common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
  87. common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
  88. info.ChannelBaseUrl = channel.GetBaseURL()
  89. info.ChannelId = originTask.ChannelId
  90. info.ChannelType = channel.Type
  91. info.ApiKey = key
  92. platform = originTask.Platform
  93. }
  94. // 使用原始任务的参数
  95. if info.Action == constant.TaskActionRemix {
  96. var taskData map[string]interface{}
  97. _ = json.Unmarshal(originTask.Data, &taskData)
  98. secondsStr, _ := taskData["seconds"].(string)
  99. seconds, _ := strconv.Atoi(secondsStr)
  100. if seconds <= 0 {
  101. seconds = 4
  102. }
  103. sizeStr, _ := taskData["size"].(string)
  104. if info.PriceData.OtherRatios == nil {
  105. info.PriceData.OtherRatios = map[string]float64{}
  106. }
  107. info.PriceData.OtherRatios["seconds"] = float64(seconds)
  108. info.PriceData.OtherRatios["size"] = 1
  109. if sizeStr == "1792x1024" || sizeStr == "1024x1792" {
  110. info.PriceData.OtherRatios["size"] = 1.666667
  111. }
  112. }
  113. }
  114. if platform == "" {
  115. platform = GetTaskPlatform(c)
  116. }
  117. info.InitChannelMeta(c)
  118. adaptor := GetTaskAdaptor(platform)
  119. if adaptor == nil {
  120. return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
  121. }
  122. adaptor.Init(info)
  123. // get & validate taskRequest 获取并验证文本请求
  124. taskErr = adaptor.ValidateRequestAndSetAction(c, info)
  125. if taskErr != nil {
  126. return
  127. }
  128. modelName := info.OriginModelName
  129. if modelName == "" {
  130. modelName = service.CoverTaskActionToModelName(platform, info.Action)
  131. }
  132. modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
  133. if !success {
  134. defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[modelName]
  135. if !ok {
  136. modelPrice = 0.1
  137. } else {
  138. modelPrice = defaultPrice
  139. }
  140. }
  141. // 处理 auto 分组:从 context 获取实际选中的分组
  142. // 当使用 auto 分组时,Distribute 中间件会将实际选中的分组存储在 ContextKeyAutoGroup 中
  143. if autoGroup, exists := common.GetContextKey(c, constant.ContextKeyAutoGroup); exists {
  144. if groupStr, ok := autoGroup.(string); ok && groupStr != "" {
  145. info.UsingGroup = groupStr
  146. }
  147. }
  148. // 预扣
  149. groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup)
  150. var ratio float64
  151. userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup)
  152. if hasUserGroupRatio {
  153. ratio = modelPrice * userGroupRatio
  154. } else {
  155. ratio = modelPrice * groupRatio
  156. }
  157. // FIXME: 临时修补,支持任务仅按次计费
  158. if !common.StringsContains(constant.TaskPricePatches, modelName) {
  159. if len(info.PriceData.OtherRatios) > 0 {
  160. for _, ra := range info.PriceData.OtherRatios {
  161. if 1.0 != ra {
  162. ratio *= ra
  163. }
  164. }
  165. }
  166. }
  167. println(fmt.Sprintf("model: %s, model_price: %.4f, group: %s, group_ratio: %.4f, final_ratio: %.4f", modelName, modelPrice, info.UsingGroup, groupRatio, ratio))
  168. userQuota, err := model.GetUserQuota(info.UserId, false)
  169. if err != nil {
  170. taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
  171. return
  172. }
  173. quota := int(ratio * common.QuotaPerUnit)
  174. if userQuota-quota < 0 {
  175. taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
  176. return
  177. }
  178. // build body
  179. requestBody, err := adaptor.BuildRequestBody(c, info)
  180. if err != nil {
  181. taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
  182. return
  183. }
  184. // do request
  185. resp, err := adaptor.DoRequest(c, info, requestBody)
  186. if err != nil {
  187. taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
  188. return
  189. }
  190. // handle response
  191. if resp != nil && resp.StatusCode != http.StatusOK {
  192. responseBody, _ := io.ReadAll(resp.Body)
  193. taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
  194. return
  195. }
  196. defer func() {
  197. // release quota
  198. if info.ConsumeQuota && taskErr == nil {
  199. err := service.PostConsumeQuota(info, quota, 0, true)
  200. if err != nil {
  201. common.SysLog("error consuming token remain quota: " + err.Error())
  202. }
  203. if quota != 0 {
  204. tokenName := c.GetString("token_name")
  205. //gRatio := groupRatio
  206. //if hasUserGroupRatio {
  207. // gRatio = userGroupRatio
  208. //}
  209. logContent := fmt.Sprintf("操作 %s", info.Action)
  210. // FIXME: 临时修补,支持任务仅按次计费
  211. if common.StringsContains(constant.TaskPricePatches, modelName) {
  212. logContent = fmt.Sprintf("%s,按次计费", logContent)
  213. } else {
  214. if len(info.PriceData.OtherRatios) > 0 {
  215. var contents []string
  216. for key, ra := range info.PriceData.OtherRatios {
  217. if 1.0 != ra {
  218. contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
  219. }
  220. }
  221. if len(contents) > 0 {
  222. logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
  223. }
  224. }
  225. }
  226. other := make(map[string]interface{})
  227. if c != nil && c.Request != nil && c.Request.URL != nil {
  228. other["request_path"] = c.Request.URL.Path
  229. }
  230. other["model_price"] = modelPrice
  231. other["group_ratio"] = groupRatio
  232. if hasUserGroupRatio {
  233. other["user_group_ratio"] = userGroupRatio
  234. }
  235. model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
  236. ChannelId: info.ChannelId,
  237. ModelName: modelName,
  238. TokenName: tokenName,
  239. Quota: quota,
  240. Content: logContent,
  241. TokenId: info.TokenId,
  242. Group: info.UsingGroup,
  243. Other: other,
  244. })
  245. model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota)
  246. model.UpdateChannelUsedQuota(info.ChannelId, quota)
  247. }
  248. }
  249. }()
  250. taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
  251. if taskErr != nil {
  252. return
  253. }
  254. info.ConsumeQuota = true
  255. // insert task
  256. task := model.InitTask(platform, info)
  257. task.TaskID = taskID
  258. task.Quota = quota
  259. task.Data = taskData
  260. task.Action = info.Action
  261. err = task.Insert()
  262. if err != nil {
  263. taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
  264. return
  265. }
  266. return nil
  267. }
  268. var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
  269. relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
  270. relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
  271. relayconstant.RelayModeVideoFetchByID: videoFetchByIDRespBodyBuilder,
  272. }
  273. func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
  274. respBuilder, ok := fetchRespBuilders[relayMode]
  275. if !ok {
  276. taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
  277. }
  278. respBody, taskErr := respBuilder(c)
  279. if taskErr != nil {
  280. return taskErr
  281. }
  282. if len(respBody) == 0 {
  283. respBody = []byte("{\"code\":\"success\",\"data\":null}")
  284. }
  285. c.Writer.Header().Set("Content-Type", "application/json")
  286. _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
  287. if err != nil {
  288. taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
  289. return
  290. }
  291. return
  292. }
  293. func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  294. userId := c.GetInt("id")
  295. var condition = struct {
  296. IDs []any `json:"ids"`
  297. Action string `json:"action"`
  298. }{}
  299. err := c.BindJSON(&condition)
  300. if err != nil {
  301. taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
  302. return
  303. }
  304. var tasks []any
  305. if len(condition.IDs) > 0 {
  306. taskModels, err := model.GetByTaskIds(userId, condition.IDs)
  307. if err != nil {
  308. taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
  309. return
  310. }
  311. for _, task := range taskModels {
  312. tasks = append(tasks, TaskModel2Dto(task))
  313. }
  314. } else {
  315. tasks = make([]any, 0)
  316. }
  317. respBody, err = json.Marshal(dto.TaskResponse[[]any]{
  318. Code: "success",
  319. Data: tasks,
  320. })
  321. return
  322. }
  323. func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  324. taskId := c.Param("id")
  325. userId := c.GetInt("id")
  326. originTask, exist, err := model.GetByTaskId(userId, taskId)
  327. if err != nil {
  328. taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
  329. return
  330. }
  331. if !exist {
  332. taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
  333. return
  334. }
  335. respBody, err = json.Marshal(dto.TaskResponse[any]{
  336. Code: "success",
  337. Data: TaskModel2Dto(originTask),
  338. })
  339. return
  340. }
  341. func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  342. taskId := c.Param("task_id")
  343. if taskId == "" {
  344. taskId = c.GetString("task_id")
  345. }
  346. userId := c.GetInt("id")
  347. originTask, exist, err := model.GetByTaskId(userId, taskId)
  348. if err != nil {
  349. taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
  350. return
  351. }
  352. if !exist {
  353. taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
  354. return
  355. }
  356. func() {
  357. channelModel, err2 := model.GetChannelById(originTask.ChannelId, true)
  358. if err2 != nil {
  359. return
  360. }
  361. if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini {
  362. return
  363. }
  364. baseURL := constant.ChannelBaseURLs[channelModel.Type]
  365. if channelModel.GetBaseURL() != "" {
  366. baseURL = channelModel.GetBaseURL()
  367. }
  368. proxy := channelModel.GetSetting().Proxy
  369. adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
  370. if adaptor == nil {
  371. return
  372. }
  373. resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
  374. "task_id": originTask.TaskID,
  375. "action": originTask.Action,
  376. }, proxy)
  377. if err2 != nil || resp == nil {
  378. return
  379. }
  380. defer resp.Body.Close()
  381. body, err2 := io.ReadAll(resp.Body)
  382. if err2 != nil {
  383. return
  384. }
  385. ti, err2 := adaptor.ParseTaskResult(body)
  386. if err2 == nil && ti != nil {
  387. if ti.Status != "" {
  388. originTask.Status = model.TaskStatus(ti.Status)
  389. }
  390. if ti.Progress != "" {
  391. originTask.Progress = ti.Progress
  392. }
  393. if ti.Url != "" {
  394. if strings.HasPrefix(ti.Url, "data:") {
  395. } else {
  396. originTask.FailReason = ti.Url
  397. }
  398. }
  399. _ = originTask.Update()
  400. var raw map[string]any
  401. _ = json.Unmarshal(body, &raw)
  402. format := "mp4"
  403. if respObj, ok := raw["response"].(map[string]any); ok {
  404. if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 {
  405. if v0, ok := vids[0].(map[string]any); ok {
  406. if mt, ok := v0["mimeType"].(string); ok && mt != "" {
  407. if strings.Contains(mt, "mp4") {
  408. format = "mp4"
  409. } else {
  410. format = mt
  411. }
  412. }
  413. }
  414. }
  415. }
  416. status := "processing"
  417. switch originTask.Status {
  418. case model.TaskStatusSuccess:
  419. status = "succeeded"
  420. case model.TaskStatusFailure:
  421. status = "failed"
  422. case model.TaskStatusQueued, model.TaskStatusSubmitted:
  423. status = "queued"
  424. }
  425. if !strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
  426. out := map[string]any{
  427. "error": nil,
  428. "format": format,
  429. "metadata": nil,
  430. "status": status,
  431. "task_id": originTask.TaskID,
  432. "url": originTask.FailReason,
  433. }
  434. respBody, _ = json.Marshal(dto.TaskResponse[any]{
  435. Code: "success",
  436. Data: out,
  437. })
  438. }
  439. }
  440. }()
  441. if len(respBody) != 0 {
  442. return
  443. }
  444. if strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
  445. adaptor := GetTaskAdaptor(originTask.Platform)
  446. if adaptor == nil {
  447. taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest)
  448. return
  449. }
  450. if converter, ok := adaptor.(channel.OpenAIVideoConverter); ok {
  451. openAIVideoData, err := converter.ConvertToOpenAIVideo(originTask)
  452. if err != nil {
  453. taskResp = service.TaskErrorWrapper(err, "convert_to_openai_video_failed", http.StatusInternalServerError)
  454. return
  455. }
  456. respBody = openAIVideoData
  457. return
  458. }
  459. taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented)
  460. return
  461. }
  462. respBody, err = json.Marshal(dto.TaskResponse[any]{
  463. Code: "success",
  464. Data: TaskModel2Dto(originTask),
  465. })
  466. if err != nil {
  467. taskResp = service.TaskErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError)
  468. }
  469. return
  470. }
  471. func TaskModel2Dto(task *model.Task) *dto.TaskDto {
  472. return &dto.TaskDto{
  473. TaskID: task.TaskID,
  474. Action: task.Action,
  475. Status: string(task.Status),
  476. FailReason: task.FailReason,
  477. SubmitTime: task.SubmitTime,
  478. StartTime: task.StartTime,
  479. FinishTime: task.FinishTime,
  480. Progress: task.Progress,
  481. Data: task.Data,
  482. }
  483. }