Просмотр исходного кода

HTTP/3: Add trailers and reset features (#28763)

James Newton-King 5 лет назад
Родитель
Сommit
9deb14da6f

+ 4 - 1
src/Servers/Kestrel/Core/src/CoreStrings.resx

@@ -647,4 +647,7 @@ For more information on configuring HTTPS see https://go.microsoft.com/fwlink/?l
   <data name="Http3StreamAborted" xml:space="preserve">
     <value>The HTTP/3 request stream was aborted.</value>
   </data>
-</root>
+  <data name="Http3StreamResetByApplication" xml:space="preserve">
+    <value>The HTTP/3 stream was reset by the application with error code {errorCode}.</value>
+  </data>
+</root>

+ 6 - 0
src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs

@@ -268,6 +268,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
             _currentIHttpResetFeature = this;
         }
 
+        protected void ResetHttp3Features()
+        {
+            _currentIHttpResponseTrailersFeature = this;
+            _currentIHttpResetFeature = this;
+        }
+
         void IHttpResponseFeature.OnStarting(Func<object, Task> callback, object state)
         {
             OnStarting(callback, state);

+ 49 - 0
src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.FeatureCollection.cs

@@ -0,0 +1,49 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using System.Net.Http;
+using Microsoft.AspNetCore.Connections;
+using Microsoft.AspNetCore.Http;
+using Microsoft.AspNetCore.Http.Features;
+using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http;
+
+namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3
+{
+    internal partial class Http3Stream : IHttpResetFeature,
+                                         IHttpResponseTrailersFeature
+    {
+        private IHeaderDictionary? _userTrailers;
+
+        IHeaderDictionary IHttpResponseTrailersFeature.Trailers
+        {
+            get
+            {
+                if (ResponseTrailers == null)
+                {
+                    ResponseTrailers = new HttpResponseTrailers();
+                    if (HasResponseCompleted)
+                    {
+                        ResponseTrailers.SetReadOnly();
+                    }
+                }
+                return _userTrailers ?? ResponseTrailers;
+            }
+            set
+            {
+                if (value == null)
+                {
+                    throw new ArgumentNullException(nameof(value));
+                }
+
+                _userTrailers = value;
+            }
+        }
+
+        void IHttpResetFeature.Reset(int errorCode)
+        {
+            var abortReason = new ConnectionAbortedException(CoreStrings.FormatHttp3StreamResetByApplication((Http3ErrorCode)errorCode));
+            ApplicationAbort(abortReason, (Http3ErrorCode)errorCode);
+        }
+    }
+}

+ 6 - 4
src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs

@@ -20,7 +20,7 @@ using Microsoft.Net.Http.Headers;
 
 namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3
 {
-    internal abstract class Http3Stream : HttpProtocol, IHttpHeadersHandler, IThreadPoolWorkItem, ITimeoutHandler, IRequestProcessor
+    internal abstract partial class Http3Stream : HttpProtocol, IHttpHeadersHandler, IThreadPoolWorkItem, ITimeoutHandler, IRequestProcessor
     {
         private static ReadOnlySpan<byte> AuthorityBytes => new byte[10] { (byte)':', (byte)'a', (byte)'u', (byte)'t', (byte)'h', (byte)'o', (byte)'r', (byte)'i', (byte)'t', (byte)'y' };
         private static ReadOnlySpan<byte> MethodBytes => new byte[7] { (byte)':', (byte)'m', (byte)'e', (byte)'t', (byte)'h', (byte)'o', (byte)'d' };
@@ -507,12 +507,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3
 
         protected override void OnReset()
         {
+            ResetHttp3Features();
         }
 
-        protected override void ApplicationAbort()
+        protected override void ApplicationAbort() => ApplicationAbort(new ConnectionAbortedException(CoreStrings.ConnectionAbortedByApplication), Http3ErrorCode.InternalError);
+
+        private void ApplicationAbort(ConnectionAbortedException abortReason, Http3ErrorCode error)
         {
-            var abortReason = new ConnectionAbortedException(CoreStrings.ConnectionAbortedByApplication);
-            Abort(abortReason, Http3ErrorCode.InternalError);
+            Abort(abortReason, error);
         }
 
         protected override string CreateRequestId()

+ 119 - 0
src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3StreamTests.cs

@@ -641,5 +641,124 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
 
             await requestStream.WaitForStreamErrorAsync(Http3ErrorCode.ProtocolError, CoreStrings.Http3StreamErrorLessDataThanLength);
         }
+
+        [Fact]
+        public async Task ResponseTrailers_WithoutData_Sent()
+        {
+            var headers = new[]
+            {
+                new KeyValuePair<string, string>(HeaderNames.Method, "Custom"),
+                new KeyValuePair<string, string>(HeaderNames.Path, "/"),
+                new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
+                new KeyValuePair<string, string>(HeaderNames.Authority, "localhost:80"),
+            };
+
+            var requestStream = await InitializeConnectionAndStreamsAsync(context =>
+            {
+                var trailersFeature = context.Features.Get<IHttpResponseTrailersFeature>();
+
+                trailersFeature.Trailers.Add("Trailer1", "Value1");
+                trailersFeature.Trailers.Add("Trailer2", "Value2");
+
+                return Task.CompletedTask;
+            });
+
+            var doneWithHeaders = await requestStream.SendHeadersAsync(headers, endStream: true);
+
+            var responseHeaders = await requestStream.ExpectHeadersAsync();
+
+            var responseTrailers = await requestStream.ExpectHeadersAsync();
+
+            Assert.Equal(2, responseTrailers.Count);
+            Assert.Equal("Value1", responseTrailers["Trailer1"]);
+            Assert.Equal("Value2", responseTrailers["Trailer2"]);
+        }
+
+        [Fact]
+        public async Task ResponseTrailers_WithData_Sent()
+        {
+            var headers = new[]
+            {
+                new KeyValuePair<string, string>(HeaderNames.Method, "Custom"),
+                new KeyValuePair<string, string>(HeaderNames.Path, "/"),
+                new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
+                new KeyValuePair<string, string>(HeaderNames.Authority, "localhost:80"),
+            };
+
+            var requestStream = await InitializeConnectionAndStreamsAsync(async context =>
+            {
+                var trailersFeature = context.Features.Get<IHttpResponseTrailersFeature>();
+
+                trailersFeature.Trailers.Add("Trailer1", "Value1");
+                trailersFeature.Trailers.Add("Trailer2", "Value2");
+
+                await context.Response.WriteAsync("Hello world");
+            });
+
+            var doneWithHeaders = await requestStream.SendHeadersAsync(headers, endStream: true);
+
+            var responseHeaders = await requestStream.ExpectHeadersAsync();
+            var responseData = await requestStream.ExpectDataAsync();
+            Assert.Equal("Hello world", Encoding.ASCII.GetString(responseData.ToArray()));
+
+            var responseTrailers = await requestStream.ExpectHeadersAsync();
+
+            Assert.Equal(2, responseTrailers.Count);
+            Assert.Equal("Value1", responseTrailers["Trailer1"]);
+            Assert.Equal("Value2", responseTrailers["Trailer2"]);
+        }
+
+        [Fact]
+        public async Task ResponseTrailers_WithExeption500_Cleared()
+        {
+            var headers = new[]
+            {
+                new KeyValuePair<string, string>(HeaderNames.Method, "Custom"),
+                new KeyValuePair<string, string>(HeaderNames.Path, "/"),
+                new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
+                new KeyValuePair<string, string>(HeaderNames.Authority, "localhost:80"),
+            };
+
+            var requestStream = await InitializeConnectionAndStreamsAsync(context =>
+            {
+                var trailersFeature = context.Features.Get<IHttpResponseTrailersFeature>();
+
+                trailersFeature.Trailers.Add("Trailer1", "Value1");
+                trailersFeature.Trailers.Add("Trailer2", "Value2");
+
+                throw new NotImplementedException("Test Exception");
+            });
+
+            var doneWithHeaders = await requestStream.SendHeadersAsync(headers, endStream: true);
+
+            var responseHeaders = await requestStream.ExpectHeadersAsync();
+
+            await requestStream.ExpectReceiveEndOfStream();
+        }
+
+        [Fact]
+        public async Task ResetStream_ReturnStreamError()
+        {
+            var headers = new[]
+            {
+                new KeyValuePair<string, string>(HeaderNames.Method, "Custom"),
+                new KeyValuePair<string, string>(HeaderNames.Path, "/"),
+                new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
+                new KeyValuePair<string, string>(HeaderNames.Authority, "localhost:80"),
+            };
+
+            var requestStream = await InitializeConnectionAndStreamsAsync(context =>
+            {
+                var resetFeature = context.Features.Get<IHttpResetFeature>();
+
+                resetFeature.Reset((int)Http3ErrorCode.RequestCancelled);
+
+                return Task.CompletedTask;
+            });
+
+            var doneWithHeaders = await requestStream.SendHeadersAsync(headers, endStream: true);
+
+            await requestStream.WaitForStreamErrorAsync(Http3ErrorCode.RequestCancelled, CoreStrings.FormatHttp3StreamResetByApplication(Http3ErrorCode.RequestCancelled));
+        }
     }
 }

+ 1 - 0
src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3TestBase.cs

@@ -396,6 +396,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
             internal async Task<Dictionary<string, string>> ExpectHeadersAsync()
             {
                 var http3WithPayload = await ReceiveFrameAsync();
+                _decodedHeaders.Clear();
                 _qpackDecoder.Decode(http3WithPayload.PayloadSequence, this);
                 return _decodedHeaders;
             }

+ 2 - 0
src/Shared/runtime/Http3/QPack/QPackDecoder.cs

@@ -173,6 +173,8 @@ namespace System.Net.Http.QPack
             {
                 Decode(segment.Span, handler);
             }
+
+            _state = State.RequiredInsertCount;
         }
 
         public void Decode(ReadOnlySpan<byte> headerBlock, IHttpHeadersHandler handler)