Browse Source

ResponseCompression DisposeAsync BasicMiddleware/#247 (#4604)

Chris Ross 7 years ago
parent
commit
a08f4b5a83

+ 3 - 7
src/Middleware/ResponseCompression/src/BodyWrapperStream.cs

@@ -1,4 +1,4 @@
-// Copyright (c) .NET Foundation. All rights reserved.
+// Copyright (c) .NET Foundation. All rights reserved.
 // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
 
 using System;
@@ -39,13 +39,9 @@ namespace Microsoft.AspNetCore.ResponseCompression
             _innerSendFileFeature = innerSendFileFeature;
         }
 
-        protected override void Dispose(bool disposing)
+        internal ValueTask FinishCompressionAsync()
         {
-            if (_compressionStream != null)
-            {
-                _compressionStream.Dispose();
-                _compressionStream = null;
-            }
+            return _compressionStream?.DisposeAsync() ?? new ValueTask();
         }
 
         public override bool CanRead => false;

+ 2 - 4
src/Middleware/ResponseCompression/src/ResponseCompressionMiddleware.cs

@@ -1,4 +1,4 @@
-// Copyright (c) .NET Foundation. All rights reserved.
+// Copyright (c) .NET Foundation. All rights reserved.
 // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
 
 using System;
@@ -68,9 +68,7 @@ namespace Microsoft.AspNetCore.ResponseCompression
             try
             {
                 await _next(context);
-                // This is not disposed via a using statement because we don't want to flush the compression buffer for unhandled exceptions,
-                // that may cause secondary exceptions.
-                bodyWrapperStream.Dispose();
+                await bodyWrapperStream.FinishCompressionAsync();
             }
             finally
             {

+ 145 - 0
src/Middleware/ResponseCompression/test/ResponseCompressionMiddlewareTest.cs

@@ -902,6 +902,57 @@ namespace Microsoft.AspNetCore.ResponseCompression.Tests
             Assert.False(fakeSendFile.Invoked);
         }
 
+        [Theory]
+        [MemberData(nameof(SupportedEncodings))]
+        public async Task Dispose_SyncWriteOrFlushNotCalled(string encoding)
+        {
+            var responseReceived = new ManualResetEvent(false);
+
+            var builder = new WebHostBuilder()
+                .ConfigureServices(services =>
+                {
+                    services.AddResponseCompression();
+                })
+                .Configure(app =>
+                {
+                    app.Use((context, next) =>
+                    {
+                        context.Response.Body = new NoSyncWrapperStream(context.Response.Body);
+                        return next();
+                    });
+                    app.UseResponseCompression();
+                    app.Run(async context =>
+                    {
+                        context.Response.Headers[HeaderNames.ContentMD5] = "MD5";
+                        context.Response.ContentType = TextPlain;
+                        await context.Response.WriteAsync(new string('a', 10));
+                        await context.Response.Body.FlushAsync();
+                        Assert.True(responseReceived.WaitOne(TimeSpan.FromSeconds(3)));
+                        await context.Response.WriteAsync(new string('a', 90));
+                    });
+                });
+
+            var server = new TestServer(builder);
+            var client = server.CreateClient();
+
+            var request = new HttpRequestMessage(HttpMethod.Get, "");
+            request.Headers.AcceptEncoding.ParseAdd(encoding);
+
+            var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
+
+            Assert.False(response.Content.Headers.TryGetValues(HeaderNames.ContentMD5, out _));
+            Assert.Single(response.Content.Headers.ContentEncoding, encoding);
+
+            var body = await response.Content.ReadAsStreamAsync();
+            var read = await body.ReadAsync(new byte[100], 0, 100);
+            Assert.True(read > 0);
+
+            responseReceived.Set();
+
+            read = await body.ReadAsync(new byte[100], 0, 100);
+            Assert.True(read > 0);
+        }
+
         private async Task<(HttpResponseMessage, List<WriteContext>)> InvokeMiddleware(
             int uncompressedBodyLength,
             string[] requestAcceptEncodings,
@@ -1039,5 +1090,99 @@ namespace Microsoft.AspNetCore.ResponseCompression.Tests
 
             public int ExpectedBodyLength { get; }
         }
+
+        private class NoSyncWrapperStream : Stream
+        {
+            private Stream _body;
+
+            public NoSyncWrapperStream(Stream body)
+            {
+                _body = body;
+            }
+
+            public override bool CanRead => _body.CanRead;
+
+            public override bool CanSeek => _body.CanSeek;
+
+            public override bool CanWrite => _body.CanWrite;
+
+            public override long Length => _body.Length;
+
+            public override long Position
+            {
+                get => throw new InvalidOperationException("This shouldn't be called");
+                set => throw new InvalidOperationException("This shouldn't be called");
+            }
+
+            public override void Flush()
+            {
+                throw new InvalidOperationException("This shouldn't be called");
+            }
+
+            public override int Read(byte[] buffer, int offset, int count)
+            {
+                throw new InvalidOperationException("This shouldn't be called");
+            }
+
+            public override long Seek(long offset, SeekOrigin origin)
+            {
+                throw new InvalidOperationException("This shouldn't be called");
+            }
+
+            public override void SetLength(long value)
+            {
+                throw new InvalidOperationException("This shouldn't be called");
+            }
+
+            public override void Write(byte[] buffer, int offset, int count)
+            {
+                throw new InvalidOperationException("This shouldn't be called");
+            }
+
+            public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+            {
+                return _body.WriteAsync(buffer, offset, count, cancellationToken);
+            }
+
+            public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
+            {
+                return _body.WriteAsync(buffer, cancellationToken);
+            }
+
+            public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
+            {
+                return _body.BeginWrite(buffer, offset, count, callback, state);
+            }
+
+            public override void EndWrite(IAsyncResult asyncResult)
+            {
+                _body.EndWrite(asyncResult);
+            }
+
+            public override void Close()
+            {
+                throw new InvalidOperationException("This shouldn't be called");
+            }
+
+            protected override void Dispose(bool disposing)
+            {
+                throw new InvalidOperationException("This shouldn't be called");
+            }
+
+            public override ValueTask DisposeAsync()
+            {
+                return _body.DisposeAsync();
+            }
+
+            public override void CopyTo(Stream destination, int bufferSize)
+            {
+                throw new InvalidOperationException("This shouldn't be called");
+            }
+
+            public override Task FlushAsync(CancellationToken cancellationToken)
+            {
+                return _body.FlushAsync(cancellationToken);
+            }
+        }
     }
 }