Răsfoiți Sursa

optimize: MJ 部分调整、优化

MJ
增加simple-change、list接口,
变换和重试操作区别出来,价格与绘图一样
优化图片返回
Xyfacai 2 ani în urmă
părinte
comite
5c747dfee2
5 a modificat fișierele cu 217 adăugiri și 56 ștergeri
  1. 7 1
      common/model-ratio.go
  2. 177 52
      controller/relay-mj.go
  3. 10 2
      controller/relay.go
  4. 21 1
      model/midjourney.go
  5. 2 0
      router/relay-router.go

+ 7 - 1
common/model-ratio.go

@@ -14,7 +14,7 @@ import (
 // 1 === $0.002 / 1K tokens
 // 1 === ¥0.014 / 1k tokens
 var ModelRatio = map[string]float64{
-	"midjourney":                50,
+	//"midjourney":                50,
 	"gpt-4-gizmo-*":             15,
 	"gpt-4":                     15,
 	"gpt-4-0314":                15,
@@ -80,6 +80,12 @@ var ModelRatio = map[string]float64{
 
 var ModelPrice = map[string]float64{
 	"gpt-4-gizmo-*": 0.1,
+	"mj_imagine":    0.1,
+	"mj_variation":  0.1,
+	"mj_reroll":     0.1,
+	"mj_blend":      0.1,
+	"mj_describe":   0.05,
+	"mj_upscale":    0.05,
 }
 
 func ModelPrice2JSONString() string {

+ 177 - 52
controller/relay-mj.go

@@ -57,7 +57,7 @@ type MidjourneyWithoutStatus struct {
 
 func RelayMidjourneyImage(c *gin.Context) {
 	taskId := c.Param("id")
-	midjourneyTask := model.GetByMJId(taskId)
+	midjourneyTask := model.GetByOnlyMJId(taskId)
 	if midjourneyTask == nil {
 		c.JSON(400, gin.H{
 			"error": "midjourney_task_not_found",
@@ -71,14 +71,27 @@ func RelayMidjourneyImage(c *gin.Context) {
 		})
 	}
 	defer resp.Body.Close()
-	data, err := io.ReadAll(resp.Body)
-	if err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+	if resp.StatusCode != http.StatusOK {
+		responseBody, _ := io.ReadAll(resp.Body)
+		c.JSON(resp.StatusCode, gin.H{
+			"error": string(responseBody),
+		})
 		return
 	}
-	c.Header("Content-Type", "image/jpeg")
-	//c.HeaderBar("Content-Length", string(rune(len(data))))
-	c.Data(http.StatusOK, "image/jpeg", data)
+	// 从Content-Type头获取MIME类型
+	contentType := resp.Header.Get("Content-Type")
+	if contentType == "" {
+		// 如果无法确定内容类型,则默认为jpeg
+		contentType = "image/jpeg"
+	}
+	// 设置响应的内容类型
+	c.Writer.Header().Set("Content-Type", contentType)
+	// 将图片流式传输到响应体
+	_, err = io.Copy(c.Writer, resp.Body)
+	if err != nil {
+		log.Println("Failed to stream image:", err)
+	}
+	return
 }
 
 func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
@@ -92,7 +105,7 @@ func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
 			Result:      "",
 		}
 	}
-	midjourneyTask := model.GetByMJId(midjRequest.MjId)
+	midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId)
 	if midjourneyTask == nil {
 		return &MidjourneyResponse{
 			Code:        4,
@@ -121,16 +134,7 @@ func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
 	return nil
 }
 
-func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
-	taskId := c.Param("id")
-	originTask := model.GetByMJId(taskId)
-	if originTask == nil {
-		return &MidjourneyResponse{
-			Code:        4,
-			Description: "task_no_found",
-		}
-	}
-	var midjourneyTask Midjourney
+func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjourneyTask Midjourney) {
 	midjourneyTask.MjId = originTask.MjId
 	midjourneyTask.Progress = originTask.Progress
 	midjourneyTask.PromptEn = originTask.PromptEn
@@ -150,14 +154,65 @@ func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
 	midjourneyTask.Action = originTask.Action
 	midjourneyTask.Description = originTask.Description
 	midjourneyTask.Prompt = originTask.Prompt
-	jsonMap, err := json.Marshal(midjourneyTask)
-	if err != nil {
-		return &MidjourneyResponse{
-			Code:        4,
-			Description: "unmarshal_response_body_failed",
+	return
+}
+
+func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
+	userId := c.GetInt("id")
+	var err error
+	var respBody []byte
+	switch relayMode {
+	case RelayModeMidjourneyTaskFetch:
+		taskId := c.Param("id")
+		originTask := model.GetByMJId(userId, taskId)
+		if originTask == nil {
+			return &MidjourneyResponse{
+				Code:        4,
+				Description: "task_no_found",
+			}
+		}
+		midjourneyTask := getMidjourneyTaskModel(c, originTask)
+		respBody, err = json.Marshal(midjourneyTask)
+		if err != nil {
+			return &MidjourneyResponse{
+				Code:        4,
+				Description: "unmarshal_response_body_failed",
+			}
+		}
+	case RelayModeMidjourneyTaskFetchByCondition:
+		var condition = struct {
+			IDs []string `json:"ids"`
+		}{}
+		err = c.BindJSON(&condition)
+		if err != nil {
+			return &MidjourneyResponse{
+				Code:        4,
+				Description: "do_request_failed",
+			}
+		}
+		var tasks []Midjourney
+		if len(condition.IDs) != 0 {
+			originTasks := model.GetByMJIds(userId, condition.IDs)
+			for _, originTask := range originTasks {
+				midjourneyTask := getMidjourneyTaskModel(c, originTask)
+				tasks = append(tasks, midjourneyTask)
+			}
+		}
+		if tasks == nil {
+			tasks = make([]Midjourney, 0)
+		}
+		respBody, err = json.Marshal(tasks)
+		if err != nil {
+			return &MidjourneyResponse{
+				Code:        4,
+				Description: "unmarshal_response_body_failed",
+			}
 		}
 	}
-	_, err = io.Copy(c.Writer, bytes.NewBuffer(jsonMap))
+
+	c.Writer.Header().Set("Content-Type", "application/json")
+
+	_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
 	if err != nil {
 		return &MidjourneyResponse{
 			Code:        4,
@@ -167,6 +222,18 @@ func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
 	return nil
 }
 
+const (
+	// type 1 根据 mode 价格不同
+	MJSubmitActionImagine   = "IMAGINE"
+	MJSubmitActionVariation = "VARIATION" //变换
+	MJSubmitActionBlend     = "BLEND"     //混图
+
+	MJSubmitActionReroll = "REROLL" //重新生成
+	// type 2 固定价格
+	MJSubmitActionDescribe = "DESCRIBE"
+	MJSubmitActionUpscale  = "UPSCALE" // 放大
+)
+
 func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 	imageModel := "midjourney"
 
@@ -186,6 +253,9 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 			}
 		}
 	}
+
+	action := midjRequest.Action
+
 	if relayMode == RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
 		if midjRequest.Prompt == "" {
 			return &MidjourneyResponse{
@@ -199,7 +269,44 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 	} else if relayMode == RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
 		midjRequest.Action = "BLEND"
 	} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
-		originTask := model.GetByMJId(midjRequest.TaskId)
+		mjId := ""
+		if relayMode == RelayModeMidjourneyChange {
+			if midjRequest.TaskId == "" {
+				return &MidjourneyResponse{
+					Code:        4,
+					Description: "taskId_is_required",
+				}
+			} else if midjRequest.Action == "" {
+				return &MidjourneyResponse{
+					Code:        4,
+					Description: "action_is_required",
+				}
+			} else if midjRequest.Index == 0 {
+				return &MidjourneyResponse{
+					Code:        4,
+					Description: "index_can_only_be_1_2_3_4",
+				}
+			}
+			action = midjRequest.Action
+			mjId = midjRequest.TaskId
+		} else if relayMode == RelayModeMidjourneySimpleChange {
+			if midjRequest.Content == "" {
+				return &MidjourneyResponse{
+					Code:        4,
+					Description: "content_is_required",
+				}
+			}
+			params := convertSimpleChangeParams(midjRequest.Content)
+			if params == nil {
+				return &MidjourneyResponse{
+					Code:        4,
+					Description: "content_parse_failed",
+				}
+			}
+			mjId = params.ID
+			action = params.Action
+		}
+		originTask := model.GetByMJId(userId, mjId)
 		if originTask == nil {
 			return &MidjourneyResponse{
 				Code:        4,
@@ -229,23 +336,6 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 			log.Printf("检测到此操作为放大、变换,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
 		}
 		midjRequest.Prompt = originTask.Prompt
-	} else if relayMode == RelayModeMidjourneyChange {
-		if midjRequest.TaskId == "" {
-			return &MidjourneyResponse{
-				Code:        4,
-				Description: "taskId_is_required",
-			}
-		} else if midjRequest.Action == "" {
-			return &MidjourneyResponse{
-				Code:        4,
-				Description: "action_is_required",
-			}
-		} else if midjRequest.Index == 0 {
-			return &MidjourneyResponse{
-				Code:        4,
-				Description: "index_can_only_be_1_2_3_4",
-			}
-		}
 	}
 
 	// map model name
@@ -293,17 +383,17 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 		requestBody = c.Request.Body
 	}
 
-	modelRatio := common.GetModelRatio(imageModel)
+	modelPrice := common.GetModelPrice("mj_" + strings.ToLower(action))
 	groupRatio := common.GetGroupRatio(group)
-	ratio := modelRatio * groupRatio
+	ratio := modelPrice * groupRatio
 	userQuota, err := model.CacheGetUserQuota(userId)
-
-	sizeRatio := 1.0
-	if midjRequest.Action == "UPSCALE" {
-		sizeRatio = 0.2
+	if err != nil {
+		return &MidjourneyResponse{
+			Code:        4,
+			Description: err.Error(),
+		}
 	}
-
-	quota := int(ratio * sizeRatio * 1000)
+	quota := int(ratio * common.QuotaPerUnit)
 
 	if consumeQuota && userQuota-quota < 0 {
 		return &MidjourneyResponse{
@@ -369,7 +459,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 			}
 			if quota != 0 {
 				tokenName := c.GetString("token_name")
-				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
+				logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, action)
 				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId, userQuota)
 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 				channelId := c.GetInt("channel_id")
@@ -423,7 +513,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 	midjourneyTask := &model.Midjourney{
 		UserId:      userId,
 		Code:        midjResponse.Code,
-		Action:      midjRequest.Action,
+		Action:      action,
 		MjId:        midjResponse.Result,
 		Prompt:      midjRequest.Prompt,
 		PromptEn:    "",
@@ -504,3 +594,38 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 	}
 	return nil
 }
+
+type taskChangeParams struct {
+	ID     string
+	Action string
+	Index  int
+}
+
+func convertSimpleChangeParams(content string) *taskChangeParams {
+	split := strings.Split(content, " ")
+	if len(split) != 2 {
+		return nil
+	}
+
+	action := strings.ToLower(split[1])
+	changeParams := &taskChangeParams{}
+	changeParams.ID = split[0]
+
+	if action[0] == 'u' {
+		changeParams.Action = "UPSCALE"
+	} else if action[0] == 'v' {
+		changeParams.Action = "VARIATION"
+	} else if action == "r" {
+		changeParams.Action = "REROLL"
+		return changeParams
+	} else {
+		return nil
+	}
+
+	index, err := strconv.Atoi(action[1:2])
+	if err != nil || index < 1 || index > 4 {
+		return nil
+	}
+	changeParams.Index = index
+	return changeParams
+}

+ 10 - 2
controller/relay.go

@@ -95,8 +95,10 @@ const (
 	RelayModeMidjourneyDescribe
 	RelayModeMidjourneyBlend
 	RelayModeMidjourneyChange
+	RelayModeMidjourneySimpleChange
 	RelayModeMidjourneyNotify
 	RelayModeMidjourneyTaskFetch
+	RelayModeMidjourneyTaskFetchByCondition
 	RelayModeAudio
 )
 
@@ -263,6 +265,7 @@ type MidjourneyRequest struct {
 	State       string   `json:"state"`
 	TaskId      string   `json:"taskId"`
 	Base64Array []string `json:"base64Array"`
+	Content     string   `json:"content"`
 }
 
 type MidjourneyResponse struct {
@@ -342,14 +345,19 @@ func RelayMidjourney(c *gin.Context) {
 		relayMode = RelayModeMidjourneyNotify
 	} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
 		relayMode = RelayModeMidjourneyChange
-	} else if strings.HasPrefix(c.Request.URL.Path, "/mj/task") {
+	} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") {
+		relayMode = RelayModeMidjourneyChange
+	} else if strings.HasSuffix(c.Request.URL.Path, "/fetch") {
 		relayMode = RelayModeMidjourneyTaskFetch
+	} else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") {
+		relayMode = RelayModeMidjourneyTaskFetchByCondition
 	}
+
 	var err *MidjourneyResponse
 	switch relayMode {
 	case RelayModeMidjourneyNotify:
 		err = relayMidjourneyNotify(c)
-	case RelayModeMidjourneyTaskFetch:
+	case RelayModeMidjourneyTaskFetch, RelayModeMidjourneyTaskFetchByCondition:
 		err = relayMidjourneyTask(c, relayMode)
 	default:
 		err = relayMidjourneySubmit(c, relayMode)

+ 21 - 1
model/midjourney.go

@@ -96,7 +96,7 @@ func GetAllUnFinishTasks() []*Midjourney {
 	return tasks
 }
 
-func GetByMJId(mjId string) *Midjourney {
+func GetByOnlyMJId(mjId string) *Midjourney {
 	var mj *Midjourney
 	var err error
 	err = DB.Where("mj_id = ?", mjId).First(&mj).Error
@@ -106,6 +106,26 @@ func GetByMJId(mjId string) *Midjourney {
 	return mj
 }
 
+func GetByMJId(userId int, mjId string) *Midjourney {
+	var mj *Midjourney
+	var err error
+	err = DB.Where("user_id = ? and mj_id = ?", userId, mjId).First(&mj).Error
+	if err != nil {
+		return nil
+	}
+	return mj
+}
+
+func GetByMJIds(userId int, mjIds []string) []*Midjourney {
+	var mj []*Midjourney
+	var err error
+	err = DB.Where("user_id = ? and mj_id in (?)", userId, mjIds).Find(&mj).Error
+	if err != nil {
+		return nil
+	}
+	return mj
+}
+
 func GetMjByuId(id int) *Midjourney {
 	var mj *Midjourney
 	var err error

+ 2 - 0
router/relay-router.go

@@ -49,10 +49,12 @@ func SetRelayRouter(router *gin.Engine) {
 	{
 		relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney)
 		relayMjRouter.POST("/submit/change", controller.RelayMidjourney)
+		relayMjRouter.POST("/submit/simple-change", controller.RelayMidjourney)
 		relayMjRouter.POST("/submit/describe", controller.RelayMidjourney)
 		relayMjRouter.POST("/submit/blend", controller.RelayMidjourney)
 		relayMjRouter.POST("/notify", controller.RelayMidjourney)
 		relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
+		relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
 	}
 	//relayMjRouter.Use()
 }