Browse Source

Throw BadHttpRequestException when decompressed data exceeds request body size limit (#61812)

feiyun0112 10 months ago
parent
commit
7e446c6adb

+ 3 - 1
src/Middleware/RequestDecompression/src/RequestDecompressionMiddleware.cs

@@ -62,7 +62,9 @@ internal sealed partial class RequestDecompressionMiddleware
                 context.GetEndpoint()?.Metadata?.GetMetadata<IRequestSizeLimitMetadata>()?.MaxRequestBodySize
                     ?? context.Features.Get<IHttpMaxRequestBodySizeFeature>()?.MaxRequestBodySize;
 
-            context.Request.Body = new SizeLimitedStream(decompressionStream, sizeLimit);
+            context.Request.Body = new SizeLimitedStream(decompressionStream, sizeLimit, static (long sizeLimit) => throw new BadHttpRequestException(
+                    $"The decompressed request body is larger than the request body size limit {sizeLimit}.",
+                    StatusCodes.Status413PayloadTooLarge));
             await _next(context);
         }
         finally

+ 4 - 4
src/Middleware/RequestDecompression/test/RequestDecompressionMiddlewareTests.cs

@@ -499,8 +499,8 @@ public class RequestDecompressionMiddlewareTests
         if (exceedsLimit)
         {
             Assert.NotNull(exception);
-            Assert.IsAssignableFrom<InvalidOperationException>(exception);
-            Assert.Equal("The maximum number of bytes have been read.", exception.Message);
+            Assert.IsAssignableFrom<BadHttpRequestException>(exception);
+            Assert.Equal(StatusCodes.Status413PayloadTooLarge, ((BadHttpRequestException)exception).StatusCode);
         }
         else
         {
@@ -583,8 +583,8 @@ public class RequestDecompressionMiddlewareTests
         if (exceedsLimit)
         {
             Assert.NotNull(exception);
-            Assert.IsAssignableFrom<InvalidOperationException>(exception);
-            Assert.Equal("The maximum number of bytes have been read.", exception.Message);
+            Assert.IsAssignableFrom<BadHttpRequestException>(exception);
+            Assert.Equal(StatusCodes.Status413PayloadTooLarge, ((BadHttpRequestException)exception).StatusCode);
         }
         else
         {

+ 21 - 4
src/Shared/SizeLimitedStream.cs

@@ -1,19 +1,22 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
+#nullable enable
+
 internal sealed class SizeLimitedStream : Stream
 {
     private readonly Stream _innerStream;
     private readonly long? _sizeLimit;
-
+    private readonly Action<long>? _handleSizeLimit;
     private long _totalBytesRead;
 
-    public SizeLimitedStream(Stream innerStream, long? sizeLimit)
+    public SizeLimitedStream(Stream innerStream, long? sizeLimit, Action<long>? handleSizeLimit = null)
     {
         ArgumentNullException.ThrowIfNull(innerStream);
 
         _innerStream = innerStream;
         _sizeLimit = sizeLimit;
+        _handleSizeLimit = handleSizeLimit;
     }
 
     public override bool CanRead => _innerStream.CanRead;
@@ -48,7 +51,14 @@ internal sealed class SizeLimitedStream : Stream
         _totalBytesRead += bytesRead;
         if (_totalBytesRead > _sizeLimit)
         {
-            throw new InvalidOperationException("The maximum number of bytes have been read.");
+            if (_handleSizeLimit != null)
+            {
+                _handleSizeLimit(_sizeLimit.Value);
+            }
+            else
+            {
+                throw new InvalidOperationException("The maximum number of bytes have been read.");
+            }
         }
 
         return bytesRead;
@@ -81,7 +91,14 @@ internal sealed class SizeLimitedStream : Stream
         _totalBytesRead += bytesRead;
         if (_totalBytesRead > _sizeLimit)
         {
-            throw new InvalidOperationException("The maximum number of bytes have been read.");
+            if (_handleSizeLimit != null)
+            {
+                _handleSizeLimit(_sizeLimit.Value);
+            }
+            else
+            {
+                throw new InvalidOperationException("The maximum number of bytes have been read.");
+            }
         }
 
         return bytesRead;