Jelajahi Sumber

ResponseCaching: started conversion to pipes (#16961)

* ResponseCaching: started conversion to pipes

* nits

* Use span instead of memory

* CachedResponseBody Tests

* Benchmark

* Reworked benchmark

* Addressed feedback

* Increased timeout
Alessio Franceschelli 6 tahun lalu
induk
melakukan
c848c33cfa

+ 15 - 0
src/Middleware/Middleware.sln

@@ -297,6 +297,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.WebSoc
 EndProject
 Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Perf", "Perf", "{4623F52E-2070-4631-8DEE-7D2F48733FFD}"
 EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.ResponseCaching.Microbenchmarks", "perf\ResponseCaching.Microbenchmarks\Microsoft.AspNetCore.ResponseCaching.Microbenchmarks.csproj", "{80C8E810-1206-482E-BE17-961DD2EBFB11}"
+EndProject
 Global
 	GlobalSection(SolutionConfigurationPlatforms) = preSolution
 		Debug|Any CPU = Debug|Any CPU
@@ -1615,6 +1617,18 @@ Global
 		{C4D624B3-749E-41D8-A43B-B304BC3885EA}.Release|x64.Build.0 = Release|Any CPU
 		{C4D624B3-749E-41D8-A43B-B304BC3885EA}.Release|x86.ActiveCfg = Release|Any CPU
 		{C4D624B3-749E-41D8-A43B-B304BC3885EA}.Release|x86.Build.0 = Release|Any CPU
+		{80C8E810-1206-482E-BE17-961DD2EBFB11}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+		{80C8E810-1206-482E-BE17-961DD2EBFB11}.Debug|Any CPU.Build.0 = Debug|Any CPU
+		{80C8E810-1206-482E-BE17-961DD2EBFB11}.Debug|x64.ActiveCfg = Debug|Any CPU
+		{80C8E810-1206-482E-BE17-961DD2EBFB11}.Debug|x64.Build.0 = Debug|Any CPU
+		{80C8E810-1206-482E-BE17-961DD2EBFB11}.Debug|x86.ActiveCfg = Debug|Any CPU
+		{80C8E810-1206-482E-BE17-961DD2EBFB11}.Debug|x86.Build.0 = Debug|Any CPU
+		{80C8E810-1206-482E-BE17-961DD2EBFB11}.Release|Any CPU.ActiveCfg = Release|Any CPU
+		{80C8E810-1206-482E-BE17-961DD2EBFB11}.Release|Any CPU.Build.0 = Release|Any CPU
+		{80C8E810-1206-482E-BE17-961DD2EBFB11}.Release|x64.ActiveCfg = Release|Any CPU
+		{80C8E810-1206-482E-BE17-961DD2EBFB11}.Release|x64.Build.0 = Release|Any CPU
+		{80C8E810-1206-482E-BE17-961DD2EBFB11}.Release|x86.ActiveCfg = Release|Any CPU
+		{80C8E810-1206-482E-BE17-961DD2EBFB11}.Release|x86.Build.0 = Release|Any CPU
 	EndGlobalSection
 	GlobalSection(SolutionProperties) = preSolution
 		HideSolutionNode = FALSE
@@ -1742,6 +1756,7 @@ Global
 		{92E11EBB-759E-4DA8-AB61-A9977D9F97D0} = {ACA6DDB9-7592-47CE-A740-D15BF307E9E0}
 		{D0CB733B-4CE8-4F6C-BBB9-548EA1A96966} = {D6FA4ABE-E685-4EDD-8B06-D8777E76B472}
 		{C4D624B3-749E-41D8-A43B-B304BC3885EA} = {4623F52E-2070-4631-8DEE-7D2F48733FFD}
+		{80C8E810-1206-482E-BE17-961DD2EBFB11} = {4623F52E-2070-4631-8DEE-7D2F48733FFD}
 	EndGlobalSection
 	GlobalSection(ExtensibilityGlobals) = postSolution
 		SolutionGuid = {83786312-A93B-4BB4-AB06-7C6913A59AFA}

+ 1 - 2
src/Middleware/ResponseCaching/src/CacheEntry/CachedResponse.cs

@@ -2,7 +2,6 @@
 // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
 
 using System;
-using System.IO;
 using Microsoft.AspNetCore.Http;
 
 namespace Microsoft.AspNetCore.ResponseCaching
@@ -15,6 +14,6 @@ namespace Microsoft.AspNetCore.ResponseCaching
 
         public IHeaderDictionary Headers { get; set; }
 
-        public Stream Body { get; set; }
+        public CachedResponseBody Body { get; set; }
     }
 }

+ 49 - 0
src/Middleware/ResponseCaching/src/CacheEntry/CachedResponseBody.cs

