Просмотр исходного кода

Merge pull request #2452 from QuantumNous/fix/oom-request-body-limit

Seefs 1 неделя назад
Родитель
Сommit
b35ae9f693

+ 1 - 0
README.en.md

@@ -305,6 +305,7 @@ docker run --name new-api -d --restart always \
 | `REDIS_CONN_STRING` | Redis connection string | - |
 | `STREAMING_TIMEOUT` | Streaming timeout (seconds) | `300` |
 | `STREAM_SCANNER_MAX_BUFFER_MB` | Max per-line buffer (MB) for the stream scanner; increase when upstream sends huge image/base64 payloads | `64` |
+| `MAX_REQUEST_BODY_MB` | Max request body size (MB, counted **after decompression**; prevents huge requests/zip bombs from exhausting memory). Exceeding it returns `413` | `32` |
 | `AZURE_DEFAULT_API_VERSION` | Azure API version | `2025-04-01-preview` |
 | `ERROR_LOG_ENABLED` | Error log switch | `false` |
 

+ 1 - 0
README.fr.md

@@ -301,6 +301,7 @@ docker run --name new-api -d --restart always \
 | `REDIS_CONN_STRING` | Chaine de connexion Redis | - |
 | `STREAMING_TIMEOUT` | Délai d'expiration du streaming (secondes) | `300` |
 | `STREAM_SCANNER_MAX_BUFFER_MB` | Taille max du buffer par ligne (Mo) pour le scanner SSE ; à augmenter quand les sorties image/base64 sont très volumineuses (ex. images 4K) | `64` |
+| `MAX_REQUEST_BODY_MB` | Taille maximale du corps de requête (Mo, comptée **après décompression** ; évite les requêtes énormes/zip bombs qui saturent la mémoire). Dépassement ⇒ `413` | `32` |
 | `AZURE_DEFAULT_API_VERSION` | Version de l'API Azure | `2025-04-01-preview` |
 | `ERROR_LOG_ENABLED` | Interrupteur du journal d'erreurs | `false` |
 

+ 1 - 0
README.ja.md

@@ -310,6 +310,7 @@ docker run --name new-api -d --restart always \
 | `REDIS_CONN_STRING` | Redis接続文字列 | - |
 | `STREAMING_TIMEOUT` | ストリーミング応答のタイムアウト時間(秒) | `300` |
 | `STREAM_SCANNER_MAX_BUFFER_MB` | ストリームスキャナの1行あたりバッファ上限(MB)。4K画像など巨大なbase64 `data:` ペイロードを扱う場合は値を増加させてください | `64` |
+| `MAX_REQUEST_BODY_MB` | リクエストボディ最大サイズ(MB、**解凍後**に計測。巨大リクエスト/zip bomb によるメモリ枯渇を防止)。超過時は `413` | `32` |
 | `AZURE_DEFAULT_API_VERSION` | Azure APIバージョン | `2025-04-01-preview` |
 | `ERROR_LOG_ENABLED` | エラーログスイッチ | `false` |
 

+ 1 - 0
README.md

@@ -306,6 +306,7 @@ docker run --name new-api -d --restart always \
 | `REDIS_CONN_STRING` | Redis 连接字符串                                                  | - |
 | `STREAMING_TIMEOUT` | 流式超时时间(秒)                                                    | `300` |
 | `STREAM_SCANNER_MAX_BUFFER_MB` | 流式扫描器单行最大缓冲(MB),图像生成等超大 `data:` 片段(如 4K 图片 base64)需适当调大 | `64` |
+| `MAX_REQUEST_BODY_MB` | 请求体最大大小(MB,**解压后**计;防止超大请求/zip bomb 导致内存暴涨),超过将返回 `413` | `32` |
 | `AZURE_DEFAULT_API_VERSION` | Azure API 版本                                                 | `2025-04-01-preview` |
 | `ERROR_LOG_ENABLED` | 错误日志开关                                                       | `false` |
 

+ 35 - 6
common/gin.go

