Browse Source

Reset ConnectProtocol during reset (#47660)

Aditya Mandaleeka 2 years ago
parent
commit
a5f8c0a59c

+ 1 - 0
src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs

@@ -366,6 +366,7 @@ internal abstract partial class HttpProtocol : IHttpResponseControl
         IsExtendedConnectRequest = false;
         IsExtendedConnectAccepted = false;
         IsWebTransportRequest = false;
+        ConnectProtocol = null;
 
         var remoteEndPoint = RemoteEndPoint;
         RemoteIpAddress = remoteEndPoint?.Address;

+ 114 - 0
src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2WebSocketTests.cs

@@ -562,4 +562,118 @@ public class Http2WebSocketTests : Http2TestBase
 
         await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
     }
+
+    [Fact]
+    public async Task HEADERS_Received_SecondRequest_ConnectProtocolReset()
+    {
+        var appDelegateTcs = new TaskCompletionSource();
+        var requestCount = 0;
+        await InitializeConnectionAsync(async context =>
+        {
+            requestCount++;
+
+            var connectFeature = context.Features.Get<IHttpExtendedConnectFeature>();
+
+            if (requestCount == 1)
+            {
+                Assert.True(connectFeature.IsExtendedConnect);
+                Assert.Equal(HttpMethods.Connect, context.Request.Method);
+                Assert.Equal("websocket", connectFeature.Protocol);
+                context.Response.StatusCode = StatusCodes.Status201Created; // Any 2XX should work
+
+                var stream = await connectFeature.AcceptAsync();
+                await stream.WriteAsync(new byte[] { 0x01 });
+                await appDelegateTcs.Task;
+            }
+            else
+            {
+                if (connectFeature.Protocol != null)
+                {
+                    throw new Exception("ConnectProtocol should be null here. The fact that it is not indicates that we are not resetting properly between requests.");
+                }
+
+                // We've done the test. Now just return the normal echo server behavior.
+                await _echoApplication(context);
+            }
+
+        });
+
+        // HEADERS + END_HEADERS
+        // :method = CONNECT
+        // :protocol = websocket
+        // :scheme = https
+        // :path = /chat
+        // :authority = server.example.com
+        // sec-websocket-protocol = chat, superchat
+        // sec-websocket-extensions = permessage-deflate
+        // sec-websocket-version = 13
+        // origin = http://www.example.com
+        var headers = new[]
+        {
+            new KeyValuePair<string, string>(InternalHeaderNames.Method, "CONNECT"),
+            new KeyValuePair<string, string>(InternalHeaderNames.Protocol, "websocket"),
+            new KeyValuePair<string, string>(InternalHeaderNames.Scheme, "http"),
+            new KeyValuePair<string, string>(InternalHeaderNames.Path, "/chat"),
+            new KeyValuePair<string, string>(InternalHeaderNames.Authority, "server.example.com"),
+            new KeyValuePair<string, string>(HeaderNames.WebSocketSubProtocols, "chat, superchat"),
+            new KeyValuePair<string, string>(HeaderNames.SecWebSocketExtensions, "permessage-deflate"),
+            new KeyValuePair<string, string>(HeaderNames.SecWebSocketVersion, "13"),
+            new KeyValuePair<string, string>(HeaderNames.Origin, "http://www.example.com"),
+        };
+
+        await SendHeadersAsync(1, Http2HeadersFrameFlags.END_HEADERS, headers);
+        await SendDataAsync(1, Array.Empty<byte>(), endStream: true);
+
+        var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+            withLength: 36,
+            withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS,
+            withStreamId: 1);
+
+        var dataFrame = await ExpectAsync(Http2FrameType.DATA,
+            withLength: 1,
+            withFlags: (byte)Http2DataFrameFlags.NONE,
+            withStreamId: 1);
+        Assert.Equal(0x01, dataFrame.Payload.Span[0]);
+
+        appDelegateTcs.TrySetResult();
+        dataFrame = await ExpectAsync(Http2FrameType.DATA,
+            withLength: 0,
+            withFlags: (byte)Http2DataFrameFlags.END_STREAM,
+            withStreamId: 1);
+
+        // TriggerTick will trigger the stream to be returned to the pool so we can assert it
+        TriggerTick();
+
+        // Stream has been returned to the pool
+        Assert.Equal(1, _connection.StreamPool.Count);
+        Assert.True(_connection.StreamPool.TryPeek(out var pooledStream));
+
+        // Next is a plain GET.
+        var headers2 = new[]
+        {
+            new KeyValuePair<string, string>(InternalHeaderNames.Method, "GET"),
+            new KeyValuePair<string, string>(InternalHeaderNames.Path, "/"),
+            new KeyValuePair<string, string>(InternalHeaderNames.Scheme, "http"),
+            new KeyValuePair<string, string>(InternalHeaderNames.Authority, "example.com"),
+        };
+
+        await StartStreamAsync(3, headers2, endStream: false);
+        await SendDataAsync(3, _helloBytes, endStream: true);
+
+        // If the echo server doesn't give us the expected responses, the test has failed.
+        await ExpectAsync(Http2FrameType.HEADERS,
+            withLength: 2,
+            withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS,
+            withStreamId: 3);
+        await ExpectAsync(Http2FrameType.DATA,
+            withLength: 5,
+            withFlags: (byte)Http2DataFrameFlags.NONE,
+            withStreamId: 3);
+        await ExpectAsync(Http2FrameType.DATA,
+            withLength: 0,
+            withFlags: (byte)Http2DataFrameFlags.END_STREAM,
+            withStreamId: 3);
+
+        await StopConnectionAsync(expectedLastStreamId: 3, ignoreNonGoAwayFrames: false);
+    }
 }