Просмотр исходного кода

Improve TestServer support for Response.StartAsync (#10189)

James Newton-King 6 лет назад
Родитель
Сommit
d5207af367

+ 8 - 4
src/Hosting/TestHost/src/HttpContextBuilder.cs

@@ -20,10 +20,11 @@ namespace Microsoft.AspNetCore.TestHost
         
         private readonly TaskCompletionSource<HttpContext> _responseTcs = new TaskCompletionSource<HttpContext>(TaskCreationOptions.RunContinuationsAsynchronously);
         private readonly ResponseStream _responseStream;
-        private readonly ResponseFeature _responseFeature = new ResponseFeature();
+        private readonly ResponseFeature _responseFeature;
         private readonly RequestLifetimeFeature _requestLifetimeFeature = new RequestLifetimeFeature();
         private readonly ResponseTrailersFeature _responseTrailersFeature = new ResponseTrailersFeature();
         private bool _pipelineFinished;
+        private bool _returningResponse;
         private Context _testContext;
         private Action<HttpContext> _responseReadCompleteCallback;
 
@@ -33,6 +34,7 @@ namespace Microsoft.AspNetCore.TestHost
             AllowSynchronousIO = allowSynchronousIO;
             _preserveExecutionContext = preserveExecutionContext;
             _httpContext = new DefaultHttpContext();
+            _responseFeature = new ResponseFeature(Abort);
 
             var request = _httpContext.Request;
             request.Protocol = "HTTP/1.1";
@@ -40,6 +42,7 @@ namespace Microsoft.AspNetCore.TestHost
 
             _httpContext.Features.Set<IHttpBodyControlFeature>(this);
             _httpContext.Features.Set<IHttpResponseFeature>(_responseFeature);
+            _httpContext.Features.Set<IHttpResponseStartFeature>(_responseFeature);
             _httpContext.Features.Set<IHttpRequestLifetimeFeature>(_requestLifetimeFeature);
             _httpContext.Features.Set<IHttpResponseTrailersFeature>(_responseTrailersFeature);
 
@@ -132,12 +135,13 @@ namespace Microsoft.AspNetCore.TestHost
 
         internal async Task ReturnResponseMessageAsync()
         {
-            // Check if the response has already started because the TrySetResult below could happen a bit late
+            // Check if the response is already returning because the TrySetResult below could happen a bit late
             // (as it happens on a different thread) by which point the CompleteResponseAsync could run and calls this
             // method again.
-            if (!_responseFeature.HasStarted)
+            if (!_returningResponse)
             {
-                // Sets HasStarted
+                _returningResponse = true;
+
                 try
                 {
                     await _responseFeature.FireOnSendingHeadersAsync();

+ 32 - 6
src/Hosting/TestHost/src/ResponseFeature.cs

@@ -3,21 +3,24 @@
 
 using System;
 using System.IO;
+using System.Threading;
 using System.Threading.Tasks;
 using Microsoft.AspNetCore.Http;
 using Microsoft.AspNetCore.Http.Features;
 
 namespace Microsoft.AspNetCore.TestHost
 {
-    internal class ResponseFeature : IHttpResponseFeature
+    internal class ResponseFeature : IHttpResponseFeature, IHttpResponseStartFeature
     {
+        private readonly HeaderDictionary _headers = new HeaderDictionary();
+        private readonly Action<Exception> _abort;
+
         private Func<Task> _responseStartingAsync = () => Task.FromResult(true);
         private Func<Task> _responseCompletedAsync = () => Task.FromResult(true);
-        private HeaderDictionary _headers = new HeaderDictionary();
         private int _statusCode;
         private string _reasonPhrase;
 
-        public ResponseFeature()
+        public ResponseFeature(Action<Exception> abort)
         {
             Headers = _headers;
             Body = new MemoryStream();
@@ -25,6 +28,7 @@ namespace Microsoft.AspNetCore.TestHost
             // 200 is the default status code all the way down to the host, so we set it
             // here to be consistent with the rest of the hosts when writing tests.
             StatusCode = 200;
+            _abort = abort;
         }
 
         public int StatusCode
@@ -98,14 +102,36 @@ namespace Microsoft.AspNetCore.TestHost
 
         public async Task FireOnSendingHeadersAsync()
         {
-            await _responseStartingAsync();
-            HasStarted = true;
-            _headers.IsReadOnly = true;
+            if (!HasStarted)
+            {
+                try
+                {
+                    await _responseStartingAsync();
+                }
+                finally
+                {
+                    HasStarted = true;
+                    _headers.IsReadOnly = true;
+                }
+            }
         }
 
         public Task FireOnResponseCompletedAsync()
         {
             return _responseCompletedAsync();
         }
+
+        public async Task StartAsync(CancellationToken token = default)
+        {
+            try
+            {
+                await FireOnSendingHeadersAsync();
+            }
+            catch (Exception ex)
+            {
+                _abort(ex);
+                throw;
+            }
+        }
     }
 }

+ 42 - 0
src/Hosting/TestHost/test/ClientHandlerTests.cs

@@ -153,6 +153,48 @@ namespace Microsoft.AspNetCore.TestHost
                 });
         }
 
+        [Fact]
+        public async Task ResponseStartAsync()
+        {
+            var hasStartedTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
+            var hasAssertedResponseTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
+
+            bool? preHasStarted = null;
+            bool? postHasStarted = null;
+            var handler = new ClientHandler(PathString.Empty, new DummyApplication(async context =>
+            {
+                preHasStarted = context.Response.HasStarted;
+
+                await context.Response.StartAsync();
+
+                postHasStarted = context.Response.HasStarted;
+
+                hasStartedTcs.TrySetResult(null);
+
+                await hasAssertedResponseTcs.Task;
+            }));
+
+            var invoker = new HttpMessageInvoker(handler);
+            var message = new HttpRequestMessage(HttpMethod.Post, "https://example.com/");
+
+            var responseTask = invoker.SendAsync(message, CancellationToken.None);
+
+            // Ensure StartAsync has been called in response
+            await hasStartedTcs.Task;
+
+            // Delay so async thread would have had time to attempt to return response
+            await Task.Delay(100);
+            Assert.False(responseTask.IsCompleted, "HttpResponse.StartAsync does not return response");
+
+            // Asserted that response return was checked, allow response to finish
+            hasAssertedResponseTcs.TrySetResult(null);
+
+            await responseTask;
+
+            Assert.False(preHasStarted);
+            Assert.True(postHasStarted);
+        }
+
         [Fact]
         public async Task ResubmitRequestWorks()
         {

+ 26 - 5
src/Hosting/TestHost/test/ResponseFeatureTests.cs

@@ -13,7 +13,7 @@ namespace Microsoft.AspNetCore.TestHost
         public async Task StatusCode_DefaultsTo200()
         {
             // Arrange & Act
-            var responseInformation = new ResponseFeature();
+            var responseInformation = CreateResponseFeature();
 
             // Assert
             Assert.Equal(200, responseInformation.StatusCode);
@@ -25,11 +25,27 @@ namespace Microsoft.AspNetCore.TestHost
             Assert.True(responseInformation.Headers.IsReadOnly);
         }
 
+        [Fact]
+        public async Task StartAsync_StartsResponse()
+        {
+            // Arrange & Act
+            var responseInformation = CreateResponseFeature();
+
+            // Assert
+            Assert.Equal(200, responseInformation.StatusCode);
+            Assert.False(responseInformation.HasStarted);
+
+            await responseInformation.StartAsync();
+
+            Assert.True(responseInformation.HasStarted);
+            Assert.True(responseInformation.Headers.IsReadOnly);
+        }
+
         [Fact]
         public void OnStarting_ThrowsWhenHasStarted()
         {
             // Arrange
-            var responseInformation = new ResponseFeature();
+            var responseInformation = CreateResponseFeature();
             responseInformation.HasStarted = true;
 
             // Act & Assert
@@ -45,7 +61,7 @@ namespace Microsoft.AspNetCore.TestHost
         [Fact]
         public void StatusCode_ThrowsWhenHasStarted()
         {
-            var responseInformation = new ResponseFeature();
+            var responseInformation = CreateResponseFeature();
             responseInformation.HasStarted = true;
 
             Assert.Throws<InvalidOperationException>(() => responseInformation.StatusCode = 400);
@@ -55,7 +71,7 @@ namespace Microsoft.AspNetCore.TestHost
         [Fact]
         public void StatusCode_MustBeGreaterThan99()
         {
-            var responseInformation = new ResponseFeature();
+            var responseInformation = CreateResponseFeature();
 
             Assert.Throws<ArgumentOutOfRangeException>(() => responseInformation.StatusCode = 99);
             Assert.Throws<ArgumentOutOfRangeException>(() => responseInformation.StatusCode = 0);
@@ -64,5 +80,10 @@ namespace Microsoft.AspNetCore.TestHost
             responseInformation.StatusCode = 200;
             responseInformation.StatusCode = 1000;
         }
+
+        private ResponseFeature CreateResponseFeature()
+        {
+            return new ResponseFeature(ex => { });
+        }
     }
-}
+}

+ 4 - 0
src/Middleware/Diagnostics/test/UnitTests/ExceptionHandlerTest.cs

@@ -7,11 +7,13 @@ using System.Diagnostics;
 using System.IO;
 using System.Linq;
 using System.Net;
+using System.Text;
 using System.Threading;
 using System.Threading.Tasks;
 using Microsoft.AspNetCore.Builder;
 using Microsoft.AspNetCore.Hosting;
 using Microsoft.AspNetCore.Http;
+using Microsoft.AspNetCore.Http.Features;
 using Microsoft.AspNetCore.TestHost;
 using Microsoft.Extensions.DependencyInjection;
 using Xunit;
@@ -114,6 +116,8 @@ namespace Microsoft.AspNetCore.Diagnostics
                     // add response buffering
                     app.Use(async (httpContext, next) =>
                     {
+                        httpContext.Features.Set<IHttpResponseStartFeature>(null);
+
                         var response = httpContext.Response;
                         var originalResponseBody = response.Body;
                         var bufferingStream = new MemoryStream();