Browse Source

refactor: improve validation logic and error handling in relay-text.go

- Simplified validation checks for MaxTokens and Messages fields.
- Enhanced error messages for better clarity.
- Updated goroutine to avoid passing context unnecessarily.
CalciumIon 1 year ago
parent
commit
a9f739a7e2
1 changed files with 11 additions and 10 deletions
  1. 11 10
      relay/relay-text.go

+ 11 - 10
relay/relay-text.go

@@ -2,11 +2,9 @@ package relay
 
 import (
 	"bytes"
-	"context"
 	"encoding/json"
 	"errors"
 	"fmt"
-	"github.com/bytedance/sonic"
 	"io"
 	"math"
 	"net/http"
@@ -20,6 +18,8 @@ import (
 	"strings"
 	"time"
 
+	"github.com/bytedance/sonic"
+
 	"github.com/gin-gonic/gin"
 )
 
@@ -36,7 +36,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
 		textRequest.Model = c.Param("model")
 	}
 
-	if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
+	if textRequest.MaxTokens > math.MaxInt32/2 {
 		return nil, errors.New("max_tokens is invalid")
 	}
 	if textRequest.Model == "" {
@@ -48,12 +48,12 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
 			return nil, errors.New("field prompt is required")
 		}
 	case relayconstant.RelayModeChatCompletions:
-		if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
+		if len(textRequest.Messages) == 0 {
 			return nil, errors.New("field messages is required")
 		}
 	case relayconstant.RelayModeEmbeddings:
 	case relayconstant.RelayModeModerations:
-		if textRequest.Input == "" || textRequest.Input == nil {
+		if textRequest.Input == nil || textRequest.Input == "" {
 			return nil, errors.New("field input is required")
 		}
 	case relayconstant.RelayModeEdits:
@@ -264,7 +264,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
 		return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 	}
 	if userQuota-preConsumedQuota < 0 {
-		return 0, 0, service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest)
+		return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest)
 	}
 	err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
 	if err != nil {
@@ -298,13 +298,14 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
 
 func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) {
 	if preConsumedQuota != 0 {
-		go func(ctx context.Context) {
-			// return pre-consumed quota
-			err := model.PostConsumeTokenQuota(relayInfo, userQuota, -preConsumedQuota, 0, false)
+		go func() {
+			relayInfoCopy := *relayInfo
+
+			err := model.PostConsumeTokenQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false)
 			if err != nil {
 				common.SysError("error return pre-consumed quota: " + err.Error())
 			}
-		}(c)
+		}()
 	}
 }