Xyfacai 2 лет назад
Родитель
Сommit
fd4ef086dc
4 измененных файлов с 212 добавлено и 11 удалено
  1. 187 10
      controller/midjourney.go
  2. 1 1
      main.go
  3. 18 0
      model/cache.go
  4. 6 0
      model/midjourney.go

+ 187 - 10
controller/midjourney.go

@@ -16,8 +16,10 @@ import (
 	"time"
 )
 
-func UpdateMidjourneyTask() {
+/*func UpdateMidjourneyTask() {
 	//revocer
+	//imageModel := "midjourney"
+	ctx := context.TODO()
 	imageModel := "midjourney"
 	defer func() {
 		if err := recover(); err != nil {
@@ -28,27 +30,28 @@ func UpdateMidjourneyTask() {
 		time.Sleep(time.Duration(15) * time.Second)
 		tasks := model.GetAllUnFinishTasks()
 		if len(tasks) != 0 {
-			log.Printf("检测到未完成的任务数有: %v", len(tasks))
+			common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
 			for _, task := range tasks {
-				log.Printf("未完成的任务信息: %v", task)
+				common.LogInfo(ctx, fmt.Sprintf("未完成的任务信息: %v", task))
 				midjourneyChannel, err := model.GetChannelById(task.ChannelId, true)
 				if err != nil {
-					log.Printf("UpdateMidjourneyTask: %v", err)
+					common.LogError(ctx, fmt.Sprintf("UpdateMidjourneyTask: %v", err))
 					task.FailReason = fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", task.ChannelId)
 					task.Status = "FAILURE"
 					task.Progress = "100%"
 					err := task.Update()
 					if err != nil {
-						log.Printf("UpdateMidjourneyTask error: %v", err)
+						common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
+						continue
 					}
 					continue
 				}
 				requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId)
-				log.Printf("requestUrl: %s", requestUrl)
+				common.LogInfo(ctx, fmt.Sprintf("requestUrl: %s", requestUrl))
 
 				req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte("")))
 				if err != nil {
-					log.Printf("UpdateMidjourneyTask error: %v", err)
+					common.LogInfo(ctx, fmt.Sprintf("Get Task error: %v", err))
 					continue
 				}
 
@@ -111,7 +114,7 @@ func UpdateMidjourneyTask() {
 				task.Status = responseItem.Status
 				task.FailReason = responseItem.FailReason
 				if task.Progress != "100%" && responseItem.FailReason != "" {
-					log.Println(task.MjId + " 构建失败," + task.FailReason)
+					common.LogWarn(task.MjId + " 构建失败," + task.FailReason)
 					task.Progress = "100%"
 					err = model.CacheUpdateUserQuota(task.UserId)
 					if err != nil {
@@ -126,8 +129,8 @@ func UpdateMidjourneyTask() {
 							if err != nil {
 								log.Println("fail to increase user quota")
 							}
-							logContent := fmt.Sprintf("%s 构图失败,补偿 %s", task.MjId, common.LogQuota(quota))
-							model.RecordLog(task.UserId, 1, logContent)
+							logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
+							model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
 						}
 					}
 				}
@@ -142,6 +145,180 @@ func UpdateMidjourneyTask() {
 		}
 	}
 }
+*/
+
+func UpdateMidjourneyTaskBulk() {
+	//revocer
+	defer func() {
+		if err := recover(); err != nil {
+			log.Printf("UpdateMidjourneyTask panic: %v", err)
+		}
+	}()
+	imageModel := "midjourney"
+	ctx := context.TODO()
+	for {
+		time.Sleep(time.Duration(15) * time.Second)
+
+		tasks := model.GetAllUnFinishTasks()
+		if len(tasks) == 0 {
+			continue
+		}
+
+		common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
+		taskChannelM := make(map[int][]string)
+		taskM := make(map[string]*model.Midjourney)
+		for _, task := range tasks {
+			if task.MjId == "" {
+				continue
+			}
+			taskM[task.MjId] = task
+			taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId)
+		}
+		if len(taskChannelM) == 0 {
+			continue
+		}
+
+		for channelId, taskIds := range taskChannelM {
+			common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
+			if len(taskIds) == 0 {
+				continue
+			}
+			midjourneyChannel, err := model.CacheGetChannel(channelId)
+			if err != nil {
+				common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
+				err := model.MjBulkUpdate(taskIds, map[string]any{
+					"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
+					"status":      "FAILURE",
+					"progress":    "100%",
+				})
+				if err != nil {
+					common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
+				}
+				continue
+			}
+			requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL)
+
+			body, _ := json.Marshal(map[string]any{
+				"ids": taskIds,
+			})
+			req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
+			if err != nil {
+				common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
+				continue
+			}
+			// 设置超时时间
+			timeout := time.Second * 5
+			ctx, cancel := context.WithTimeout(context.Background(), timeout)
+			// 使用带有超时的 context 创建新的请求
+			req = req.WithContext(ctx)
+			req.Header.Set("Content-Type", "application/json")
+			req.Header.Set("mj-api-secret", midjourneyChannel.Key)
+			resp, err := httpClient.Do(req)
+			if err != nil {
+				common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
+				continue
+			}
+			responseBody, err := io.ReadAll(resp.Body)
+			if err != nil {
+				common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
+				continue
+			}
+			var responseItems []Midjourney
+			err = json.Unmarshal(responseBody, &responseItems)
+			if err != nil {
+				common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v", err))
+				continue
+			}
+			resp.Body.Close()
+			req.Body.Close()
+			cancel()
+
+			for _, responseItem := range responseItems {
+				task := taskM[responseItem.MjId]
+				if !checkMjTaskNeedUpdate(task, responseItem) {
+					continue
+				}
+
+				task.Code = 1
+				task.Progress = responseItem.Progress
+				task.PromptEn = responseItem.PromptEn
+				task.State = responseItem.State
+				task.SubmitTime = responseItem.SubmitTime
+				task.StartTime = responseItem.StartTime
+				task.FinishTime = responseItem.FinishTime
+				task.ImageUrl = responseItem.ImageUrl
+				task.Status = responseItem.Status
+				task.FailReason = responseItem.FailReason
+				if task.Progress != "100%" && responseItem.FailReason != "" {
+					common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
+					task.Progress = "100%"
+					err = model.CacheUpdateUserQuota(task.UserId)
+					if err != nil {
+						common.LogError(ctx, "error update user quota cache: "+err.Error())
+					} else {
+						modelRatio := common.GetModelRatio(imageModel)
+						groupRatio := common.GetGroupRatio("default")
+						ratio := modelRatio * groupRatio
+						quota := int(ratio * 1 * 1000)
+						if quota != 0 {
+							err = model.IncreaseUserQuota(task.UserId, quota)
+							if err != nil {
+								common.LogError(ctx, "fail to increase user quota: "+err.Error())
+							}
+							logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
+							model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+						}
+					}
+				}
+				err = task.Update()
+				if err != nil {
+					common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
+				}
+			}
+		}
+	}
+}
+
+func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask Midjourney) bool {
+	if oldTask.Code != 1 {
+		return true
+	}
+	if oldTask.Progress != newTask.Progress {
+		return true
+	}
+	if oldTask.PromptEn != newTask.PromptEn {
+		return true
+	}
+	if oldTask.State != newTask.State {
+		return true
+	}
+	if oldTask.SubmitTime != newTask.SubmitTime {
+		return true
+	}
+	if oldTask.StartTime != newTask.StartTime {
+		return true
+	}
+	if oldTask.FinishTime != newTask.FinishTime {
+		return true
+	}
+	if oldTask.ImageUrl != newTask.ImageUrl {
+		return true
+	}
+	if oldTask.Status != newTask.Status {
+		return true
+	}
+	if oldTask.FailReason != newTask.FailReason {
+		return true
+	}
+	if oldTask.FinishTime != newTask.FinishTime {
+		return true
+	}
+	if oldTask.Progress != "100%" && newTask.FailReason != "" {
+		return true
+	}
+
+	return false
+}
 
 func GetAllMidjourney(c *gin.Context) {
 	p, _ := strconv.Atoi(c.Query("p"))

+ 1 - 1
main.go

@@ -81,7 +81,7 @@ func main() {
 		}
 		go controller.AutomaticallyTestChannels(frequency)
 	}
