فهرست منبع

修复多路复用bug

CaIon 2 سال پیش
والد
کامیت
fe94656260
5فایلهای تغییر یافته به همراه42 افزوده شده و 8 حذف شده
  1. 4 0
      controller/midjourney.go
  2. 5 1
      controller/relay-mj.go
  3. 18 1
      controller/relay-text.go
  4. 2 2
      controller/relay.go
  5. 13 4
      middleware/auth.go

+ 4 - 0
controller/midjourney.go

@@ -37,22 +37,26 @@ func UpdateMidjourneyTask() {
 			jsonStr, err := json.Marshal(requestBody)
 			if err != nil {
 				log.Printf("UpdateMidjourneyTask: %v", err)
+				continue
 			}
 			req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(jsonStr))
 			if err != nil {
 				log.Printf("UpdateMidjourneyTask: %v", err)
+				continue
 			}
 			req.Header.Set("Content-Type", "application/json")
 			req.Header.Set("mj-api-secret", "uhiftyuwadbkjshbiklahcuitguasguzhxliawodawdu")
 			resp, err := httpClient.Do(req)
 			if err != nil {
 				log.Printf("UpdateMidjourneyTask: %v", err)
+				continue
 			}
 			defer resp.Body.Close()
 			var response []Midjourney
 			err = json.NewDecoder(resp.Body).Decode(&response)
 			if err != nil {
 				log.Printf("UpdateMidjourneyTask: %v", err)
+				continue
 			}
 			for _, responseItem := range response {
 				var midjourneyTask *model.Midjourney

+ 5 - 1
controller/relay-mj.go

@@ -248,6 +248,10 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 
 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+	//mjToken := ""
+	//if c.Request.Header.Get("Authorization") != "" {
+	//	mjToken = strings.Split(c.Request.Header.Get("Authorization"), " ")[1]
+	//}
 	req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
 	// print request header
 	log.Printf("request header: %s", req.Header)
@@ -353,7 +357,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 		Progress:    "0%",
 		FailReason:  "",
 	}
-	if midjResponse.Code == 4 {
+	if midjResponse.Code == 4 || midjResponse.Code == 24 {
 		midjourneyTask.FailReason = midjResponse.Description
 	}
 	err = midjourneyTask.Insert()

+ 18 - 1
controller/relay-text.go

@@ -7,6 +7,7 @@ import (
 	"fmt"
 	"github.com/gin-gonic/gin"
 	"io"
+	"log"
 	"net/http"
 	"one-api/common"
 	"one-api/model"
@@ -278,6 +279,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 
 	if apiType != APITypeXunfei { // cause xunfei use websocket
 		req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+		// 设置GetBody函数,该函数返回一个新的io.ReadCloser,该io.ReadCloser返回与原始请求体相同的数据
+		req.GetBody = func() (io.ReadCloser, error) {
+			return io.NopCloser(requestBody), nil
+		}
 		if err != nil {
 			return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 		}
@@ -308,7 +313,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		}
 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 		req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+
 		//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
+		req.Close = true
 		resp, err = httpClient.Do(req)
 		if err != nil {
 			return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
@@ -324,8 +331,18 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
 
 		if resp.StatusCode != http.StatusOK {
+			//print resp body
+			body, err := io.ReadAll(resp.Body)
+			if err != nil {
+				log.Println("read resp err body failed", err)
+			}
+			log.Println("resp body:", string(body))
+			errStr := fmt.Sprintf("bad status code: %d", resp.StatusCode)
+			if resp.StatusCode == 503 {
+				errStr = string(body)
+			}
 			return errorWrapper(
-				fmt.Errorf("bad status code: %d", resp.StatusCode), "bad_status_code", resp.StatusCode)
+				fmt.Errorf(errStr), "bad_status_code", resp.StatusCode)
 		}
 	}
 

+ 2 - 2
controller/relay.go

@@ -205,7 +205,7 @@ func Relay(c *gin.Context) {
 			})
 		}
 		channelId := c.GetInt("channel_id")
-		common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
+		common.SysError(fmt.Sprintf("relay error (channel #%d): %v ", channelId, err))
 		// https://platform.openai.com/docs/guides/error-codes/api-errors
 		if shouldDisableChannel(&err.OpenAIError) {
 			channelId := c.GetInt("channel_id")
@@ -259,7 +259,7 @@ func RelayMidjourney(c *gin.Context) {
 		//	channelId := c.GetInt("channel_id")
 		//	channelName := c.GetString("channel_name")
 		//	disableChannel(channelId, channelName, err.Result)
-		//}
+		//};''''''''''''''''''''''''''''''''
 	}
 }
 

+ 13 - 4
middleware/auth.go

@@ -85,10 +85,19 @@ func RootAuth() func(c *gin.Context) {
 func TokenAuth() func(c *gin.Context) {
 	return func(c *gin.Context) {
 		key := c.Request.Header.Get("Authorization")
-		key = strings.TrimPrefix(key, "Bearer ")
-		key = strings.TrimPrefix(key, "sk-")
-		parts := strings.Split(key, "-")
-		key = parts[0]
+		parts := make([]string, 0)
+		if key == "" {
+			key = c.Request.Header.Get("mj-api-secret")
+			key = strings.TrimPrefix(key, "Bearer ")
+			key = strings.TrimPrefix(key, "sk-")
+			parts := strings.Split(key, "-")
+			key = parts[0]
+		} else {
+			key = strings.TrimPrefix(key, "Bearer ")
+			key = strings.TrimPrefix(key, "sk-")
+			parts := strings.Split(key, "-")
+			key = parts[0]
+		}
 		token, err := model.ValidateUserToken(key)
 		if err != nil {
 			c.JSON(http.StatusUnauthorized, gin.H{