| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263 |
- package controller
- import (
- "bytes"
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/dto"
- "one-api/model"
- "one-api/service"
- "one-api/setting"
- "time"
- "github.com/gin-gonic/gin"
- )
- func UpdateMidjourneyTaskBulk() {
- //imageModel := "midjourney"
- ctx := context.TODO()
- for {
- time.Sleep(time.Duration(15) * time.Second)
- tasks := model.GetAllUnFinishTasks()
- if len(tasks) == 0 {
- continue
- }
- common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
- taskChannelM := make(map[int][]string)
- taskM := make(map[string]*model.Midjourney)
- nullTaskIds := make([]int, 0)
- for _, task := range tasks {
- if task.MjId == "" {
- // 统计失败的未完成任务
- nullTaskIds = append(nullTaskIds, task.Id)
- continue
- }
- taskM[task.MjId] = task
- taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId)
- }
- if len(nullTaskIds) > 0 {
- err := model.MjBulkUpdateByTaskIds(nullTaskIds, map[string]any{
- "status": "FAILURE",
- "progress": "100%",
- })
- if err != nil {
- common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
- } else {
- common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
- }
- }
- if len(taskChannelM) == 0 {
- continue
- }
- for channelId, taskIds := range taskChannelM {
- common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
- if len(taskIds) == 0 {
- continue
- }
- midjourneyChannel, err := model.CacheGetChannel(channelId)
- if err != nil {
- common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
- err := model.MjBulkUpdate(taskIds, map[string]any{
- "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
- "status": "FAILURE",
- "progress": "100%",
- })
- if err != nil {
- common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
- }
- continue
- }
- requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL)
- body, _ := json.Marshal(map[string]any{
- "ids": taskIds,
- })
- req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
- if err != nil {
- common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
- continue
- }
- // 设置超时时间
- timeout := time.Second * 15
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- // 使用带有超时的 context 创建新的请求
- req = req.WithContext(ctx)
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("mj-api-secret", midjourneyChannel.Key)
- resp, err := service.GetHttpClient().Do(req)
- if err != nil {
- common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
- continue
- }
- if resp.StatusCode != http.StatusOK {
- common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
- continue
- }
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
- continue
- }
- var responseItems []dto.MidjourneyDto
- err = json.Unmarshal(responseBody, &responseItems)
- if err != nil {
- common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
- continue
- }
- resp.Body.Close()
- req.Body.Close()
- cancel()
- for _, responseItem := range responseItems {
- task := taskM[responseItem.MjId]
- useTime := (time.Now().UnixNano() / int64(time.Millisecond)) - task.SubmitTime
- // 如果时间超过一小时,且进度不是100%,则认为任务失败
- if useTime > 3600000 && task.Progress != "100%" {
- responseItem.FailReason = "上游任务超时(超过1小时)"
- responseItem.Status = "FAILURE"
- }
- if !checkMjTaskNeedUpdate(task, responseItem) {
- continue
- }
- task.Code = 1
- task.Progress = responseItem.Progress
- task.PromptEn = responseItem.PromptEn
- task.State = responseItem.State
- task.SubmitTime = responseItem.SubmitTime
- task.StartTime = responseItem.StartTime
- task.FinishTime = responseItem.FinishTime
- task.ImageUrl = responseItem.ImageUrl
- task.Status = responseItem.Status
- task.FailReason = responseItem.FailReason
- if responseItem.Properties != nil {
- propertiesStr, _ := json.Marshal(responseItem.Properties)
- task.Properties = string(propertiesStr)
- }
- if responseItem.Buttons != nil {
- buttonStr, _ := json.Marshal(responseItem.Buttons)
- task.Buttons = string(buttonStr)
- }
- shouldReturnQuota := false
- if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
- common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
- task.Progress = "100%"
- if task.Quota != 0 {
- shouldReturnQuota = true
- }
- }
- err = task.Update()
- if err != nil {
- common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
- } else {
- if shouldReturnQuota {
- err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
- if err != nil {
- common.LogError(ctx, "fail to increase user quota: "+err.Error())
- }
- logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(task.Quota))
- model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
- }
- }
- }
- }
- }
- }
- func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) bool {
- if oldTask.Code != 1 {
- return true
- }
- if oldTask.Progress != newTask.Progress {
- return true
- }
- if oldTask.PromptEn != newTask.PromptEn {
- return true
- }
- if oldTask.State != newTask.State {
- return true
- }
- if oldTask.SubmitTime != newTask.SubmitTime {
- return true
- }
- if oldTask.StartTime != newTask.StartTime {
- return true
- }
- if oldTask.FinishTime != newTask.FinishTime {
- return true
- }
- if oldTask.ImageUrl != newTask.ImageUrl {
- return true
- }
- if oldTask.Status != newTask.Status {
- return true
- }
- if oldTask.FailReason != newTask.FailReason {
- return true
- }
- if oldTask.FinishTime != newTask.FinishTime {
- return true
- }
- if oldTask.Progress != "100%" && newTask.FailReason != "" {
- return true
- }
- return false
- }
- func GetAllMidjourney(c *gin.Context) {
- pageInfo := common.GetPageQuery(c)
- // 解析其他查询参数
- queryParams := model.TaskQueryParams{
- ChannelID: c.Query("channel_id"),
- MjID: c.Query("mj_id"),
- StartTimestamp: c.Query("start_timestamp"),
- EndTimestamp: c.Query("end_timestamp"),
- }
- items := model.GetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
- total := model.CountAllTasks(queryParams)
- if setting.MjForwardUrlEnabled {
- for i, midjourney := range items {
- midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
- items[i] = midjourney
- }
- }
- pageInfo.SetTotal(int(total))
- pageInfo.SetItems(items)
- common.ApiSuccess(c, pageInfo)
- }
- func GetUserMidjourney(c *gin.Context) {
- pageInfo := common.GetPageQuery(c)
- userId := c.GetInt("id")
- queryParams := model.TaskQueryParams{
- MjID: c.Query("mj_id"),
- StartTimestamp: c.Query("start_timestamp"),
- EndTimestamp: c.Query("end_timestamp"),
- }
- items := model.GetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
- total := model.CountAllUserTask(userId, queryParams)
- if setting.MjForwardUrlEnabled {
- for i, midjourney := range items {
- midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
- items[i] = midjourney
- }
- }
- pageInfo.SetTotal(int(total))
- pageInfo.SetItems(items)
- common.ApiSuccess(c, pageInfo)
- }
|