CaIon hace 2 años
padre
commit
43be1982d7

+ 1 - 1
controller/channel-test.go

@@ -174,7 +174,7 @@ func testAllChannels(notify bool) error {
 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
 				disableChannel(channel.Id, channel.Name, err.Error())
 			}
-			if shouldDisableChannel(openaiErr) {
+			if shouldDisableChannel(openaiErr, -1) {
 				disableChannel(channel.Id, channel.Name, err.Error())
 			}
 			channel.UpdateResponseTime(milliseconds)

+ 5 - 5
controller/midjourney.go

@@ -15,13 +15,13 @@ import (
 
 func UpdateMidjourneyTask() {
 	//revocer
-	defer func() {
-		if err := recover(); err != nil {
-			log.Printf("UpdateMidjourneyTask: %v", err)
-		}
-	}()
 	imageModel := "midjourney"
 	for {
+		defer func() {
+			if err := recover(); err != nil {
+				log.Printf("UpdateMidjourneyTask: %v", err)
+			}
+		}()
 		time.Sleep(time.Duration(15) * time.Second)
 		tasks := model.GetAllUnFinishTasks()
 		if len(tasks) != 0 {

+ 1 - 13
controller/relay-text.go

@@ -7,7 +7,6 @@ import (
 	"fmt"
 	"github.com/gin-gonic/gin"
 	"io"
-	"log"
 	"net/http"
 	"one-api/common"
 	"one-api/model"
@@ -331,18 +330,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
 
 		if resp.StatusCode != http.StatusOK {
-			//print resp body
-			body, err := io.ReadAll(resp.Body)
-			if err != nil {
-				log.Println("read resp err body failed", err)
-			}
-			log.Println("resp body:", string(body))
-			errStr := fmt.Sprintf("bad status code: %d", resp.StatusCode)
-			if resp.StatusCode == 503 {
-				errStr = string(body)
-			}
-			return errorWrapper(
-				fmt.Errorf(errStr), "bad_status_code", resp.StatusCode)
+			return relayErrorHandler(resp)
 		}
 	}
 

+ 35 - 1
controller/relay-utils.go

@@ -1,10 +1,14 @@
 package controller
 
 import (
+	"encoding/json"
 	"fmt"
 	"github.com/gin-gonic/gin"
 	"github.com/pkoukk/tiktoken-go"
+	"io"
+	"net/http"
 	"one-api/common"
+	"strconv"
 )
 
 var stopFinishReason = "stop"
@@ -95,13 +99,16 @@ func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatus
 	}
 }
 
-func shouldDisableChannel(err *OpenAIError) bool {
+func shouldDisableChannel(err *OpenAIError, statusCode int) bool {
 	if !common.AutomaticDisableChannelEnabled {
 		return false
 	}
 	if err == nil {
 		return false
 	}
+	if statusCode == http.StatusUnauthorized {
+		return true
+	}
 	if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
 		return true
 	}
@@ -115,3 +122,30 @@ func setEventStreamHeaders(c *gin.Context) {
 	c.Writer.Header().Set("Transfer-Encoding", "chunked")
 	c.Writer.Header().Set("X-Accel-Buffering", "no")
 }
+
+func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) {
+	openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{
+		StatusCode: resp.StatusCode,
+		OpenAIError: OpenAIError{
+			Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
+			Type:    "one_api_error",
+			Code:    "bad_response_status_code",
+			Param:   strconv.Itoa(resp.StatusCode),
+		},
+	}
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return
+	}
+	var textResponse TextResponse
+	err = json.Unmarshal(responseBody, &textResponse)
+	if err != nil {
+		return
+	}
+	openAIErrorWithStatusCode.OpenAIError = textResponse.Error
+	return
+}

+ 1 - 1
controller/relay.go

@@ -207,7 +207,7 @@ func Relay(c *gin.Context) {
 		channelId := c.GetInt("channel_id")
 		common.SysError(fmt.Sprintf("relay error (channel #%d): %v ", channelId, err))
 		// https://platform.openai.com/docs/guides/error-codes/api-errors
-		if shouldDisableChannel(&err.OpenAIError) {
+		if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
 			channelId := c.GetInt("channel_id")
 			channelName := c.GetString("channel_name")
 			disableChannel(channelId, channelName, err.Message)