Selaa lähdekoodia

Return response trailers from test server (#10135)

James Newton-King 6 vuotta sitten
vanhempi
sitoutus
11311d4f9d

+ 14 - 1
src/Hosting/TestHost/src/ClientHandler.cs

@@ -119,9 +119,22 @@ namespace Microsoft.AspNetCore.TestHost
                 responseBody = context.Response.Body;
             });
 
+            var response = new HttpResponseMessage();
+
+            // Copy trailers to the response message when the response stream is complete
+            contextBuilder.RegisterResponseReadCompleteCallback(context =>
+            {
+                var responseTrailersFeature = context.Features.Get<IHttpResponseTrailersFeature>();
+
+                foreach (var trailer in responseTrailersFeature.Trailers)
+                {
+                    bool success = response.TrailingHeaders.TryAddWithoutValidation(trailer.Key, (IEnumerable<string>)trailer.Value);
+                    Contract.Assert(success, "Bad trailer");
+                }
+            });
+
             var httpContext = await contextBuilder.SendAsync(cancellationToken);
 
-            var response = new HttpResponseMessage();
             response.StatusCode = (HttpStatusCode)httpContext.Response.StatusCode;
             response.ReasonPhrase = httpContext.Features.Get<IHttpResponseFeature>().ReasonPhrase;
             response.RequestMessage = request;

+ 15 - 7
src/Hosting/TestHost/src/HttpContextBuilder.cs

@@ -17,13 +17,15 @@ namespace Microsoft.AspNetCore.TestHost
         private readonly IHttpApplication<Context> _application;
         private readonly bool _preserveExecutionContext;
         private readonly HttpContext _httpContext;
-
-        private TaskCompletionSource<HttpContext> _responseTcs = new TaskCompletionSource<HttpContext>(TaskCreationOptions.RunContinuationsAsynchronously);
-        private ResponseStream _responseStream;
-        private ResponseFeature _responseFeature = new ResponseFeature();
-        private RequestLifetimeFeature _requestLifetimeFeature = new RequestLifetimeFeature();
+        
+        private readonly TaskCompletionSource<HttpContext> _responseTcs = new TaskCompletionSource<HttpContext>(TaskCreationOptions.RunContinuationsAsynchronously);
+        private readonly ResponseStream _responseStream;
+        private readonly ResponseFeature _responseFeature = new ResponseFeature();
+        private readonly RequestLifetimeFeature _requestLifetimeFeature = new RequestLifetimeFeature();
+        private readonly ResponseTrailersFeature _responseTrailersFeature = new ResponseTrailersFeature();
         private bool _pipelineFinished;
         private Context _testContext;
+        private Action<HttpContext> _responseReadCompleteCallback;
 
         internal HttpContextBuilder(IHttpApplication<Context> application, bool allowSynchronousIO, bool preserveExecutionContext)
         {
@@ -39,8 +41,9 @@ namespace Microsoft.AspNetCore.TestHost
             _httpContext.Features.Set<IHttpBodyControlFeature>(this);
             _httpContext.Features.Set<IHttpResponseFeature>(_responseFeature);
             _httpContext.Features.Set<IHttpRequestLifetimeFeature>(_requestLifetimeFeature);
-            
-            _responseStream = new ResponseStream(ReturnResponseMessageAsync, AbortRequest, () => AllowSynchronousIO);
+            _httpContext.Features.Set<IHttpResponseTrailersFeature>(_responseTrailersFeature);
+
+            _responseStream = new ResponseStream(ReturnResponseMessageAsync, AbortRequest, () => AllowSynchronousIO, () => _responseReadCompleteCallback?.Invoke(_httpContext));
             _responseFeature.Body = _responseStream;
         }
 
@@ -56,6 +59,11 @@ namespace Microsoft.AspNetCore.TestHost
             configureContext(_httpContext);
         }
 
+        internal void RegisterResponseReadCompleteCallback(Action<HttpContext> responseReadCompleteCallback)
+        {
+            _responseReadCompleteCallback = responseReadCompleteCallback;
+        }
+
         /// <summary>
         /// Start processing the request.
         /// </summary>

+ 17 - 7
src/Hosting/TestHost/src/ResponseStream.cs

