CaIon 1 год назад
Родитель
Сommit
fd3a41bacb
4 измененных файлов с 22 добавлено и 14 удалено
  1. 1 0
      dto/midjourney.go
  2. 13 14
      relay/relay-mj.go
  3. 2 0
      web/src/components/MjLogsTable.js
  4. 6 0
      web/src/pages/Channel/EditChannel.js

+ 1 - 0
dto/midjourney.go

@@ -25,6 +25,7 @@ type MidjourneyDto struct {
 	MjId        string `json:"id"`
 	Action      string `json:"action"`
 	CustomId    string `json:"customId"`
+	BotType     string `json:"botType"`
 	Prompt      string `json:"prompt"`
 	PromptEn    string `json:"promptEn"`
 	Description string `json:"description"`

+ 13 - 14
relay/relay-mj.go

@@ -112,7 +112,7 @@ func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
 	return nil
 }
 
-func getMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) {
+func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) {
 	midjourneyTask.MjId = originTask.MjId
 	midjourneyTask.Progress = originTask.Progress
 	midjourneyTask.PromptEn = originTask.PromptEn
@@ -181,7 +181,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
 				Description: "task_no_found",
 			}
 		}
-		midjourneyTask := getMidjourneyTaskDto(c, originTask)
+		midjourneyTask := coverMidjourneyTaskDto(c, originTask)
 		respBody, err = json.Marshal(midjourneyTask)
 		if err != nil {
 			return &dto.MidjourneyResponse{
@@ -204,7 +204,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
 		if len(condition.IDs) != 0 {
 			originTasks := model.GetByMJIds(userId, condition.IDs)
 			for _, originTask := range originTasks {
-				midjourneyTask := getMidjourneyTaskDto(c, originTask)
+				midjourneyTask := coverMidjourneyTaskDto(c, originTask)
 				tasks = append(tasks, midjourneyTask)
 			}
 		}
@@ -403,23 +403,22 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 		}
 	}
 	//req.Header.Set("ApiKey", c.Request.Header.Get("ApiKey"))
-
+	timeout := time.Second * 30
+	ctx, cancel := context.WithTimeout(context.Background(), timeout)
+	// 使用带有超时的 context 创建新的请求
+	req = req.WithContext(ctx)
 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
-	//mjToken := ""
-	//if c.Request.Header.Get("ApiKey") != "" {
-	//	mjToken = strings.Split(c.Request.Header.Get("ApiKey"), " ")[1]
-	//}
-	//req.Header.Set("ApiKey", "Bearer midjourney-proxy")
 	req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
 	// print request header
-	log.Printf("request header: %s", req.Header)
-	log.Printf("request body: %s", midjRequest.Prompt)
+	//log.Printf("request header: %s", req.Header)
+	//log.Printf("request body: %s", midjRequest.Prompt)
 
+	defer cancel()
 	resp, err := service.GetHttpClient().Do(req)
 	if err != nil {
 		return &dto.MidjourneyResponse{
-			Code:        4,
+			Code:        5,
 			Description: "do_request_failed",
 		}
 	}
@@ -427,14 +426,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 	err = req.Body.Close()
 	if err != nil {
 		return &dto.MidjourneyResponse{
-			Code:        4,
+			Code:        5,
 			Description: "close_request_body_failed",
 		}
 	}
 	err = c.Request.Body.Close()
 	if err != nil {
 		return &dto.MidjourneyResponse{
-			Code:        4,
+			Code:        5,
 			Description: "close_request_body_failed",
 		}
 	}

+ 2 - 0
web/src/components/MjLogsTable.js

@@ -35,6 +35,8 @@ function renderType(type) {
             return <Tag color="yellow" size='large'>图生文</Tag>;
         case 'BLEAND':
             return <Tag color="lime" size='large'>图混合</Tag>;
+        case 'REROLL':
+            return <Tag color="indigo" size='large'>重绘</Tag>;
         case 'INPAINT':
             return <Tag color="violet" size='large'>局部重绘</Tag>;
         case 'INPAINT_PRE':

+ 6 - 0
web/src/pages/Channel/EditChannel.js

@@ -95,6 +95,12 @@ const EditChannel = (props) => {
                 case 26:
                     localModels = ['glm-4', 'glm-4v', 'glm-3-turbo'];
                     break;
+                case 2:
+                    localModels = ['midjourney'];
+                    break;
+                case 5:
+                    localModels = ['midjourney'];
+                    break;
             }
             setInputs((inputs) => ({...inputs, models: localModels}));
         }