Browse Source

Fix Kestrel overpooling of HTTP/2 and HTTP/3 request headers (#40087)

* Fix Kestrel overpooling of HTTP/2 and HTTP/3 request headers

* Fix build
James Newton-King 4 years ago
parent
commit
e371b5d160

+ 2 - 0
src/Servers/Kestrel/Core/src/Internal/Http2/Http2Connection.cs

@@ -1079,6 +1079,8 @@ internal partial class Http2Connection : IHttp2StreamLifetimeHandler, IHttpStrea
 
             if (endHeaders)
             {
+                _currentHeadersStream.OnHeadersComplete();
+
                 StartStream();
                 ResetRequestHeaderParsingState();
             }

+ 1 - 1
src/Servers/Kestrel/Core/src/Internal/Http3/Http3Connection.cs

@@ -16,7 +16,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3;
 
 internal class Http3Connection : IHttp3StreamLifetimeHandler, IRequestProcessor
 {
-    private static readonly object StreamPersistentStateKey = new object();
+    internal static readonly object StreamPersistentStateKey = new object();
 
     // Internal for unit testing
     internal readonly Dictionary<long, IHttp3Stream> _streams = new Dictionary<long, IHttp3Stream>();

+ 2 - 0
src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs

@@ -811,6 +811,8 @@ internal abstract partial class Http3Stream : HttpProtocol, IHttp3Stream, IHttpS
 
         InputRemaining = HttpRequestHeaders.ContentLength;
 
+        OnHeadersComplete();
+
         // If the stream is complete after receiving the headers then run OnEndStreamReceived.
         // If there is a bad content length then this will throw before the request delegate is called.
         if (isCompleted)

+ 69 - 0
src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2ConnectionTests.cs

@@ -7,6 +7,7 @@ using System.Globalization;
 using System.Net.Http;
 using System.Net.Http.HPack;
 using System.Net.Security;
+using System.Reflection;
 using System.Security.Authentication;
 using System.Text;
 using Microsoft.AspNetCore.Connections;
@@ -14,6 +15,7 @@ using Microsoft.AspNetCore.Connections.Features;
 using Microsoft.AspNetCore.Http;
 using Microsoft.AspNetCore.Http.Features;
 using Microsoft.AspNetCore.Server.Kestrel.Core.Features;
+using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http;
 using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2;
 using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
 using Microsoft.AspNetCore.Testing;
@@ -207,6 +209,73 @@ public class Http2ConnectionTests : Http2TestBase
         await StopConnectionAsync(expectedLastStreamId: 3, ignoreNonGoAwayFrames: false);
     }
 
+    [Fact]
+    public async Task RequestHeaderStringReuse_MultipleStreams_KnownHeaderClearedIfNotReused()
+    {
+        const BindingFlags privateFlags = BindingFlags.NonPublic | BindingFlags.Instance;
+
+        IEnumerable<KeyValuePair<string, string>> requestHeaders1 = new[]
+        {
+            new KeyValuePair<string, string>(HeaderNames.Method, "GET"),
+            new KeyValuePair<string, string>(HeaderNames.Path, "/hello"),
+            new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
+            new KeyValuePair<string, string>(HeaderNames.Authority, "localhost:80"),
+            new KeyValuePair<string, string>(HeaderNames.ContentType, "application/json")
+        };
+
+        // Note: No content-type
+        IEnumerable<KeyValuePair<string, string>> requestHeaders2 = new[]
+        {
+            new KeyValuePair<string, string>(HeaderNames.Method, "GET"),
+            new KeyValuePair<string, string>(HeaderNames.Path, "/hello"),
+            new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
+            new KeyValuePair<string, string>(HeaderNames.Authority, "localhost:80")
+        };
+
+        await InitializeConnectionAsync(_noopApplication);
+
+        await StartStreamAsync(1, requestHeaders1, endStream: true);
+
+        await ExpectAsync(Http2FrameType.HEADERS,
+            withLength: 36,
+            withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
+            withStreamId: 1);
+
+        // TriggerTick will trigger the stream to be returned to the pool so we can assert it
+        TriggerTick();
+
+        // Stream has been returned to the pool
+        Assert.Equal(1, _connection.StreamPool.Count);
+        Assert.True(_connection.StreamPool.TryPeek(out var stream1));
+
+        // Hacky but required because header references is private.
+        var headerReferences1 = typeof(HttpRequestHeaders).GetField("_headers", privateFlags).GetValue(stream1.RequestHeaders);
+        var contentTypeValue1 = (StringValues)headerReferences1.GetType().GetField("_ContentType").GetValue(headerReferences1);
+
+        await StartStreamAsync(3, requestHeaders2, endStream: true);
+
+        await ExpectAsync(Http2FrameType.HEADERS,
+            withLength: 6,
+            withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
+            withStreamId: 3);
+
+        // TriggerTick will trigger the stream to be returned to the pool so we can assert it
+        TriggerTick();
+
+        // Stream has been returned to the pool
+        Assert.Equal(1, _connection.StreamPool.Count);
+        Assert.True(_connection.StreamPool.TryPeek(out var stream2));
+
+        // Hacky but required because header references is private.
+        var headerReferences2 = typeof(HttpRequestHeaders).GetField("_headers", privateFlags).GetValue(stream2.RequestHeaders);
+        var contentTypeValue2 = (StringValues)headerReferences2.GetType().GetField("_ContentType").GetValue(headerReferences2);
+
+        Assert.Equal("application/json", contentTypeValue1);
+        Assert.Equal(StringValues.Empty, contentTypeValue2);
+
+        await StopConnectionAsync(expectedLastStreamId: 3, ignoreNonGoAwayFrames: false);
+    }
+
     private class ResponseTrailersWrapper : IHeaderDictionary
     {
         readonly IHeaderDictionary _innerHeaders;

+ 46 - 0
src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3ConnectionTests.cs

@@ -7,12 +7,15 @@ using System.Collections;
 using System.Collections.Generic;
 using System.Globalization;
 using System.Net.Http;
+using System.Reflection;
 using System.Text;
 using System.Text.RegularExpressions;
 using System.Threading.Tasks;
 using Microsoft.AspNetCore.Connections;
+using Microsoft.AspNetCore.Connections.Features;
 using Microsoft.AspNetCore.Http;
 using Microsoft.AspNetCore.Http.Features;
+using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http;
 using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3;
 using Microsoft.AspNetCore.Testing;
 using Microsoft.Extensions.Logging;
@@ -372,6 +375,49 @@ public class Http3ConnectionTests : Http3TestBase
         Assert.Same(authority1, authority2);
     }
 
+    [Fact]
+    public async Task RequestHeaderStringReuse_MultipleStreams_KnownHeaderClearedIfNotReused()
+    {
+        const BindingFlags privateFlags = BindingFlags.NonPublic | BindingFlags.Instance;
+
+        KeyValuePair<string, string>[] requestHeaders1 = new[]
+        {
+            new KeyValuePair<string, string>(HeaderNames.Method, "GET"),
+            new KeyValuePair<string, string>(HeaderNames.Path, "/hello"),
+            new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
+            new KeyValuePair<string, string>(HeaderNames.Authority, "localhost:80"),
+            new KeyValuePair<string, string>(HeaderNames.ContentType, "application/json")
+        };
+
+        // Note: No content-type
+        KeyValuePair<string, string>[] requestHeaders2 = new[]
+        {
+            new KeyValuePair<string, string>(HeaderNames.Method, "GET"),
+            new KeyValuePair<string, string>(HeaderNames.Path, "/hello"),
+            new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
+            new KeyValuePair<string, string>(HeaderNames.Authority, "localhost:80")
+        };
+
+        await Http3Api.InitializeConnectionAsync(_echoApplication);
+
+        var streamContext1 = await MakeRequestAsync(0, requestHeaders1, sendData: true, waitForServerDispose: true);
+        var http3Stream1 = (Http3Stream)streamContext1.Features.Get<IPersistentStateFeature>().State[Http3Connection.StreamPersistentStateKey];
+
+        // Hacky but required because header references is private.
+        var headerReferences1 = typeof(HttpRequestHeaders).GetField("_headers", privateFlags).GetValue(http3Stream1.RequestHeaders);
+        var contentTypeValue1 = (StringValues)headerReferences1.GetType().GetField("_ContentType").GetValue(headerReferences1);
+
+        var streamContext2 = await MakeRequestAsync(1, requestHeaders2, sendData: true, waitForServerDispose: true);
+        var http3Stream2 = (Http3Stream)streamContext2.Features.Get<IPersistentStateFeature>().State[Http3Connection.StreamPersistentStateKey];
+
+        // Hacky but required because header references is private.
+        var headerReferences2 = typeof(HttpRequestHeaders).GetField("_headers", privateFlags).GetValue(http3Stream2.RequestHeaders);
+        var contentTypeValue2 = (StringValues)headerReferences1.GetType().GetField("_ContentType").GetValue(headerReferences2);
+
+        Assert.Equal("application/json", contentTypeValue1);
+        Assert.Equal(StringValues.Empty, contentTypeValue2);
+    }
+
     [Theory]
     [InlineData(10)]
     [InlineData(100)]