@@ -18,18 +18,47 @@ import (
 
 const KeyRequestBody = "key_request_body"
 
+var ErrRequestBodyTooLarge = errors.New("request body too large")
+
+func IsRequestBodyTooLargeError(err error) bool {
+	if err == nil {
+		return false
+	}
+	if errors.Is(err, ErrRequestBodyTooLarge) {
+		return true
+	}
+	var mbe *http.MaxBytesError
+	return errors.As(err, &mbe)
+}
+
 func GetRequestBody(c *gin.Context) ([]byte, error) {
-	requestBody, _ := c.Get(KeyRequestBody)
-	if requestBody != nil {
-		return requestBody.([]byte), nil
+	cached, exists := c.Get(KeyRequestBody)
+	if exists && cached != nil {
+		if b, ok := cached.([]byte); ok {
+			return b, nil
+		}
+	}
+	maxMB := constant.MaxRequestBodyMB
+	if maxMB <= 0 {
+		maxMB = 32
 	}
-	requestBody, err := io.ReadAll(c.Request.Body)
+	maxBytes := int64(maxMB) << 20
+
+	limited := io.LimitReader(c.Request.Body, maxBytes+1)
+	body, err := io.ReadAll(limited)
 	if err != nil {
+		_ = c.Request.Body.Close()
+		if IsRequestBodyTooLargeError(err) {
+			return nil, ErrRequestBodyTooLarge
+		}
 		return nil, err
 	}
 	_ = c.Request.Body.Close()
-	c.Set(KeyRequestBody, requestBody)
-	return requestBody.([]byte), nil
+	if int64(len(body)) > maxBytes {
+		return nil, ErrRequestBodyTooLarge
+	}
+	c.Set(KeyRequestBody, body)
+	return body, nil
 }
 
 func UnmarshalBodyReusable(c *gin.Context, v any) error {

+ 2 - 0
common/init.go

@@ -117,6 +117,8 @@ func initConstantEnv() {
 	constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
 	constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
 	constant.StreamScannerMaxBufferMB = GetEnvOrDefault("STREAM_SCANNER_MAX_BUFFER_MB", 64)
+	// MaxRequestBodyMB 请求体最大大小(解压后),用于防止超大请求/zip bomb导致内存暴涨
+	constant.MaxRequestBodyMB = GetEnvOrDefault("MAX_REQUEST_BODY_MB", 32)
 	// ForceStreamOption 覆盖请求参数,强制返回usage信息
 	constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
 	constant.CountToken = GetEnvOrDefaultBool("CountToken", true)

+ 1 - 0
constant/env.go

@@ -9,6 +9,7 @@ var CountToken bool
 var GetMediaToken bool
 var GetMediaTokenNotStream bool
 var UpdateTask bool
+var MaxRequestBodyMB int
 var AzureDefaultAPIVersion string
 var GeminiVisionMaxImageNum int
 var NotifyLimitCount int

+ 1 - 1
controller/discord.go

@@ -114,7 +114,7 @@ func DiscordOAuth(c *gin.Context) {
 		DiscordBind(c)
 		return
 	}
-		if !system_setting.GetDiscordSettings().Enabled {
+	if !system_setting.GetDiscordSettings().Enabled {
 		c.JSON(http.StatusOK, gin.H{
 			"success": false,
 			"message": "管理员未开启通过 Discord 登录以及注册",

+ 67 - 9
controller/relay.go

@@ -2,6 +2,7 @@ package controller
 
 import (
 	"bytes"
+	"errors"
 	"fmt"
 	"io"
 	"log"
@@ -104,7 +105,12 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
 
 	request, err := helper.GetAndValidateRequest(c, relayFormat)
 	if err != nil {
-		newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
+		// Map "request body too large" to 413 so clients can handle it correctly
+		if common.IsRequestBodyTooLargeError(err) || errors.Is(err, common.ErrRequestBodyTooLarge) {
+			newAPIError = types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusRequestEntityTooLarge, types.ErrOptionWithSkipRetry())
+		} else {
+			newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
+		}
 		return
 	}
 
@@ -114,9 +120,17 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
 		return
 	}
 
-	meta := request.GetTokenCountMeta()
+	needSensitiveCheck := setting.ShouldCheckPromptSensitive()
+	needCountToken := constant.CountToken
+	// Avoid building huge CombineText (strings.Join) when token counting and sensitive check are both disabled.
+	var meta *types.TokenCountMeta
+	if needSensitiveCheck || needCountToken {
+		meta = request.GetTokenCountMeta()
+	} else {
+		meta = fastTokenCountMetaForPricing(request)
+	}
 
-	if setting.ShouldCheckPromptSensitive() {
+	if needSensitiveCheck && meta != nil {
 		contains, words := service.CheckSensitiveText(meta.CombineText)
 		if contains {
 			logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
@@ -165,15 +179,24 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
 	}
 
 	for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() {
-		channel, err := getChannel(c, relayInfo, retryParam)
-		if err != nil {
-			logger.LogError(c, err.Error())
-			newAPIError = err
+		channel, channelErr := getChannel(c, relayInfo, retryParam)
+		if channelErr != nil {
+			logger.LogError(c, channelErr.Error())
+			newAPIError = channelErr
 			break
 		}
 
 		addUsedChannel(c, channel.Id)
-		requestBody, _ := common.GetRequestBody(c)
+		requestBody, bodyErr := common.GetRequestBody(c)
+		if bodyErr != nil {
+			// Ensure consistent 413 for oversized bodies even when error occurs later (e.g., retry path)
+			if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) {
+				newAPIError = types.NewErrorWithStatusCode(bodyErr, types.ErrorCodeReadRequestBodyFailed, http.StatusRequestEntityTooLarge, types.ErrOptionWithSkipRetry())
+			} else {
+				newAPIError = types.NewErrorWithStatusCode(bodyErr, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
+			}
+			break
+		}
 		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 
 		switch relayFormat {
@@ -218,6 +241,33 @@ func addUsedChannel(c *gin.Context, channelId int) {
 	c.Set("use_channel", useChannel)
 }
 
+func fastTokenCountMetaForPricing(request dto.Request) *types.TokenCountMeta {
+	if request == nil {
+		return &types.TokenCountMeta{}
+	}
+	meta := &types.TokenCountMeta{
+		TokenType: types.TokenTypeTokenizer,
+	}
+	switch r := request.(type) {
+	case *dto.GeneralOpenAIRequest:
+		if r.MaxCompletionTokens > r.MaxTokens {
+			meta.MaxTokens = int(r.MaxCompletionTokens)
+		} else {
+			meta.MaxTokens = int(r.MaxTokens)
+		}
+	case *dto.OpenAIResponsesRequest:
+		meta.MaxTokens = int(r.MaxOutputTokens)
+	case *dto.ClaudeRequest:
+		meta.MaxTokens = int(r.MaxTokens)
+	case *dto.ImageRequest:
+		// Pricing for image requests depends on ImagePriceRatio; safe to compute even when CountToken is disabled.
+		return r.GetTokenCountMeta()
+	default:
+		// Best-effort: leave CombineText empty to avoid large allocations.
+	}
+	return meta
+}
+
 func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryParam *service.RetryParam) (*model.Channel, *types.NewAPIError) {
 	if info.ChannelMeta == nil {
 		autoBan := c.GetBool("auto_ban")
@@ -432,7 +482,15 @@ func RelayTask(c *gin.Context) {
 		logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry()))
 		//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
 
-		requestBody, _ := common.GetRequestBody(c)
+		requestBody, err := common.GetRequestBody(c)
+		if err != nil {
+			if common.IsRequestBodyTooLargeError(err) || errors.Is(err, common.ErrRequestBodyTooLarge) {
+				taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusRequestEntityTooLarge)
+			} else {
+				taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusBadRequest)
+			}
+			break
+		}
 		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 		taskErr = taskRelayHandler(c, relayInfo)
 	}

+ 2 - 2
controller/task.go

@@ -88,7 +88,7 @@ func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM
 	for channelId, taskIds := range taskChannelM {
 		err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
 		if err != nil {
-			logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
+			logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error()))
 		}
 	}
 	return nil
@@ -141,7 +141,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
 		return err
 	}
 	if !responseItems.IsSuccess() {
-		common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody)))
+		common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody)))
 		return err
 	}
 