@@ -17,22 +17,24 @@ namespace Microsoft.AspNetCore.TestHost
     internal class ResponseStream : Stream
     {
         private bool _complete;
+        private bool _readerComplete;
         private bool _aborted;
         private Exception _abortException;
-        private SemaphoreSlim _writeLock;
-
-        private Func<Task> _onFirstWriteAsync;
         private bool _firstWrite;
-        private Action _abortRequest;
-        private Func<bool> _allowSynchronousIO;
 
-        private Pipe _pipe = new Pipe();
+        private readonly SemaphoreSlim _writeLock;
+        private readonly Func<Task> _onFirstWriteAsync;
+        private readonly Action _abortRequest;
+        private readonly Func<bool> _allowSynchronousIO;
+        private readonly Action _readComplete;
+        private readonly Pipe _pipe = new Pipe();
 
-        internal ResponseStream(Func<Task> onFirstWriteAsync, Action abortRequest, Func<bool> allowSynchronousIO)
+        internal ResponseStream(Func<Task> onFirstWriteAsync, Action abortRequest, Func<bool> allowSynchronousIO, Action readComplete)
         {
             _onFirstWriteAsync = onFirstWriteAsync ?? throw new ArgumentNullException(nameof(onFirstWriteAsync));
             _abortRequest = abortRequest ?? throw new ArgumentNullException(nameof(abortRequest));
             _allowSynchronousIO = allowSynchronousIO ?? throw new ArgumentNullException(nameof(allowSynchronousIO));
+            _readComplete = readComplete;
             _firstWrite = true;
             _writeLock = new SemaphoreSlim(1, 1);
         }
@@ -108,6 +110,12 @@ namespace Microsoft.AspNetCore.TestHost
         {
             VerifyBuffer(buffer, offset, count, allowEmpty: false);
             CheckAborted();
+
+            if (_readerComplete)
+            {
+                return 0;
+            }
+
             var registration = cancellationToken.Register(Cancel);
             try
             {
@@ -116,6 +124,8 @@ namespace Microsoft.AspNetCore.TestHost
                 if (result.Buffer.IsEmpty && result.IsCompleted)
                 {
                     _pipe.Reader.Complete();
+                    _readComplete();
+                    _readerComplete = true;
                     return 0;
                 }
 

+ 13 - 0
src/Hosting/TestHost/src/ResponseTrailersFeature.cs

@@ -0,0 +1,13 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using Microsoft.AspNetCore.Http;
+using Microsoft.AspNetCore.Http.Features;
+
+namespace Microsoft.AspNetCore.TestHost
+{
+    internal class ResponseTrailersFeature : IHttpResponseTrailersFeature
+    {
+        public IHeaderDictionary Trailers { get; set; } = new HeaderDictionary();
+    }
+}

+ 65 - 0
src/Hosting/TestHost/test/ClientHandlerTests.cs

@@ -88,6 +88,71 @@ namespace Microsoft.AspNetCore.TestHost
             return httpClient.GetAsync("https://example.com/");
         }
 
+        [Fact]
+        public async Task ServerTrailersSetOnResponseAfterContentRead()
+        {
+            var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
+
+            var handler = new ClientHandler(PathString.Empty, new DummyApplication(async context =>
+            {
+                context.Response.AppendTrailer("StartTrailer", "Value!");
+
+                await context.Response.WriteAsync("Hello World");
+                await context.Response.Body.FlushAsync();
+
+                // Pause writing response to ensure trailers are written at the end
+                await tcs.Task;
+
+                await context.Response.WriteAsync("Bye World");
+                await context.Response.Body.FlushAsync();
+
+                context.Response.AppendTrailer("EndTrailer", "Value!");
+            }));
+
+            var invoker = new HttpMessageInvoker(handler);
+            var message = new HttpRequestMessage(HttpMethod.Post, "https://example.com/");
+
+            var response = await invoker.SendAsync(message, CancellationToken.None);
+
+            Assert.Empty(response.TrailingHeaders);
+
+            var responseBody = await response.Content.ReadAsStreamAsync();
+
+            int read = await responseBody.ReadAsync(new byte[100], 0, 100);
+            Assert.Equal(11, read);
+
+            Assert.Empty(response.TrailingHeaders);
+
+            var readTask = responseBody.ReadAsync(new byte[100], 0, 100);
+            Assert.False(readTask.IsCompleted);
+            tcs.TrySetResult(null);
+
+            read = await readTask;
+            Assert.Equal(9, read);
+
+            Assert.Empty(response.TrailingHeaders);
+
+            // Read nothing because we're at the end of the response
+            read = await responseBody.ReadAsync(new byte[100], 0, 100);
+            Assert.Equal(0, read);
+
+            // Ensure additional reads after end don't effect trailers
+            read = await responseBody.ReadAsync(new byte[100], 0, 100);
+            Assert.Equal(0, read);
+
+            Assert.Collection(response.TrailingHeaders,
+                kvp =>
+                {
+                    Assert.Equal("StartTrailer", kvp.Key);
+                    Assert.Equal("Value!", kvp.Value.Single());
+                },
+                kvp =>
+                {
+                    Assert.Equal("EndTrailer", kvp.Key);
+                    Assert.Equal("Value!", kvp.Value.Single());
+                });
+        }
+
         [Fact]
         public async Task ResubmitRequestWorks()
         {