Просмотр исходного кода

fix: fix http status code (close #193)

JustSong 2 лет назад
Родитель
Сommit
8a4cd403fd
4 измененных файлов с 22 добавлено и 22 удалено
  1. 12 12
      controller/relay-text.go
  2. 2 2
      controller/relay.go
  3. 3 3
      middleware/auth.go
  4. 5 5
      middleware/distributor.go

+ 12 - 12
controller/relay-text.go

@@ -76,7 +76,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	preConsumedQuota := int(float64(preConsumedTokens) * ratio)
 	userQuota, err := model.CacheGetUserQuota(userId)
 	if err != nil {
-		return errorWrapper(err, "get_user_quota_failed", http.StatusOK)
+		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
 	}
 	if userQuota > 10*preConsumedQuota {
 		// in this case, we do not pre-consume quota
@@ -86,12 +86,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	if consumeQuota && preConsumedQuota > 0 {
 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 		if err != nil {
-			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusOK)
+			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
 		}
 	}
 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
 	if err != nil {
-		return errorWrapper(err, "new_request_failed", http.StatusOK)
+		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 	}
 	if channelType == common.ChannelTypeAzure {
 		key := c.Request.Header.Get("Authorization")
@@ -106,15 +106,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	client := &http.Client{}
 	resp, err := client.Do(req)
 	if err != nil {
-		return errorWrapper(err, "do_request_failed", http.StatusOK)
+		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 	}
 	err = req.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_request_body_failed", http.StatusOK)
+		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 	}
 	err = c.Request.Body.Close()
 	if err != nil {
-		return errorWrapper(err, "close_request_body_failed", http.StatusOK)
+		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 	}
 	var textResponse TextResponse
 	isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
@@ -224,22 +224,22 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		})
 		err = resp.Body.Close()
 		if err != nil {
-			return errorWrapper(err, "close_response_body_failed", http.StatusOK)
+			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 		}
 		return nil
 	} else {
 		if consumeQuota {
 			responseBody, err := io.ReadAll(resp.Body)
 			if err != nil {
-				return errorWrapper(err, "read_response_body_failed", http.StatusOK)
+				return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 			}
 			err = resp.Body.Close()
 			if err != nil {
-				return errorWrapper(err, "close_response_body_failed", http.StatusOK)
+				return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 			}
 			err = json.Unmarshal(responseBody, &textResponse)
 			if err != nil {
-				return errorWrapper(err, "unmarshal_response_body_failed", http.StatusOK)
+				return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 			}
 			if textResponse.Error.Type != "" {
 				return &OpenAIErrorWithStatusCode{
@@ -260,11 +260,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		c.Writer.WriteHeader(resp.StatusCode)
 		_, err = io.Copy(c.Writer, resp.Body)
 		if err != nil {
-			return errorWrapper(err, "copy_response_body_failed", http.StatusOK)
+			return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
 		}
 		err = resp.Body.Close()
 		if err != nil {
-			return errorWrapper(err, "close_response_body_failed", http.StatusOK)
+			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 		}
 		return nil
 	}

+ 2 - 2
controller/relay.go

@@ -135,7 +135,7 @@ func RelayNotImplemented(c *gin.Context) {
 		Param:   "",
 		Code:    "api_not_implemented",
 	}
-	c.JSON(http.StatusOK, gin.H{
+	c.JSON(http.StatusNotImplemented, gin.H{
 		"error": err,
 	})
 }
@@ -147,7 +147,7 @@ func RelayNotFound(c *gin.Context) {
 		Param:   "",
 		Code:    "api_not_found",
 	}
-	c.JSON(http.StatusOK, gin.H{
+	c.JSON(http.StatusNotFound, gin.H{
 		"error": err,
 	})
 }

+ 3 - 3
middleware/auth.go

@@ -91,7 +91,7 @@ func TokenAuth() func(c *gin.Context) {
 		key = parts[0]
 		token, err := model.ValidateUserToken(key)
 		if err != nil {
-			c.JSON(http.StatusOK, gin.H{
+			c.JSON(http.StatusUnauthorized, gin.H{
 				"error": gin.H{
 					"message": err.Error(),
 					"type":    "one_api_error",
@@ -101,7 +101,7 @@ func TokenAuth() func(c *gin.Context) {
 			return
 		}
 		if !model.CacheIsUserEnabled(token.UserId) {
-			c.JSON(http.StatusOK, gin.H{
+			c.JSON(http.StatusForbidden, gin.H{
 				"error": gin.H{
 					"message": "用户已被封禁",
 					"type":    "one_api_error",
@@ -123,7 +123,7 @@ func TokenAuth() func(c *gin.Context) {
 			if model.IsAdmin(token.UserId) {
 				c.Set("channelId", parts[1])
 			} else {
-				c.JSON(http.StatusOK, gin.H{
+				c.JSON(http.StatusForbidden, gin.H{
 					"error": gin.H{
 						"message": "普通用户不支持指定渠道",
 						"type":    "one_api_error",

+ 5 - 5
middleware/distributor.go

@@ -24,7 +24,7 @@ func Distribute() func(c *gin.Context) {
 		if ok {
 			id, err := strconv.Atoi(channelId.(string))
 			if err != nil {
-				c.JSON(http.StatusOK, gin.H{
+				c.JSON(http.StatusBadRequest, gin.H{
 					"error": gin.H{
 						"message": "无效的渠道 ID",
 						"type":    "one_api_error",
@@ -35,7 +35,7 @@ func Distribute() func(c *gin.Context) {
 			}
 			channel, err = model.GetChannelById(id, true)
 			if err != nil {
-				c.JSON(200, gin.H{
+				c.JSON(http.StatusBadRequest, gin.H{
 					"error": gin.H{
 						"message": "无效的渠道 ID",
 						"type":    "one_api_error",
@@ -45,7 +45,7 @@ func Distribute() func(c *gin.Context) {
 				return
 			}
 			if channel.Status != common.ChannelStatusEnabled {
-				c.JSON(200, gin.H{
+				c.JSON(http.StatusForbidden, gin.H{
 					"error": gin.H{
 						"message": "该渠道已被禁用",
 						"type":    "one_api_error",
@@ -59,7 +59,7 @@ func Distribute() func(c *gin.Context) {
 			var modelRequest ModelRequest
 			err := common.UnmarshalBodyReusable(c, &modelRequest)
 			if err != nil {
-				c.JSON(200, gin.H{
+				c.JSON(http.StatusBadRequest, gin.H{
 					"error": gin.H{
 						"message": "无效的请求",
 						"type":    "one_api_error",
@@ -75,7 +75,7 @@ func Distribute() func(c *gin.Context) {
 			}
 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
 			if err != nil {
-				c.JSON(200, gin.H{
+				c.JSON(http.StatusServiceUnavailable, gin.H{
 					"error": gin.H{
 						"message": "无可用渠道",
 						"type":    "one_api_error",