+ 3 - 3
controller/topup_creem.go

@@ -7,12 +7,12 @@ import (
 	"encoding/hex"
 	"encoding/json"
 	"fmt"
-	"io"
-	"log"
-	"net/http"
 	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/setting"
+	"io"
+	"log"
+	"net/http"
 	"time"
 
 	"github.com/gin-gonic/gin"

+ 1 - 1
middleware/distributor.go

@@ -162,7 +162,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 			}
 			midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
 			if mjErr != nil {
-				return nil, false, fmt.Errorf(mjErr.Description)
+				return nil, false, fmt.Errorf("%s", mjErr.Description)
 			}
 			if midjourneyModel == "" {
 				if !success {

+ 44 - 7
middleware/gzip.go

@@ -5,32 +5,69 @@ import (
 	"io"
 	"net/http"
 
+	"github.com/QuantumNous/new-api/constant"
 	"github.com/andybalholm/brotli"
 	"github.com/gin-gonic/gin"
 )
 
+type readCloser struct {
+	io.Reader
+	closeFn func() error
+}
+
+func (rc *readCloser) Close() error {
+	if rc.closeFn != nil {
+		return rc.closeFn()
+	}
+	return nil
+}
+
 func DecompressRequestMiddleware() gin.HandlerFunc {
 	return func(c *gin.Context) {
 		if c.Request.Body == nil || c.Request.Method == http.MethodGet {
 			c.Next()
 			return
 		}
+		maxMB := constant.MaxRequestBodyMB
+		if maxMB <= 0 {
+			maxMB = 32
+		}
+		maxBytes := int64(maxMB) << 20
+
+		origBody := c.Request.Body
+		wrapMaxBytes := func(body io.ReadCloser) io.ReadCloser {
+			return http.MaxBytesReader(c.Writer, body, maxBytes)
+		}
+
 		switch c.GetHeader("Content-Encoding") {
 		case "gzip":
-			gzipReader, err := gzip.NewReader(c.Request.Body)
+			gzipReader, err := gzip.NewReader(origBody)
 			if err != nil {
+				_ = origBody.Close()
 				c.AbortWithStatus(http.StatusBadRequest)
 				return
 			}
-			defer gzipReader.Close()
-
-			// Replace the request body with the decompressed data
-			c.Request.Body = io.NopCloser(gzipReader)
+			// Replace the request body with the decompressed data, and enforce a max size (post-decompression).
+			c.Request.Body = wrapMaxBytes(&readCloser{
+				Reader: gzipReader,
+				closeFn: func() error {
+					_ = gzipReader.Close()
+					return origBody.Close()
+				},
+			})
 			c.Request.Header.Del("Content-Encoding")
 		case "br":
-			reader := brotli.NewReader(c.Request.Body)
-			c.Request.Body = io.NopCloser(reader)
+			reader := brotli.NewReader(origBody)
+			c.Request.Body = wrapMaxBytes(&readCloser{
+				Reader: reader,
+				closeFn: func() error {
+					return origBody.Close()
+				},
+			})
 			c.Request.Header.Del("Content-Encoding")
+		default:
+			// Even for uncompressed bodies, enforce a max size to avoid huge request allocations.
+			c.Request.Body = wrapMaxBytes(origBody)
 		}
 
 		// Continue processing the request

+ 1 - 1
relay/channel/aws/constants.go

@@ -18,7 +18,7 @@ var awsModelIDMap = map[string]string{
 	"claude-opus-4-1-20250805":   "anthropic.claude-opus-4-1-20250805-v1:0",
 	"claude-sonnet-4-5-20250929": "anthropic.claude-sonnet-4-5-20250929-v1:0",
 	"claude-haiku-4-5-20251001":  "anthropic.claude-haiku-4-5-20251001-v1:0",
-	"claude-opus-4-5-20251101":  "anthropic.claude-opus-4-5-20251101-v1:0",
+	"claude-opus-4-5-20251101":   "anthropic.claude-opus-4-5-20251101-v1:0",
 	// Nova models
 	"nova-micro-v1:0":   "amazon.nova-micro-v1:0",
 	"nova-lite-v1:0":    "amazon.nova-lite-v1:0",

+ 2 - 2
relay/channel/baidu/relay-baidu.go

@@ -150,7 +150,7 @@ func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
 		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
 	}
 	if baiduResponse.ErrorMsg != "" {
-		return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
+		return types.NewError(fmt.Errorf("%s", baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
 	}
 	fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
 	jsonResponse, err := json.Marshal(fullTextResponse)
@@ -175,7 +175,7 @@ func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht
 		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
 	}
 	if baiduResponse.ErrorMsg != "" {
-		return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
+		return types.NewError(fmt.Errorf("%s", baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
 	}
 	fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
 	jsonResponse, err := json.Marshal(fullTextResponse)

+ 1 - 1
relay/channel/coze/relay-coze.go

@@ -208,7 +208,7 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
 			return
 		}
 
-		common.SysLog(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
+		common.SysLog(fmt.Sprintf("stream event error: %v %v", errorData.Code, errorData.Message))
 	}
 }
 

+ 1 - 1
relay/channel/task/jimeng/adaptor.go

@@ -196,7 +196,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	}
 
 	if jResp.Code != 10000 {
-		taskErr = service.TaskErrorWrapper(fmt.Errorf(jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError)
+		taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError)
 		return
 	}
 

+ 1 - 1
relay/channel/task/kling/adaptor.go

@@ -186,7 +186,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 		return
 	}
 	if kResp.Code != 0 {
-		taskErr = service.TaskErrorWrapperLocal(fmt.Errorf(kResp.Message), "task_failed", http.StatusBadRequest)
+		taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("%s", kResp.Message), "task_failed", http.StatusBadRequest)
 		return
 	}
 	ov := dto.NewOpenAIVideo()

+ 1 - 1
relay/channel/task/suno/adaptor.go

@@ -105,7 +105,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 		return
 	}
 	if !sunoResponse.IsSuccess() {
-		taskErr = service.TaskErrorWrapper(fmt.Errorf(sunoResponse.Message), sunoResponse.Code, http.StatusInternalServerError)
+		taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", sunoResponse.Message), sunoResponse.Code, http.StatusInternalServerError)
 		return
 	}
 

+ 1 - 1
relay/relay_task.go

@@ -196,7 +196,7 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
 	// handle response
 	if resp != nil && resp.StatusCode != http.StatusOK {
 		responseBody, _ := io.ReadAll(resp.Body)
-		taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
+		taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
 		return
 	}
 

+ 3 - 3
setting/system_setting/discord.go

@@ -3,9 +3,9 @@ package system_setting
 import "github.com/QuantumNous/new-api/setting/config"
 
 type DiscordSettings struct {
-	Enabled               bool   `json:"enabled"`
-	ClientId              string `json:"client_id"`
-	ClientSecret          string `json:"client_secret"`
+	Enabled      bool   `json:"enabled"`
+	ClientId     string `json:"client_id"`
+	ClientSecret string `json:"client_secret"`
 }
 
 // 默认配置

+ 8 - 0
types/error.go

@@ -94,6 +94,14 @@ type NewAPIError struct {
 	StatusCode     int
 }
 
+// Unwrap enables errors.Is / errors.As to work with NewAPIError by exposing the underlying error.
+func (e *NewAPIError) Unwrap() error {
+	if e == nil {
+		return nil
+	}
+	return e.Err
+}
+
 func (e *NewAPIError) GetErrorCode() ErrorCode {
 	if e == nil {
 		return ""