@@ -0,0 +1,49 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using System.Collections.Generic;
+using System.IO.Pipelines;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Microsoft.AspNetCore.ResponseCaching
+{
+    internal class CachedResponseBody
+    {
+        public CachedResponseBody(List<byte[]> segments, long length)
+        {
+            Segments = segments;
+            Length = length;
+        }
+
+        public List<byte[]> Segments { get; }
+
+        public long Length { get; }
+
+        public async Task CopyToAsync(PipeWriter destination, CancellationToken cancellationToken)
+        {
+            if (destination == null)
+            {
+                throw new ArgumentNullException(nameof(destination));
+            }
+
+            foreach (var segment in Segments)
+            {
+                cancellationToken.ThrowIfCancellationRequested();
+
+                Copy(segment, destination);
+
+                await destination.FlushAsync();
+            }
+        }
+
+        private static void Copy(byte[] segment, PipeWriter destination)
+        {
+            var span = destination.GetSpan(segment.Length);
+
+            segment.CopyTo(span);
+            destination.Advance(segment.Length);
+        }
+    }
+}

+ 2 - 4
src/Middleware/ResponseCaching/src/MemoryCachedResponse.cs

@@ -1,4 +1,4 @@
-// Copyright (c) .NET Foundation. All rights reserved.
+// Copyright (c) .NET Foundation. All rights reserved.
 // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
 
 using System;
@@ -15,8 +15,6 @@ namespace Microsoft.AspNetCore.ResponseCaching
 
         public IHeaderDictionary Headers { get; set; } = new HeaderDictionary();
 
-        public List<byte[]> BodySegments { get; set; }
-
-        public long BodyLength { get; set; }
+        public CachedResponseBody Body { get; set; }
     }
 }

+ 3 - 6
src/Middleware/ResponseCaching/src/MemoryResponseCache.cs

@@ -2,6 +2,7 @@
 // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
 
 using System;
+using System.Buffers;
 using System.Threading.Tasks;
 using Microsoft.Extensions.Caching.Memory;
 
@@ -27,7 +28,7 @@ namespace Microsoft.AspNetCore.ResponseCaching
                     Created = memoryCachedResponse.Created,
                     StatusCode = memoryCachedResponse.StatusCode,
                     Headers = memoryCachedResponse.Headers,
-                    Body = new SegmentReadStream(memoryCachedResponse.BodySegments, memoryCachedResponse.BodyLength)
+                    Body = memoryCachedResponse.Body
                 };
             }
             else
