relay_task.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. package relay
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/constant"
  11. "one-api/dto"
  12. "one-api/model"
  13. relaycommon "one-api/relay/common"
  14. relayconstant "one-api/relay/constant"
  15. "one-api/service"
  16. "one-api/setting/ratio_setting"
  17. "strconv"
  18. "strings"
  19. "github.com/gin-gonic/gin"
  20. )
  21. /*
  22. Task 任务通过平台、Action 区分任务
  23. */
  24. func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
  25. info.InitChannelMeta(c)
  26. // ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields
  27. if info.TaskRelayInfo == nil {
  28. info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
  29. }
  30. platform := constant.TaskPlatform(c.GetString("platform"))
  31. if platform == "" {
  32. platform = GetTaskPlatform(c)
  33. }
  34. info.InitChannelMeta(c)
  35. adaptor := GetTaskAdaptor(platform)
  36. if adaptor == nil {
  37. return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
  38. }
  39. adaptor.Init(info)
  40. // get & validate taskRequest 获取并验证文本请求
  41. taskErr = adaptor.ValidateRequestAndSetAction(c, info)
  42. if taskErr != nil {
  43. return
  44. }
  45. modelName := info.OriginModelName
  46. if modelName == "" {
  47. modelName = service.CoverTaskActionToModelName(platform, info.Action)
  48. }
  49. modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
  50. if !success {
  51. defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
  52. if !ok {
  53. modelPrice = 0.1
  54. } else {
  55. modelPrice = defaultPrice
  56. }
  57. }
  58. // 预扣
  59. groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup)
  60. var ratio float64
  61. userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup)
  62. if hasUserGroupRatio {
  63. ratio = modelPrice * userGroupRatio
  64. } else {
  65. ratio = modelPrice * groupRatio
  66. }
  67. userQuota, err := model.GetUserQuota(info.UserId, false)
  68. if err != nil {
  69. taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
  70. return
  71. }
  72. quota := int(ratio * common.QuotaPerUnit)
  73. if userQuota-quota < 0 {
  74. taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
  75. return
  76. }
  77. if info.OriginTaskID != "" {
  78. originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
  79. if err != nil {
  80. taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
  81. return
  82. }
  83. if !exist {
  84. taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
  85. return
  86. }
  87. if originTask.ChannelId != info.ChannelId {
  88. channel, err := model.GetChannelById(originTask.ChannelId, true)
  89. if err != nil {
  90. taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
  91. return
  92. }
  93. if channel.Status != common.ChannelStatusEnabled {
  94. return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
  95. }
  96. c.Set("base_url", channel.GetBaseURL())
  97. c.Set("channel_id", originTask.ChannelId)
  98. c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
  99. info.ChannelBaseUrl = channel.GetBaseURL()
  100. info.ChannelId = originTask.ChannelId
  101. }
  102. }
  103. // build body
  104. requestBody, err := adaptor.BuildRequestBody(c, info)
  105. if err != nil {
  106. taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
  107. return
  108. }
  109. // do request
  110. resp, err := adaptor.DoRequest(c, info, requestBody)
  111. if err != nil {
  112. taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
  113. return
  114. }
  115. // handle response
  116. if resp != nil && resp.StatusCode != http.StatusOK {
  117. responseBody, _ := io.ReadAll(resp.Body)
  118. taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
  119. return
  120. }
  121. defer func() {
  122. // release quota
  123. if info.ConsumeQuota && taskErr == nil {
  124. err := service.PostConsumeQuota(info, quota, 0, true)
  125. if err != nil {
  126. common.SysLog("error consuming token remain quota: " + err.Error())
  127. }
  128. if quota != 0 {
  129. tokenName := c.GetString("token_name")
  130. gRatio := groupRatio
  131. if hasUserGroupRatio {
  132. gRatio = userGroupRatio
  133. }
  134. logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, info.Action)
  135. other := make(map[string]interface{})
  136. other["model_price"] = modelPrice
  137. other["group_ratio"] = groupRatio
  138. if hasUserGroupRatio {
  139. other["user_group_ratio"] = userGroupRatio
  140. }
  141. model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
  142. ChannelId: info.ChannelId,
  143. ModelName: modelName,
  144. TokenName: tokenName,
  145. Quota: quota,
  146. Content: logContent,
  147. TokenId: info.TokenId,
  148. Group: info.UsingGroup,
  149. Other: other,
  150. })
  151. model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota)
  152. model.UpdateChannelUsedQuota(info.ChannelId, quota)
  153. }
  154. }
  155. }()
  156. taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
  157. if taskErr != nil {
  158. return
  159. }
  160. info.ConsumeQuota = true
  161. // insert task
  162. task := model.InitTask(platform, info)
  163. task.TaskID = taskID
  164. task.Quota = quota
  165. task.Data = taskData
  166. task.Action = info.Action
  167. err = task.Insert()
  168. if err != nil {
  169. taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
  170. return
  171. }
  172. return nil
  173. }
  174. var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
  175. relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
  176. relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
  177. relayconstant.RelayModeVideoFetchByID: videoFetchByIDRespBodyBuilder,
  178. }
  179. func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
  180. respBuilder, ok := fetchRespBuilders[relayMode]
  181. if !ok {
  182. taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
  183. }
  184. respBody, taskErr := respBuilder(c)
  185. if taskErr != nil {
  186. return taskErr
  187. }
  188. if len(respBody) == 0 {
  189. respBody = []byte("{\"code\":\"success\",\"data\":null}")
  190. }
  191. c.Writer.Header().Set("Content-Type", "application/json")
  192. _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
  193. if err != nil {
  194. taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
  195. return
  196. }
  197. return
  198. }
  199. func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  200. userId := c.GetInt("id")
  201. var condition = struct {
  202. IDs []any `json:"ids"`
  203. Action string `json:"action"`
  204. }{}
  205. err := c.BindJSON(&condition)
  206. if err != nil {
  207. taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
  208. return
  209. }
  210. var tasks []any
  211. if len(condition.IDs) > 0 {
  212. taskModels, err := model.GetByTaskIds(userId, condition.IDs)
  213. if err != nil {
  214. taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
  215. return
  216. }
  217. for _, task := range taskModels {
  218. tasks = append(tasks, TaskModel2Dto(task))
  219. }
  220. } else {
  221. tasks = make([]any, 0)
  222. }
  223. respBody, err = json.Marshal(dto.TaskResponse[[]any]{
  224. Code: "success",
  225. Data: tasks,
  226. })
  227. return
  228. }
  229. func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  230. taskId := c.Param("id")
  231. userId := c.GetInt("id")
  232. originTask, exist, err := model.GetByTaskId(userId, taskId)
  233. if err != nil {
  234. taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
  235. return
  236. }
  237. if !exist {
  238. taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
  239. return
  240. }
  241. respBody, err = json.Marshal(dto.TaskResponse[any]{
  242. Code: "success",
  243. Data: TaskModel2Dto(originTask),
  244. })
  245. return
  246. }
  247. func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  248. taskId := c.Param("task_id")
  249. if taskId == "" {
  250. taskId = c.GetString("task_id")
  251. }
  252. userId := c.GetInt("id")
  253. originTask, exist, err := model.GetByTaskId(userId, taskId)
  254. if err != nil {
  255. taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
  256. return
  257. }
  258. if !exist {
  259. taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
  260. return
  261. }
  262. func() {
  263. channelModel, err2 := model.GetChannelById(originTask.ChannelId, true)
  264. if err2 != nil {
  265. return
  266. }
  267. if channelModel.Type != constant.ChannelTypeVertexAi {
  268. return
  269. }
  270. baseURL := constant.ChannelBaseURLs[channelModel.Type]
  271. if channelModel.GetBaseURL() != "" {
  272. baseURL = channelModel.GetBaseURL()
  273. }
  274. adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
  275. if adaptor == nil {
  276. return
  277. }
  278. resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
  279. "task_id": originTask.TaskID,
  280. "action": originTask.Action,
  281. })
  282. if err2 != nil || resp == nil {
  283. return
  284. }
  285. defer resp.Body.Close()
  286. body, err2 := io.ReadAll(resp.Body)
  287. if err2 != nil {
  288. return
  289. }
  290. ti, err2 := adaptor.ParseTaskResult(body)
  291. if err2 == nil && ti != nil {
  292. if ti.Status != "" {
  293. originTask.Status = model.TaskStatus(ti.Status)
  294. }
  295. if ti.Progress != "" {
  296. originTask.Progress = ti.Progress
  297. }
  298. if ti.Url != "" {
  299. originTask.FailReason = ti.Url
  300. }
  301. _ = originTask.Update()
  302. var raw map[string]any
  303. _ = json.Unmarshal(body, &raw)
  304. format := "mp4"
  305. if respObj, ok := raw["response"].(map[string]any); ok {
  306. if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 {
  307. if v0, ok := vids[0].(map[string]any); ok {
  308. if mt, ok := v0["mimeType"].(string); ok && mt != "" {
  309. if strings.Contains(mt, "mp4") {
  310. format = "mp4"
  311. } else {
  312. format = mt
  313. }
  314. }
  315. }
  316. }
  317. }
  318. status := "processing"
  319. switch originTask.Status {
  320. case model.TaskStatusSuccess:
  321. status = "succeeded"
  322. case model.TaskStatusFailure:
  323. status = "failed"
  324. case model.TaskStatusQueued, model.TaskStatusSubmitted:
  325. status = "queued"
  326. }
  327. out := map[string]any{
  328. "error": nil,
  329. "format": format,
  330. "metadata": nil,
  331. "status": status,
  332. "task_id": originTask.TaskID,
  333. "url": originTask.FailReason,
  334. }
  335. respBody, _ = json.Marshal(dto.TaskResponse[any]{
  336. Code: "success",
  337. Data: out,
  338. })
  339. }
  340. }()
  341. if len(respBody) == 0 {
  342. respBody, err = json.Marshal(dto.TaskResponse[any]{
  343. Code: "success",
  344. Data: TaskModel2Dto(originTask),
  345. })
  346. }
  347. return
  348. }
  349. func TaskModel2Dto(task *model.Task) *dto.TaskDto {
  350. return &dto.TaskDto{
  351. TaskID: task.TaskID,
  352. Action: task.Action,
  353. Status: string(task.Status),
  354. FailReason: task.FailReason,
  355. SubmitTime: task.SubmitTime,
  356. StartTime: task.StartTime,
  357. FinishTime: task.FinishTime,
  358. Progress: task.Progress,
  359. Data: task.Data,
  360. }
  361. }