task.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. package model
  2. import (
  3. "database/sql/driver"
  4. "encoding/json"
  5. "one-api/constant"
  6. commonRelay "one-api/relay/common"
  7. "time"
  8. )
  9. type TaskStatus string
  10. const (
  11. TaskStatusNotStart TaskStatus = "NOT_START"
  12. TaskStatusSubmitted = "SUBMITTED"
  13. TaskStatusQueued = "QUEUED"
  14. TaskStatusInProgress = "IN_PROGRESS"
  15. TaskStatusFailure = "FAILURE"
  16. TaskStatusSuccess = "SUCCESS"
  17. TaskStatusUnknown = "UNKNOWN"
  18. )
  19. type Task struct {
  20. ID int64 `json:"id" gorm:"primary_key;AUTO_INCREMENT"`
  21. CreatedAt int64 `json:"created_at" gorm:"index"`
  22. UpdatedAt int64 `json:"updated_at"`
  23. TaskID string `json:"task_id" gorm:"type:varchar(50);index"` // 第三方id,不一定有/ song id\ Task id
  24. Platform constant.TaskPlatform `json:"platform" gorm:"type:varchar(30);index"` // 平台
  25. UserId int `json:"user_id" gorm:"index"`
  26. ChannelId int `json:"channel_id" gorm:"index"`
  27. Quota int `json:"quota"`
  28. Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
  29. Status TaskStatus `json:"status" gorm:"type:varchar(20);index"` // 任务状态
  30. FailReason string `json:"fail_reason"`
  31. SubmitTime int64 `json:"submit_time" gorm:"index"`
  32. StartTime int64 `json:"start_time" gorm:"index"`
  33. FinishTime int64 `json:"finish_time" gorm:"index"`
  34. Progress string `json:"progress" gorm:"type:varchar(20);index"`
  35. Properties Properties `json:"properties" gorm:"type:json"`
  36. Data json.RawMessage `json:"data" gorm:"type:json"`
  37. }
  38. func (t *Task) SetData(data any) {
  39. b, _ := json.Marshal(data)
  40. t.Data = json.RawMessage(b)
  41. }
  42. func (t *Task) GetData(v any) error {
  43. err := json.Unmarshal(t.Data, &v)
  44. return err
  45. }
  46. type Properties struct {
  47. Input string `json:"input"`
  48. }
  49. func (m *Properties) Scan(val interface{}) error {
  50. bytesValue, _ := val.([]byte)
  51. return json.Unmarshal(bytesValue, m)
  52. }
  53. func (m Properties) Value() (driver.Value, error) {
  54. return json.Marshal(m)
  55. }
  56. // SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
  57. type SyncTaskQueryParams struct {
  58. Platform constant.TaskPlatform
  59. ChannelID string
  60. TaskID string
  61. UserID string
  62. Action string
  63. Status string
  64. StartTimestamp int64
  65. EndTimestamp int64
  66. UserIDs []int
  67. }
  68. func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.TaskRelayInfo) *Task {
  69. t := &Task{
  70. UserId: relayInfo.UserId,
  71. SubmitTime: time.Now().Unix(),
  72. Status: TaskStatusNotStart,
  73. Progress: "0%",
  74. ChannelId: relayInfo.ChannelId,
  75. Platform: platform,
  76. }
  77. return t
  78. }
  79. func TaskGetAllUserTask(userId int, startIdx int, num int, queryParams SyncTaskQueryParams) []*Task {
  80. var tasks []*Task
  81. var err error
  82. // 初始化查询构建器
  83. query := DB.Where("user_id = ?", userId)
  84. if queryParams.TaskID != "" {
  85. query = query.Where("task_id = ?", queryParams.TaskID)
  86. }
  87. if queryParams.Action != "" {
  88. query = query.Where("action = ?", queryParams.Action)
  89. }
  90. if queryParams.Status != "" {
  91. query = query.Where("status = ?", queryParams.Status)
  92. }
  93. if queryParams.Platform != "" {
  94. query = query.Where("platform = ?", queryParams.Platform)
  95. }
  96. if queryParams.StartTimestamp != 0 {
  97. // 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析
  98. query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
  99. }
  100. if queryParams.EndTimestamp != 0 {
  101. query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
  102. }
  103. // 获取数据
  104. err = query.Omit("channel_id").Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
  105. if err != nil {
  106. return nil
  107. }
  108. return tasks
  109. }
  110. func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []*Task {
  111. var tasks []*Task
  112. var err error
  113. // 初始化查询构建器
  114. query := DB
  115. // 添加过滤条件
  116. if queryParams.ChannelID != "" {
  117. query = query.Where("channel_id = ?", queryParams.ChannelID)
  118. }
  119. if queryParams.Platform != "" {
  120. query = query.Where("platform = ?", queryParams.Platform)
  121. }
  122. if queryParams.UserID != "" {
  123. query = query.Where("user_id = ?", queryParams.UserID)
  124. }
  125. if len(queryParams.UserIDs) != 0 {
  126. query = query.Where("user_id in (?)", queryParams.UserIDs)
  127. }
  128. if queryParams.TaskID != "" {
  129. query = query.Where("task_id = ?", queryParams.TaskID)
  130. }
  131. if queryParams.Action != "" {
  132. query = query.Where("action = ?", queryParams.Action)
  133. }
  134. if queryParams.Status != "" {
  135. query = query.Where("status = ?", queryParams.Status)
  136. }
  137. if queryParams.StartTimestamp != 0 {
  138. query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
  139. }
  140. if queryParams.EndTimestamp != 0 {
  141. query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
  142. }
  143. // 获取数据
  144. err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
  145. if err != nil {
  146. return nil
  147. }
  148. return tasks
  149. }
  150. func GetAllUnFinishSyncTasks(limit int) []*Task {
  151. var tasks []*Task
  152. var err error
  153. // get all tasks progress is not 100%
  154. err = DB.Where("progress != ?", "100%").Limit(limit).Order("id").Find(&tasks).Error
  155. if err != nil {
  156. return nil
  157. }
  158. return tasks
  159. }
  160. func GetByOnlyTaskId(taskId string) (*Task, bool, error) {
  161. if taskId == "" {
  162. return nil, false, nil
  163. }
  164. var task *Task
  165. var err error
  166. err = DB.Where("task_id = ?", taskId).First(&task).Error
  167. exist, err := RecordExist(err)
  168. if err != nil {
  169. return nil, false, err
  170. }
  171. return task, exist, err
  172. }
  173. func GetByTaskId(userId int, taskId string) (*Task, bool, error) {
  174. if taskId == "" {
  175. return nil, false, nil
  176. }
  177. var task *Task
  178. var err error
  179. err = DB.Where("user_id = ? and task_id = ?", userId, taskId).
  180. First(&task).Error
  181. exist, err := RecordExist(err)
  182. if err != nil {
  183. return nil, false, err
  184. }
  185. return task, exist, err
  186. }
  187. func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) {
  188. if len(taskIds) == 0 {
  189. return nil, nil
  190. }
  191. var task []*Task
  192. var err error
  193. err = DB.Where("user_id = ? and task_id in (?)", userId, taskIds).
  194. Find(&task).Error
  195. if err != nil {
  196. return nil, err
  197. }
  198. return task, nil
  199. }
  200. func TaskUpdateProgress(id int64, progress string) error {
  201. return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error
  202. }
  203. func (Task *Task) Insert() error {
  204. var err error
  205. err = DB.Create(Task).Error
  206. return err
  207. }
  208. func (Task *Task) Update() error {
  209. var err error
  210. err = DB.Save(Task).Error
  211. return err
  212. }
  213. func TaskBulkUpdate(TaskIds []string, params map[string]any) error {
  214. if len(TaskIds) == 0 {
  215. return nil
  216. }
  217. return DB.Model(&Task{}).
  218. Where("task_id in (?)", TaskIds).
  219. Updates(params).Error
  220. }
  221. func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error {
  222. if len(taskIDs) == 0 {
  223. return nil
  224. }
  225. return DB.Model(&Task{}).
  226. Where("id in (?)", taskIDs).
  227. Updates(params).Error
  228. }
  229. func TaskBulkUpdateByID(ids []int64, params map[string]any) error {
  230. if len(ids) == 0 {
  231. return nil
  232. }
  233. return DB.Model(&Task{}).
  234. Where("id in (?)", ids).
  235. Updates(params).Error
  236. }
  237. type TaskQuotaUsage struct {
  238. Mode string `json:"mode"`
  239. Count float64 `json:"count"`
  240. }
  241. func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) {
  242. query := DB.Model(Task{})
  243. // 添加过滤条件
  244. if queryParams.ChannelID != "" {
  245. query = query.Where("channel_id = ?", queryParams.ChannelID)
  246. }
  247. if queryParams.UserID != "" {
  248. query = query.Where("user_id = ?", queryParams.UserID)
  249. }
  250. if len(queryParams.UserIDs) != 0 {
  251. query = query.Where("user_id in (?)", queryParams.UserIDs)
  252. }
  253. if queryParams.TaskID != "" {
  254. query = query.Where("task_id = ?", queryParams.TaskID)
  255. }
  256. if queryParams.Action != "" {
  257. query = query.Where("action = ?", queryParams.Action)
  258. }
  259. if queryParams.Status != "" {
  260. query = query.Where("status = ?", queryParams.Status)
  261. }
  262. if queryParams.StartTimestamp != 0 {
  263. query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
  264. }
  265. if queryParams.EndTimestamp != 0 {
  266. query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
  267. }
  268. err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
  269. return stat, err
  270. }
  271. // TaskCountAllTasks returns total tasks that match the given query params (admin usage)
  272. func TaskCountAllTasks(queryParams SyncTaskQueryParams) int64 {
  273. var total int64
  274. query := DB.Model(&Task{})
  275. if queryParams.ChannelID != "" {
  276. query = query.Where("channel_id = ?", queryParams.ChannelID)
  277. }
  278. if queryParams.Platform != "" {
  279. query = query.Where("platform = ?", queryParams.Platform)
  280. }
  281. if queryParams.UserID != "" {
  282. query = query.Where("user_id = ?", queryParams.UserID)
  283. }
  284. if len(queryParams.UserIDs) != 0 {
  285. query = query.Where("user_id in (?)", queryParams.UserIDs)
  286. }
  287. if queryParams.TaskID != "" {
  288. query = query.Where("task_id = ?", queryParams.TaskID)
  289. }
  290. if queryParams.Action != "" {
  291. query = query.Where("action = ?", queryParams.Action)
  292. }
  293. if queryParams.Status != "" {
  294. query = query.Where("status = ?", queryParams.Status)
  295. }
  296. if queryParams.StartTimestamp != 0 {
  297. query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
  298. }
  299. if queryParams.EndTimestamp != 0 {
  300. query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
  301. }
  302. _ = query.Count(&total).Error
  303. return total
  304. }
  305. // TaskCountAllUserTask returns total tasks for given user
  306. func TaskCountAllUserTask(userId int, queryParams SyncTaskQueryParams) int64 {
  307. var total int64
  308. query := DB.Model(&Task{}).Where("user_id = ?", userId)
  309. if queryParams.TaskID != "" {
  310. query = query.Where("task_id = ?", queryParams.TaskID)
  311. }
  312. if queryParams.Action != "" {
  313. query = query.Where("action = ?", queryParams.Action)
  314. }
  315. if queryParams.Status != "" {
  316. query = query.Where("status = ?", queryParams.Status)
  317. }
  318. if queryParams.Platform != "" {
  319. query = query.Where("platform = ?", queryParams.Platform)
  320. }
  321. if queryParams.StartTimestamp != 0 {
  322. query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
  323. }
  324. if queryParams.EndTimestamp != 0 {
  325. query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
  326. }
  327. _ = query.Count(&total).Error
  328. return total
  329. }