relay_task.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. package relay
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "strconv"
  10. "strings"
  11. "github.com/QuantumNous/new-api/common"
  12. "github.com/QuantumNous/new-api/constant"
  13. "github.com/QuantumNous/new-api/dto"
  14. "github.com/QuantumNous/new-api/model"
  15. "github.com/QuantumNous/new-api/relay/channel"
  16. relaycommon "github.com/QuantumNous/new-api/relay/common"
  17. relayconstant "github.com/QuantumNous/new-api/relay/constant"
  18. "github.com/QuantumNous/new-api/service"
  19. "github.com/QuantumNous/new-api/setting/ratio_setting"
  20. "github.com/gin-gonic/gin"
  21. )
  22. /*
  23. Task 任务通过平台、Action 区分任务
  24. */
  25. func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
  26. info.InitChannelMeta(c)
  27. // ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields
  28. if info.TaskRelayInfo == nil {
  29. info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
  30. }
  31. platform := constant.TaskPlatform(c.GetString("platform"))
  32. if platform == "" {
  33. platform = GetTaskPlatform(c)
  34. }
  35. info.InitChannelMeta(c)
  36. adaptor := GetTaskAdaptor(platform)
  37. if adaptor == nil {
  38. return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
  39. }
  40. adaptor.Init(info)
  41. // get & validate taskRequest 获取并验证文本请求
  42. taskErr = adaptor.ValidateRequestAndSetAction(c, info)
  43. if taskErr != nil {
  44. return
  45. }
  46. modelName := info.OriginModelName
  47. if modelName == "" {
  48. modelName = service.CoverTaskActionToModelName(platform, info.Action)
  49. }
  50. modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
  51. if !success {
  52. defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[modelName]
  53. if !ok {
  54. modelPrice = 0.1
  55. } else {
  56. modelPrice = defaultPrice
  57. }
  58. }
  59. // 预扣
  60. groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup)
  61. var ratio float64
  62. userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup)
  63. if hasUserGroupRatio {
  64. ratio = modelPrice * userGroupRatio
  65. } else {
  66. ratio = modelPrice * groupRatio
  67. }
  68. // FIXME: 临时修补,支持任务仅按次计费
  69. if !common.StringsContains(constant.TaskPricePatches, modelName) {
  70. if len(info.PriceData.OtherRatios) > 0 {
  71. for _, ra := range info.PriceData.OtherRatios {
  72. if 1.0 != ra {
  73. ratio *= ra
  74. }
  75. }
  76. }
  77. }
  78. println(fmt.Sprintf("model: %s, model_price: %.4f, group: %s, group_ratio: %.4f, final_ratio: %.4f", modelName, modelPrice, info.UsingGroup, groupRatio, ratio))
  79. userQuota, err := model.GetUserQuota(info.UserId, false)
  80. if err != nil {
  81. taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
  82. return
  83. }
  84. quota := int(ratio * common.QuotaPerUnit)
  85. if userQuota-quota < 0 {
  86. taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
  87. return
  88. }
  89. if info.OriginTaskID != "" {
  90. originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
  91. if err != nil {
  92. taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
  93. return
  94. }
  95. if !exist {
  96. taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
  97. return
  98. }
  99. if originTask.ChannelId != info.ChannelId {
  100. channel, err := model.GetChannelById(originTask.ChannelId, true)
  101. if err != nil {
  102. taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
  103. return
  104. }
  105. if channel.Status != common.ChannelStatusEnabled {
  106. return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
  107. }
  108. c.Set("base_url", channel.GetBaseURL())
  109. c.Set("channel_id", originTask.ChannelId)
  110. c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
  111. info.ChannelBaseUrl = channel.GetBaseURL()
  112. info.ChannelId = originTask.ChannelId
  113. }
  114. }
  115. // build body
  116. requestBody, err := adaptor.BuildRequestBody(c, info)
  117. if err != nil {
  118. taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
  119. return
  120. }
  121. // do request
  122. resp, err := adaptor.DoRequest(c, info, requestBody)
  123. if err != nil {
  124. taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
  125. return
  126. }
  127. // handle response
  128. if resp != nil && resp.StatusCode != http.StatusOK {
  129. responseBody, _ := io.ReadAll(resp.Body)
  130. taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
  131. return
  132. }
  133. defer func() {
  134. // release quota
  135. if info.ConsumeQuota && taskErr == nil {
  136. err := service.PostConsumeQuota(info, quota, 0, true)
  137. if err != nil {
  138. common.SysLog("error consuming token remain quota: " + err.Error())
  139. }
  140. if quota != 0 {
  141. tokenName := c.GetString("token_name")
  142. //gRatio := groupRatio
  143. //if hasUserGroupRatio {
  144. // gRatio = userGroupRatio
  145. //}
  146. logContent := fmt.Sprintf("操作 %s", info.Action)
  147. // FIXME: 临时修补,支持任务仅按次计费
  148. if common.StringsContains(constant.TaskPricePatches, modelName) {
  149. logContent = fmt.Sprintf("%s,按次计费", logContent)
  150. } else {
  151. if len(info.PriceData.OtherRatios) > 0 {
  152. var contents []string
  153. for key, ra := range info.PriceData.OtherRatios {
  154. if 1.0 != ra {
  155. contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
  156. }
  157. }
  158. if len(contents) > 0 {
  159. logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
  160. }
  161. }
  162. }
  163. other := make(map[string]interface{})
  164. if c != nil && c.Request != nil && c.Request.URL != nil {
  165. other["request_path"] = c.Request.URL.Path
  166. }
  167. other["model_price"] = modelPrice
  168. other["group_ratio"] = groupRatio
  169. if hasUserGroupRatio {
  170. other["user_group_ratio"] = userGroupRatio
  171. }
  172. model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
  173. ChannelId: info.ChannelId,
  174. ModelName: modelName,
  175. TokenName: tokenName,
  176. Quota: quota,
  177. Content: logContent,
  178. TokenId: info.TokenId,
  179. Group: info.UsingGroup,
  180. Other: other,
  181. })
  182. model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota)
  183. model.UpdateChannelUsedQuota(info.ChannelId, quota)
  184. }
  185. }
  186. }()
  187. taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
  188. if taskErr != nil {
  189. return
  190. }
  191. info.ConsumeQuota = true
  192. // insert task
  193. task := model.InitTask(platform, info)
  194. task.TaskID = taskID
  195. task.Quota = quota
  196. task.Data = taskData
  197. task.Action = info.Action
  198. err = task.Insert()
  199. if err != nil {
  200. taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
  201. return
  202. }
  203. return nil
  204. }
  205. var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
  206. relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
  207. relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
  208. relayconstant.RelayModeVideoFetchByID: videoFetchByIDRespBodyBuilder,
  209. }
  210. func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
  211. respBuilder, ok := fetchRespBuilders[relayMode]
  212. if !ok {
  213. taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
  214. }
  215. respBody, taskErr := respBuilder(c)
  216. if taskErr != nil {
  217. return taskErr
  218. }
  219. if len(respBody) == 0 {
  220. respBody = []byte("{\"code\":\"success\",\"data\":null}")
  221. }
  222. c.Writer.Header().Set("Content-Type", "application/json")
  223. _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
  224. if err != nil {
  225. taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
  226. return
  227. }
  228. return
  229. }
  230. func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  231. userId := c.GetInt("id")
  232. var condition = struct {
  233. IDs []any `json:"ids"`
  234. Action string `json:"action"`
  235. }{}
  236. err := c.BindJSON(&condition)
  237. if err != nil {
  238. taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
  239. return
  240. }
  241. var tasks []any
  242. if len(condition.IDs) > 0 {
  243. taskModels, err := model.GetByTaskIds(userId, condition.IDs)
  244. if err != nil {
  245. taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
  246. return
  247. }
  248. for _, task := range taskModels {
  249. tasks = append(tasks, TaskModel2Dto(task))
  250. }
  251. } else {
  252. tasks = make([]any, 0)
  253. }
  254. respBody, err = json.Marshal(dto.TaskResponse[[]any]{
  255. Code: "success",
  256. Data: tasks,
  257. })
  258. return
  259. }
  260. func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  261. taskId := c.Param("id")
  262. userId := c.GetInt("id")
  263. originTask, exist, err := model.GetByTaskId(userId, taskId)
  264. if err != nil {
  265. taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
  266. return
  267. }
  268. if !exist {
  269. taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
  270. return
  271. }
  272. respBody, err = json.Marshal(dto.TaskResponse[any]{
  273. Code: "success",
  274. Data: TaskModel2Dto(originTask),
  275. })
  276. return
  277. }
  278. func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  279. taskId := c.Param("task_id")
  280. if taskId == "" {
  281. taskId = c.GetString("task_id")
  282. }
  283. userId := c.GetInt("id")
  284. originTask, exist, err := model.GetByTaskId(userId, taskId)
  285. if err != nil {
  286. taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
  287. return
  288. }
  289. if !exist {
  290. taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
  291. return
  292. }
  293. func() {
  294. channelModel, err2 := model.GetChannelById(originTask.ChannelId, true)
  295. if err2 != nil {
  296. return
  297. }
  298. if channelModel.Type != constant.ChannelTypeVertexAi {
  299. return
  300. }
  301. baseURL := constant.ChannelBaseURLs[channelModel.Type]
  302. if channelModel.GetBaseURL() != "" {
  303. baseURL = channelModel.GetBaseURL()
  304. }
  305. adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
  306. if adaptor == nil {
  307. return
  308. }
  309. resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
  310. "task_id": originTask.TaskID,
  311. "action": originTask.Action,
  312. })
  313. if err2 != nil || resp == nil {
  314. return
  315. }
  316. defer resp.Body.Close()
  317. body, err2 := io.ReadAll(resp.Body)
  318. if err2 != nil {
  319. return
  320. }
  321. ti, err2 := adaptor.ParseTaskResult(body)
  322. if err2 == nil && ti != nil {
  323. if ti.Status != "" {
  324. originTask.Status = model.TaskStatus(ti.Status)
  325. }
  326. if ti.Progress != "" {
  327. originTask.Progress = ti.Progress
  328. }
  329. if ti.Url != "" {
  330. originTask.FailReason = ti.Url
  331. }
  332. _ = originTask.Update()
  333. var raw map[string]any
  334. _ = json.Unmarshal(body, &raw)
  335. format := "mp4"
  336. if respObj, ok := raw["response"].(map[string]any); ok {
  337. if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 {
  338. if v0, ok := vids[0].(map[string]any); ok {
  339. if mt, ok := v0["mimeType"].(string); ok && mt != "" {
  340. if strings.Contains(mt, "mp4") {
  341. format = "mp4"
  342. } else {
  343. format = mt
  344. }
  345. }
  346. }
  347. }
  348. }
  349. status := "processing"
  350. switch originTask.Status {
  351. case model.TaskStatusSuccess:
  352. status = "succeeded"
  353. case model.TaskStatusFailure:
  354. status = "failed"
  355. case model.TaskStatusQueued, model.TaskStatusSubmitted:
  356. status = "queued"
  357. }
  358. out := map[string]any{
  359. "error": nil,
  360. "format": format,
  361. "metadata": nil,
  362. "status": status,
  363. "task_id": originTask.TaskID,
  364. "url": originTask.FailReason,
  365. }
  366. respBody, _ = json.Marshal(dto.TaskResponse[any]{
  367. Code: "success",
  368. Data: out,
  369. })
  370. }
  371. }()
  372. if len(respBody) != 0 {
  373. return
  374. }
  375. if strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
  376. adaptor := GetTaskAdaptor(originTask.Platform)
  377. if adaptor == nil {
  378. taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest)
  379. return
  380. }
  381. if converter, ok := adaptor.(channel.OpenAIVideoConverter); ok {
  382. openAIVideoData, err := converter.ConvertToOpenAIVideo(originTask)
  383. if err != nil {
  384. taskResp = service.TaskErrorWrapper(err, "convert_to_openai_video_failed", http.StatusInternalServerError)
  385. return
  386. }
  387. respBody = openAIVideoData
  388. return
  389. }
  390. taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented)
  391. return
  392. }
  393. respBody, err = json.Marshal(dto.TaskResponse[any]{
  394. Code: "success",
  395. Data: TaskModel2Dto(originTask),
  396. })
  397. if err != nil {
  398. taskResp = service.TaskErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError)
  399. }
  400. return
  401. }
  402. func TaskModel2Dto(task *model.Task) *dto.TaskDto {
  403. return &dto.TaskDto{
  404. TaskID: task.TaskID,
  405. Action: task.Action,
  406. Status: string(task.Status),
  407. FailReason: task.FailReason,
  408. SubmitTime: task.SubmitTime,
  409. StartTime: task.StartTime,
  410. FinishTime: task.FinishTime,
  411. Progress: task.Progress,
  412. Data: task.Data,
  413. }
  414. }