@@ -40,9 +41,6 @@ namespace Microsoft.AspNetCore.ResponseCaching
         {
             if (entry is CachedResponse cachedResponse)
             {
-                var segmentStream = new SegmentWriteStream(StreamUtilities.BodySegmentSize);
-                cachedResponse.Body.CopyTo(segmentStream);
-
                 _cache.Set(
                     key,
                     new MemoryCachedResponse
@@ -50,8 +48,7 @@ namespace Microsoft.AspNetCore.ResponseCaching
                         Created = cachedResponse.Created,
                         StatusCode = cachedResponse.StatusCode,
                         Headers = cachedResponse.Headers,
-                        BodySegments = segmentStream.GetSegments(),
-                        BodyLength = segmentStream.Length
+                        Body = cachedResponse.Body
                     },
                     new MemoryCacheEntryOptions
                     {

+ 6 - 6
src/Middleware/ResponseCaching/src/ResponseCachingMiddleware.cs

@@ -192,7 +192,7 @@ namespace Microsoft.AspNetCore.ResponseCaching
                     {
                         try
                         {
-                            await body.CopyToAsync(response.Body, StreamUtilities.BodySegmentSize, context.HttpContext.RequestAborted);
+                            await body.CopyToAsync(response.BodyWriter, context.HttpContext.RequestAborted);
                         }
                         catch (OperationCanceledException)
                         {
@@ -343,19 +343,19 @@ namespace Microsoft.AspNetCore.ResponseCaching
             if (context.ShouldCacheResponse && context.ResponseCachingStream.BufferingEnabled)
             {
                 var contentLength = context.HttpContext.Response.ContentLength;
-                var bufferStream = context.ResponseCachingStream.GetBufferStream();
-                if (!contentLength.HasValue || contentLength == bufferStream.Length
-                    || (bufferStream.Length == 0
+                var cachedResponseBody = context.ResponseCachingStream.GetCachedResponseBody();
+                if (!contentLength.HasValue || contentLength == cachedResponseBody.Length
+                    || (cachedResponseBody.Length == 0
                         && HttpMethods.IsHead(context.HttpContext.Request.Method)))
                 {
                     var response = context.HttpContext.Response;
                     // Add a content-length if required
                     if (!response.ContentLength.HasValue && StringValues.IsNullOrEmpty(response.Headers[HeaderNames.TransferEncoding]))
                     {
-                        context.CachedResponse.Headers[HeaderNames.ContentLength] = HeaderUtilities.FormatNonNegativeInt64(bufferStream.Length);
+                        context.CachedResponse.Headers[HeaderNames.ContentLength] = HeaderUtilities.FormatNonNegativeInt64(cachedResponseBody.Length);
                     }
 
-                    context.CachedResponse.Body = bufferStream;
+                    context.CachedResponse.Body = cachedResponseBody;
                     _logger.ResponseCached();
                     _cache.Set(context.StorageVaryKey ?? context.BaseKey, context.CachedResponse, context.CachedResponseValidFor);
                 }

+ 2 - 2
src/Middleware/ResponseCaching/src/Streams/ResponseCachingStream.cs

@@ -45,13 +45,13 @@ namespace Microsoft.AspNetCore.ResponseCaching
             }
         }
 
-        internal Stream GetBufferStream()
+        internal CachedResponseBody GetCachedResponseBody()
         {
             if (!BufferingEnabled)
             {
                 throw new InvalidOperationException("Buffer stream cannot be retrieved since buffering is disabled.");
             }
-            return new SegmentReadStream(_segmentWriteStream.GetSegments(), _segmentWriteStream.Length);
+            return new CachedResponseBody(_segmentWriteStream.GetSegments(), _segmentWriteStream.Length);
         }
 
         internal void DisableBuffering()

+ 0 - 225
src/Middleware/ResponseCaching/src/Streams/SegmentReadStream.cs

@@ -1,225 +0,0 @@
-// Copyright (c) .NET Foundation. All rights reserved.
-// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
-
-using System;
-using System.Collections.Generic;
-using System.IO;
-using System.Threading;
-using System.Threading.Tasks;
-
-namespace Microsoft.AspNetCore.ResponseCaching
-{
-    internal class SegmentReadStream : Stream
-    {
-        private readonly List<byte[]> _segments;
-        private readonly long _length;
-        private int _segmentIndex;
-        private int _segmentOffset;
-        private long _position;
-
-        internal SegmentReadStream(List<byte[]> segments, long length)
-        {
-            _segments = segments ?? throw new ArgumentNullException(nameof(segments));
-            _length = length;
-        }
-
-        public override bool CanRead => true;
-
-        public override bool CanSeek => true;
-
-        public override bool CanWrite => false;
-
-        public override long Length => _length;
-
-        public override long Position
-        {
-            get
-            {
-                return _position;
-            }
-            set
-            {
-                // The stream only supports a full rewind. This will need an update if random access becomes a required feature.
-                if (value != 0)
-                {
-                    throw new ArgumentOutOfRangeException(nameof(value), value, $"{nameof(Position)} can only be set to 0.");
-                }
-
-                _position = 0;
-                _segmentOffset = 0;
-                _segmentIndex = 0;
-            }
-        }
-
-        public override void Flush()
-        {
-            throw new NotSupportedException("The stream does not support writing.");
-        }
-
-        public override int Read(byte[] buffer, int offset, int count)
-        {
-            if (buffer == null)
-            {
-                throw new ArgumentNullException(nameof(buffer));
-            }
-            if (offset < 0)
-            {
-                throw new ArgumentOutOfRangeException(nameof(offset), offset, "Non-negative number required.");
-            }
-            // Read of length 0 will return zero and indicate end of stream.
-            if (count <= 0 )
-            {
-                throw new ArgumentOutOfRangeException(nameof(count), count, "Positive number required.");
-            }
-            if (count > buffer.Length - offset)
-            {
-                throw new ArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection.");
-            }
-
-            if (_segmentIndex == _segments.Count)
-            {
-                return 0;
-            }
-
-            var bytesRead = 0;
-            while (count > 0)
-            {
-                if (_segmentOffset == _segments[_segmentIndex].Length)
-                {
-                    // Move to the next segment
-                    _segmentIndex++;
-                    _segmentOffset = 0;
-
-                    if (_segmentIndex == _segments.Count)
-                    {
-                        break;
-                    }
-                }
-
-                // Read up to the end of the segment
-                var segmentBytesRead = Math.Min(count, _segments[_segmentIndex].Length - _segmentOffset);
-                Buffer.BlockCopy(_segments[_segmentIndex], _segmentOffset, buffer, offset, segmentBytesRead);
-                bytesRead += segmentBytesRead;
-                _segmentOffset += segmentBytesRead;
-                _position += segmentBytesRead;
-                offset += segmentBytesRead;
-                count -= segmentBytesRead;
-            }
-
-            return bytesRead;
-        }
-
-        public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
-        {
-            return Task.FromResult(Read(buffer, offset, count));
-        }
-
-        public override int ReadByte()
-        {
-            if (Position == Length)
-            {
-                return -1;
-            }
-
-            if (_segmentOffset == _segments[_segmentIndex].Length)
-            {
-                // Move to the next segment
-                _segmentIndex++;
-                _segmentOffset = 0;
-            }
-
-            var byteRead = _segments[_segmentIndex][_segmentOffset];
-            _segmentOffset++;
-            _position++;
-
-            return byteRead;
-        }
-
-        public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
-        {
-            var tcs = new TaskCompletionSource<int>(state);
-
-            try
-            {
-                tcs.TrySetResult(Read(buffer, offset, count));
-            }
-            catch (Exception ex)
-            {
-                tcs.TrySetException(ex);
-            }
-
-            if (callback != null)
-            {
-                // Offload callbacks to avoid stack dives on sync completions.
-                var ignored = Task.Run(() =>
-                {
-                    try
-                    {
-                        callback(tcs.Task);
-                    }
-                    catch (Exception)
-                    {
-                        // Suppress exceptions on background threads.
-                    }
-                });
-            }
-
-            return tcs.Task;
-        }
-
-        public override int EndRead(IAsyncResult asyncResult)
-        {
-            if (asyncResult == null)
-            {
-                throw new ArgumentNullException(nameof(asyncResult));
-            }
-            return ((Task<int>)asyncResult).GetAwaiter().GetResult();
-        }
-
-        public override long Seek(long offset, SeekOrigin origin)
-        {
-            // The stream only supports a full rewind. This will need an update if random access becomes a required feature.
-            if (origin != SeekOrigin.Begin)
-            {
-                throw new ArgumentException(nameof(origin), $"{nameof(Seek)} can only be set to {nameof(SeekOrigin.Begin)}.");
-            }
-            if (offset != 0)
-            {
-                throw new ArgumentOutOfRangeException(nameof(offset), offset, $"{nameof(Seek)} can only be set to 0.");
-            }
-
-            Position = 0;
-            return Position;
-        }
-
-        public override void SetLength(long value)
-        {
-            throw new NotSupportedException("The stream does not support writing.");
-        }
-
-        public override void Write(byte[] buffer, int offset, int count)
-        {
-            throw new NotSupportedException("The stream does not support writing.");
-        }
-
-        public override async Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
-        {
-            if (destination == null)
-            {
-                throw new ArgumentNullException(nameof(destination));
-            }
-            if (!destination.CanWrite)
-            {
-                throw new NotSupportedException("The destination stream does not support writing.");
-            }
-
-            for (; _segmentIndex < _segments.Count; _segmentIndex++, _segmentOffset = 0)
-            {
-                cancellationToken.ThrowIfCancellationRequested();
-                var bytesCopied = _segments[_segmentIndex].Length - _segmentOffset;
-                await destination.WriteAsync(_segments[_segmentIndex], _segmentOffset, bytesCopied, cancellationToken);
-                _position += bytesCopied;
-            }
-        }
-    }
-}

+ 128 - 0
src/Middleware/ResponseCaching/test/CachedResponseBodyTests.cs

@@ -0,0 +1,128 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System.Buffers;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.IO.Pipelines;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using Xunit;
+
+namespace Microsoft.AspNetCore.ResponseCaching.Tests
+{
+    public class CachedResponseBodyTests
+    {
+        private readonly int _timeout = Debugger.IsAttached ? -1 : 5000;
+
+        [Fact]
+        public void GetSegments()
+        {
+            var segments = new List<byte[]>();
+            var body = new CachedResponseBody(segments, 0);
+
+            Assert.Same(segments, body.Segments);
+        }
+
+        [Fact]
+        public void GetLength()
+        {
+            var segments = new List<byte[]>();
+            var body = new CachedResponseBody(segments, 42);
+
+            Assert.Equal(42, body.Length);
+        }
+
+        [Fact]
+        public async Task Copy_DoNothingWhenNoSegments()
+        {
+            var segments = new List<byte[]>();
+            var receivedSegments = new List<byte[]>();
+            var body = new CachedResponseBody(segments, 0);
+
+            var pipe = new Pipe();
+            using var cts = new CancellationTokenSource(_timeout);
+
+            var receiverTask = ReceiveDataAsync(pipe.Reader, receivedSegments, cts.Token);
+            var copyTask = body.CopyToAsync(pipe.Writer, cts.Token).ContinueWith(_ => pipe.Writer.CompleteAsync());
+
+            await Task.WhenAll(receiverTask, copyTask);
+
+            Assert.Empty(receivedSegments);
+        }
+
+        [Fact]
+        public async Task Copy_SingleSegment()
+        {
+            var segments = new List<byte[]>
+            {
+                new byte[] { 1 }
+            };
+            var receivedSegments = new List<byte[]>();
+            var body = new CachedResponseBody(segments, 0);
+
+            var pipe = new Pipe();
+
+            using var cts = new CancellationTokenSource(_timeout);
+
+            var receiverTask = ReceiveDataAsync(pipe.Reader, receivedSegments, cts.Token);
+            var copyTask = CopyDataAsync(body, pipe.Writer, cts.Token);
+
+            await Task.WhenAll(receiverTask, copyTask);
+
+            Assert.Equal(segments, receivedSegments);
+        }
+
+        [Fact]
+        public async Task Copy_MultipleSegments()
+        {
+            var segments = new List<byte[]>
+            {
+                new byte[] { 1 },
+                new byte[] { 2, 3 }
+            };
+            var receivedSegments = new List<byte[]>();
+            var body = new CachedResponseBody(segments, 0);
+
+            var pipe = new Pipe();
+
+            using var cts = new CancellationTokenSource(_timeout);
+
+            var receiverTask = ReceiveDataAsync(pipe.Reader, receivedSegments, cts.Token);
+            var copyTask = CopyDataAsync(body, pipe.Writer, cts.Token);
+
+            await Task.WhenAll(receiverTask, copyTask);
+
+            Assert.Equal(new byte[] { 1, 2, 3 }, receivedSegments.SelectMany(x => x).ToArray());
+        }
+
+        async Task CopyDataAsync(CachedResponseBody body, PipeWriter writer, CancellationToken cancellationToken)
+        {
+            await body.CopyToAsync(writer, cancellationToken);
+            await writer.CompleteAsync();
+        }
+
+        async Task ReceiveDataAsync(PipeReader reader, List<byte[]> receivedSegments, CancellationToken cancellationToken)
+        {
+            while (true)
+            {
+                var result = await reader.ReadAsync(cancellationToken);
+                var buffer = result.Buffer;
+
+                foreach(var memory in buffer)
+                {
+                    receivedSegments.Add(memory.ToArray());
+                }
+
+                reader.AdvanceTo(buffer.End, buffer.End);
+
+                if (result.IsCompleted)
+                {
+                    break;
+                }
+            }
+            await reader.CompleteAsync();
+        }
+    }
+}

+ 4 - 4
src/Middleware/ResponseCaching/test/ResponseCachingMiddlewareTests.cs

@@ -63,7 +63,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests
                 new CachedResponse()
                 {
                     Headers = new HeaderDictionary(),
-                    Body = new SegmentReadStream(new List<byte[]>(0), 0)
+                    Body = new CachedResponseBody(new List<byte[]>(0), 0)
                 },
                 TimeSpan.Zero);
 
@@ -91,7 +91,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests
                     {
                         { "MyHeader", "NewValue" }
                     },
-                    Body = new SegmentReadStream(new List<byte[]>(0), 0)
+                    Body = new CachedResponseBody(new List<byte[]>(0), 0)
                 },
                 TimeSpan.Zero);
 
@@ -140,7 +140,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests
                 new CachedResponse()
                 {
                     Headers = new HeaderDictionary(),
-                    Body = new SegmentReadStream(new List<byte[]>(0), 0)
+                    Body = new CachedResponseBody(new List<byte[]>(0), 0)
                 },
                 TimeSpan.Zero);
 
@@ -164,7 +164,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests
                 "BaseKey",
                 new CachedResponse()
                 {
-                    Body = new SegmentReadStream(new List<byte[]>(0), 0)
+                    Body = new CachedResponseBody(new List<byte[]>(0), 0)
                 },
                 TimeSpan.Zero);
 

+ 0 - 284
src/Middleware/ResponseCaching/test/SegmentReadStreamTests.cs

@@ -1,284 +0,0 @@
-// Copyright (c) .NET Foundation. All rights reserved.
-// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
-
-using System;
-using System.Collections.Generic;
-using System.IO;
-using System.Linq;
-using Xunit;
-
-namespace Microsoft.AspNetCore.ResponseCaching.Tests
-{
-    public class SegmentReadStreamTests
-    {
-        public class TestStreamInitInfo
-        {
-            internal List<byte[]> Segments { get; set; }
-            internal int SegmentSize { get; set; }
-            internal long Length { get; set; }
-        }
-
-        public static TheoryData<TestStreamInitInfo> TestStreams
-        {
-            get
-            {
-                return new TheoryData<TestStreamInitInfo>
-                {
-                    // Partial Segment
-                    new TestStreamInitInfo()
-                    {
-                        Segments = new List<byte[]>(new[]
-                        {
-                            new byte[] { 0, 1, 2, 3, 4 },
-                            new byte[] { 5, 6, 7, 8, 9 },
-                            new byte[] { 10, 11, 12 },
-                        }),
-                        SegmentSize = 5,
-                        Length = 13
-                    },
-                    // Full Segments
-                    new TestStreamInitInfo()
-                    {
-                        Segments = new List<byte[]>(new[]
-                        {
-                            new byte[] { 0, 1, 2, 3, 4 },
-                            new byte[] { 5, 6, 7, 8, 9 },
-                            new byte[] { 10, 11, 12, 13, 14 },
-                        }),
-                        SegmentSize = 5,
-                        Length = 15
-                    }
-                };
-            }
-        }
-
-        [Fact]
-        public void SegmentReadStream_NullSegments_Throws()
-        {
-            Assert.Throws<ArgumentNullException>(() => new SegmentReadStream(null, 0));
-        }
-
-        [Fact]
-        public void Position_ResetToZero_Succeeds()
-        {
-            var stream = new SegmentReadStream(new List<byte[]>(), 0);
-
-            // This should not throw
-            stream.Position = 0;
-        }
-
-        [Theory]
-        [InlineData(1)]
-        [InlineData(-1)]
-        [InlineData(100)]
-        [InlineData(long.MaxValue)]
-        [InlineData(long.MinValue)]
-        public void Position_SetToNonZero_Throws(long position)
-        {
-            var stream = new SegmentReadStream(new List<byte[]>(new[] { new byte[100] }), 100);
-
-            Assert.Throws<ArgumentOutOfRangeException>(() => stream.Position = position);
-        }
-
-        [Fact]
-        public void WriteOperations_Throws()
-        {
-            var stream = new SegmentReadStream(new List<byte[]>(), 0);
-
-
-            Assert.Throws<NotSupportedException>(() => stream.Flush());
-            Assert.Throws<NotSupportedException>(() => stream.Write(new byte[1], 0, 0));
-        }
-
-        [Fact]
-        public void SetLength_Throws()
-        {
-            var stream = new SegmentReadStream(new List<byte[]>(), 0);
-
-            Assert.Throws<NotSupportedException>(() => stream.SetLength(0));
-        }
-
-        [Theory]
-        [InlineData(SeekOrigin.Current)]
-        [InlineData(SeekOrigin.End)]
-        public void Seek_NotBegin_Throws(SeekOrigin origin)
-        {
-            var stream = new SegmentReadStream(new List<byte[]>(), 0);
-
-            Assert.Throws<ArgumentException>(() => stream.Seek(0, origin));
-        }
-
-        [Theory]
-        [InlineData(1)]
-        [InlineData(-1)]
-        [InlineData(100)]
-        [InlineData(long.MaxValue)]
-        [InlineData(long.MinValue)]
-        public void Seek_NotZero_Throws(long offset)
-        {
-            var stream = new SegmentReadStream(new List<byte[]>(), 0);
-
-            Assert.Throws<ArgumentOutOfRangeException>(() => stream.Seek(offset, SeekOrigin.Begin));
-        }
-
-        [Theory]
-        [MemberData(nameof(TestStreams))]
-        public void ReadByte_CanReadAllBytes(TestStreamInitInfo info)
-        {
-            var stream = new SegmentReadStream(info.Segments, info.Length);
-
-            for (var i = 0; i < stream.Length; i++)
-            {
-                Assert.Equal(i, stream.Position);
-                Assert.Equal(i, stream.ReadByte());
-            }
-            Assert.Equal(stream.Length, stream.Position);
-            Assert.Equal(-1, stream.ReadByte());
-            Assert.Equal(stream.Length, stream.Position);
-        }
-
-        [Theory]
-        [MemberData(nameof(TestStreams))]
-        public void Read_CountLessThanSegmentSize_CanReadAllBytes(TestStreamInitInfo info)
-        {
-            var stream = new SegmentReadStream(info.Segments, info.Length);
-            var count = info.SegmentSize - 1;
-
-            for (var i = 0; i < stream.Length; i+=count)
-            {
-                var output = new byte[count];
-                var expectedOutput = new byte[count];
-                var expectedBytesRead = Math.Min(count, stream.Length - i);
-                for (var j = 0; j < expectedBytesRead; j++)
-                {
-                    expectedOutput[j] = (byte)(i + j);
-                }
-                Assert.Equal(i, stream.Position);
-                Assert.Equal(expectedBytesRead, stream.Read(output, 0, count));
-                Assert.True(expectedOutput.SequenceEqual(output));
-            }
-            Assert.Equal(stream.Length, stream.Position);
-            Assert.Equal(0, stream.Read(new byte[count], 0, count));
-            Assert.Equal(stream.Length, stream.Position);
-        }
-
-        [Theory]
-        [MemberData(nameof(TestStreams))]
-        public void Read_CountEqualSegmentSize_CanReadAllBytes(TestStreamInitInfo info)
-        {
-            var stream = new SegmentReadStream(info.Segments, info.Length);
-            var count = info.SegmentSize;
-
-            for (var i = 0; i < stream.Length; i += count)
-            {
-                var output = new byte[count];
-                var expectedOutput = new byte[count];
-                var expectedBytesRead = Math.Min(count, stream.Length - i);
-                for (var j = 0; j < expectedBytesRead; j++)
-                {
-                    expectedOutput[j] = (byte)(i + j);
-                }
-                Assert.Equal(i, stream.Position);
-                Assert.Equal(expectedBytesRead, stream.Read(output, 0, count));
-                Assert.True(expectedOutput.SequenceEqual(output));
-            }
-            Assert.Equal(stream.Length, stream.Position);
-            Assert.Equal(0, stream.Read(new byte[count], 0, count));
-            Assert.Equal(stream.Length, stream.Position);
-        }
-
-        [Theory]
-        [MemberData(nameof(TestStreams))]
-        public void Read_CountGreaterThanSegmentSize_CanReadAllBytes(TestStreamInitInfo info)
-        {
-            var stream = new SegmentReadStream(info.Segments, info.Length);
-            var count = info.SegmentSize + 1;
-
-            for (var i = 0; i < stream.Length; i += count)
-            {
-                var output = new byte[count];
-                var expectedOutput = new byte[count];
-                var expectedBytesRead = Math.Min(count, stream.Length - i);
-                for (var j = 0; j < expectedBytesRead; j++)
-                {
-                    expectedOutput[j] = (byte)(i + j);
-                }
-                Assert.Equal(i, stream.Position);
-                Assert.Equal(expectedBytesRead, stream.Read(output, 0, count));
-                Assert.True(expectedOutput.SequenceEqual(output));
-            }
-            Assert.Equal(stream.Length, stream.Position);
-            Assert.Equal(0, stream.Read(new byte[count], 0, count));
-            Assert.Equal(stream.Length, stream.Position);
-        }
-
-        [Theory]
-        [MemberData(nameof(TestStreams))]
-        public void CopyToAsync_CopiesAllBytes(TestStreamInitInfo info)
-        {
-            var stream = new SegmentReadStream(info.Segments, info.Length);
-            var writeStream = new SegmentWriteStream(info.SegmentSize);
-
-            stream.CopyTo(writeStream);
-
-            Assert.Equal(stream.Length, stream.Position);
-            Assert.Equal(stream.Length, writeStream.Length);
-            var writeSegments = writeStream.GetSegments();
-            for (var i = 0; i < info.Segments.Count; i++)
-            {
-                Assert.True(writeSegments[i].SequenceEqual(info.Segments[i]));
-            }
-        }
-
-        [Theory]
-        [MemberData(nameof(TestStreams))]
-        public void CopyToAsync_CopiesFromCurrentPosition(TestStreamInitInfo info)
-        {
-            var skippedBytes = info.SegmentSize;
-            var writeStream = new SegmentWriteStream((int)info.Length);
-            var stream = new SegmentReadStream(info.Segments, info.Length);
-            stream.Read(new byte[skippedBytes], 0, skippedBytes);
-
-            stream.CopyTo(writeStream);
-
-            Assert.Equal(stream.Length, stream.Position);
-            Assert.Equal(stream.Length - skippedBytes, writeStream.Length);
-            var writeSegments = writeStream.GetSegments();
-
-            for (var i = skippedBytes; i < info.Length; i++)
-            {
-                Assert.Equal(info.Segments[i / info.SegmentSize][i % info.SegmentSize], writeSegments[0][i - skippedBytes]);
-            }
-        }
-
-        [Theory]
-        [MemberData(nameof(TestStreams))]
-        public void CopyToAsync_CopiesFromStart_AfterReset(TestStreamInitInfo info)
-        {
-            var skippedBytes = info.SegmentSize;
-            var writeStream = new SegmentWriteStream(info.SegmentSize);
-            var stream = new SegmentReadStream(info.Segments, info.Length);
-            stream.Read(new byte[skippedBytes], 0, skippedBytes);
-
-            stream.CopyTo(writeStream);
-
-            // Assert bytes read from current location to the end
-            Assert.Equal(stream.Length, stream.Position);
-            Assert.Equal(stream.Length - skippedBytes, writeStream.Length);
-
-            // Reset
-            stream.Position = 0;
-            writeStream = new SegmentWriteStream(info.SegmentSize);
-
-            stream.CopyTo(writeStream);
-
-            Assert.Equal(stream.Length, stream.Position);
-            Assert.Equal(stream.Length, writeStream.Length);
-            var writeSegments = writeStream.GetSegments();
-            for (var i = 0; i < info.Segments.Count; i++)
-            {
-                Assert.True(writeSegments[i].SequenceEqual(info.Segments[i]));
-            }
-        }
-    }
-}

+ 1 - 0
src/Middleware/perf/ResponseCaching.Microbenchmarks/AssemblyInfo.cs

@@ -0,0 +1 @@
+[assembly: BenchmarkDotNet.Attributes.AspNetCoreBenchmark]

+ 13 - 0
src/Middleware/perf/ResponseCaching.Microbenchmarks/Microsoft.AspNetCore.ResponseCaching.Microbenchmarks.csproj

@@ -0,0 +1,13 @@
+<Project Sdk="Microsoft.NET.Sdk">
+
+  <PropertyGroup>
+    <OutputType>Exe</OutputType>
+    <TargetFramework>$(DefaultNetCoreTargetFramework)</TargetFramework>
+  </PropertyGroup>
+
+  <ItemGroup>
+    <Reference Include="BenchmarkDotNet" />
+    <Reference Include="Microsoft.AspNetCore.BenchmarkRunner.Sources" />
+    <Reference Include="Microsoft.AspNetCore.ResponseCaching" />
+  </ItemGroup>
+</Project>

+ 144 - 0
src/Middleware/perf/ResponseCaching.Microbenchmarks/ResponseCachingBenchmark.cs

@@ -0,0 +1,144 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using System.IO;
+using System.IO.Pipelines;
+using System.Threading;
+using System.Threading.Tasks;
+using BenchmarkDotNet.Attributes;
+using Microsoft.AspNetCore.Http;
+using Microsoft.AspNetCore.Http.Features;
+using Microsoft.AspNetCore.ResponseCaching;
+using Microsoft.Extensions.Logging.Abstractions;
+using Microsoft.Extensions.ObjectPool;
+using Microsoft.Extensions.Options;
+using Microsoft.Net.Http.Headers;
+
+namespace Microsoft.AspNetCore.WebSockets.Microbenchmarks
+{
+    public class ResponseCachingBenchmark
+    {
+        private static readonly string _cacheControl = $"{CacheControlHeaderValue.PublicString}, {CacheControlHeaderValue.MaxAgeString}={int.MaxValue}";
+
+        private ResponseCachingMiddleware _middleware;
+        private readonly byte[] _data = new byte[1 * 1024 * 1024];
+
+        [Params(
+            100,
+            64 * 1024,
+            1 * 1024 * 1024
+        )]
+        public int Size { get; set; }
+
+        [GlobalSetup]
+        public void Setup()
+        {
+            _middleware = new ResponseCachingMiddleware(
+                    async context => {
+                        context.Response.Headers[HeaderNames.CacheControl] = _cacheControl;
+                        await context.Response.BodyWriter.WriteAsync(new ReadOnlyMemory<byte>(_data, 0, Size));
+                    },
+                    Options.Create(new ResponseCachingOptions
+                    {
+                        SizeLimit = int.MaxValue, // ~2GB
+                        MaximumBodySize = 1 * 1024 * 1024,
+                    }),
+                    NullLoggerFactory.Instance,
+                    new DefaultObjectPoolProvider()
+                );
+
+            // no need to actually cache as there is a warm-up fase
+        }
+
+        [Benchmark]
+        public async Task Cache()
+        {
+            var pipe = new Pipe();
+            var consumer = ConsumeAsync(pipe.Reader, CancellationToken.None);
+            DefaultHttpContext context = CreateHttpContext(pipe);
+            context.Request.Method = HttpMethods.Get;
+            context.Request.Path = "/a";
+
+            // don't serve from cache but store result
+            context.Request.Headers[HeaderNames.CacheControl] = CacheControlHeaderValue.NoCacheString;
+
+            await _middleware.Invoke(context);
+
+            await pipe.Writer.CompleteAsync();
+            await consumer;
+        }
+
+        [Benchmark]
+        public async Task ServeFromCache()
+        {
+            var pipe = new Pipe();
+            var consumer = ConsumeAsync(pipe.Reader, CancellationToken.None);
+            DefaultHttpContext context = CreateHttpContext(pipe);
+            context.Request.Method = HttpMethods.Get;
+            context.Request.Path = "/b";
+
+            await _middleware.Invoke(context);
+
+            await pipe.Writer.CompleteAsync();
+            await consumer;
+        }
+
+        private static DefaultHttpContext CreateHttpContext(Pipe pipe)
+        {
+            var features = new FeatureCollection();
+            features.Set<IHttpRequestFeature>(new HttpRequestFeature());
+            features.Set<IHttpResponseFeature>(new HttpResponseFeature());
+            features.Set<IHttpResponseBodyFeature>(new PipeResponseBodyFeature(pipe.Writer));
+            var context = new DefaultHttpContext(features);
+            return context;
+        }
+
+        private async ValueTask ConsumeAsync(PipeReader reader, CancellationToken cancellationToken)
+        {
+            while (true)
+            {
+                var result = await reader.ReadAsync(cancellationToken);
+                var buffer = result.Buffer;
+
+                reader.AdvanceTo(buffer.End, buffer.End);
+
+                if (result.IsCompleted)
+                {
+                    break;
+                }
+            }
+
+            await reader.CompleteAsync();
+        }
+
+        private class PipeResponseBodyFeature : IHttpResponseBodyFeature
+        {
+            public PipeResponseBodyFeature(PipeWriter pipeWriter)
+            {
+                Writer = pipeWriter;
+            }
+
+            public Stream Stream => Writer.AsStream();
+
+            public PipeWriter Writer { get; }
+
+            public Task CompleteAsync() => Writer.CompleteAsync().AsTask();
+
+            public void DisableBuffering()
+            {
+                throw new NotImplementedException();
+            }
+
+            public Task SendFileAsync(string path, long offset, long? count, CancellationToken cancellationToken = default)
+            {
+                throw new NotImplementedException();
+            }
+
+            public Task StartAsync(CancellationToken cancellationToken = default)
+            {
+                throw new NotImplementedException();
+            }
+        }
+    }
+}