-	go controller.UpdateMidjourneyTask()
+	go controller.UpdateMidjourneyTaskBulk()
 	if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
 		common.BatchUpdateEnabled = true
 		common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")

+ 18 - 0
model/cache.go

@@ -133,6 +133,7 @@ func CacheIsUserEnabled(userId int) (bool, error) {
 }
 
 var group2model2channels map[string]map[string][]*Channel
+var channelsIDM map[int]*Channel
 var channelSyncLock sync.RWMutex
 
 func InitChannelCache() {
@@ -149,10 +150,12 @@ func InitChannelCache() {
 		groups[ability.Group] = true
 	}
 	newGroup2model2channels := make(map[string]map[string][]*Channel)
+	newChannelsIDM := make(map[int]*Channel)
 	for group := range groups {
 		newGroup2model2channels[group] = make(map[string][]*Channel)
 	}
 	for _, channel := range channels {
+		newChannelsIDM[channel.Id] = channel
 		groups := strings.Split(channel.Group, ",")
 		for _, group := range groups {
 			models := strings.Split(channel.Models, ",")
@@ -177,6 +180,7 @@ func InitChannelCache() {
 
 	channelSyncLock.Lock()
 	group2model2channels = newGroup2model2channels
+	channelsIDM = newChannelsIDM
 	channelSyncLock.Unlock()
 	common.SysLog("channels synced from database")
 }
@@ -217,3 +221,17 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
 	idx := rand.Intn(endIdx)
 	return channels[idx], nil
 }
+
+func CacheGetChannel(id int) (*Channel, error) {
+	if !common.MemoryCacheEnabled {
+		return GetChannelById(id, true)
+	}
+	channelSyncLock.RLock()
+	defer channelSyncLock.RUnlock()
+
+	c, ok := channelsIDM[id]
+	if !ok {
+		return nil, errors.New(fmt.Sprintf("当前渠道# %d,已不存在", id))
+	}
+	return c, nil
+}

+ 6 - 0
model/midjourney.go

@@ -131,3 +131,9 @@ func (midjourney *Midjourney) Update() error {
 	err = DB.Save(midjourney).Error
 	return err
 }
+
+func MjBulkUpdate(taskIDs []string, params map[string]any) error {
+	return DB.Model(&Midjourney{}).
+		Where("mj_id in (?)", taskIDs).
+		Updates(params).Error
+}