|
|
@@ -562,4 +562,118 @@ public class Http2WebSocketTests : Http2TestBase
|
|
|
|
|
|
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
|
|
|
}
|
|
|
+
|
|
|
+ [Fact]
|
|
|
+ public async Task HEADERS_Received_SecondRequest_ConnectProtocolReset()
|
|
|
+ {
|
|
|
+ var appDelegateTcs = new TaskCompletionSource();
|
|
|
+ var requestCount = 0;
|
|
|
+ await InitializeConnectionAsync(async context =>
|
|
|
+ {
|
|
|
+ requestCount++;
|
|
|
+
|
|
|
+ var connectFeature = context.Features.Get<IHttpExtendedConnectFeature>();
|
|
|
+
|
|
|
+ if (requestCount == 1)
|
|
|
+ {
|
|
|
+ Assert.True(connectFeature.IsExtendedConnect);
|
|
|
+ Assert.Equal(HttpMethods.Connect, context.Request.Method);
|
|
|
+ Assert.Equal("websocket", connectFeature.Protocol);
|
|
|
+ context.Response.StatusCode = StatusCodes.Status201Created; // Any 2XX should work
|
|
|
+
|
|
|
+ var stream = await connectFeature.AcceptAsync();
|
|
|
+ await stream.WriteAsync(new byte[] { 0x01 });
|
|
|
+ await appDelegateTcs.Task;
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ if (connectFeature.Protocol != null)
|
|
|
+ {
|
|
|
+ throw new Exception("ConnectProtocol should be null here. The fact that it is not indicates that we are not resetting properly between requests.");
|
|
|
+ }
|
|
|
+
|
|
|
+ // We've done the test. Now just return the normal echo server behavior.
|
|
|
+ await _echoApplication(context);
|
|
|
+ }
|
|
|
+
|
|
|
+ });
|
|
|
+
|
|
|
+ // HEADERS + END_HEADERS
|
|
|
+ // :method = CONNECT
|
|
|
+ // :protocol = websocket
|
|
|
+ // :scheme = https
|
|
|
+ // :path = /chat
|
|
|
+ // :authority = server.example.com
|
|
|
+ // sec-websocket-protocol = chat, superchat
|
|
|
+ // sec-websocket-extensions = permessage-deflate
|
|
|
+ // sec-websocket-version = 13
|
|
|
+ // origin = http://www.example.com
|
|
|
+ var headers = new[]
|
|
|
+ {
|
|
|
+ new KeyValuePair<string, string>(InternalHeaderNames.Method, "CONNECT"),
|
|
|
+ new KeyValuePair<string, string>(InternalHeaderNames.Protocol, "websocket"),
|
|
|
+ new KeyValuePair<string, string>(InternalHeaderNames.Scheme, "http"),
|
|
|
+ new KeyValuePair<string, string>(InternalHeaderNames.Path, "/chat"),
|
|
|
+ new KeyValuePair<string, string>(InternalHeaderNames.Authority, "server.example.com"),
|
|
|
+ new KeyValuePair<string, string>(HeaderNames.WebSocketSubProtocols, "chat, superchat"),
|
|
|
+ new KeyValuePair<string, string>(HeaderNames.SecWebSocketExtensions, "permessage-deflate"),
|
|
|
+ new KeyValuePair<string, string>(HeaderNames.SecWebSocketVersion, "13"),
|
|
|
+ new KeyValuePair<string, string>(HeaderNames.Origin, "http://www.example.com"),
|
|
|
+ };
|
|
|
+
|
|
|
+ await SendHeadersAsync(1, Http2HeadersFrameFlags.END_HEADERS, headers);
|
|
|
+ await SendDataAsync(1, Array.Empty<byte>(), endStream: true);
|
|
|
+
|
|
|
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
|
|
|
+ withLength: 36,
|
|
|
+ withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS,
|
|
|
+ withStreamId: 1);
|
|
|
+
|
|
|
+ var dataFrame = await ExpectAsync(Http2FrameType.DATA,
|
|
|
+ withLength: 1,
|
|
|
+ withFlags: (byte)Http2DataFrameFlags.NONE,
|
|
|
+ withStreamId: 1);
|
|
|
+ Assert.Equal(0x01, dataFrame.Payload.Span[0]);
|
|
|
+
|
|
|
+ appDelegateTcs.TrySetResult();
|
|
|
+ dataFrame = await ExpectAsync(Http2FrameType.DATA,
|
|
|
+ withLength: 0,
|
|
|
+ withFlags: (byte)Http2DataFrameFlags.END_STREAM,
|
|
|
+ withStreamId: 1);
|
|
|
+
|
|
|
+ // TriggerTick will trigger the stream to be returned to the pool so we can assert it
|
|
|
+ TriggerTick();
|
|
|
+
|
|
|
+ // Stream has been returned to the pool
|
|
|
+ Assert.Equal(1, _connection.StreamPool.Count);
|
|
|
+ Assert.True(_connection.StreamPool.TryPeek(out var pooledStream));
|
|
|
+
|
|
|
+ // Next is a plain GET.
|
|
|
+ var headers2 = new[]
|
|
|
+ {
|
|
|
+ new KeyValuePair<string, string>(InternalHeaderNames.Method, "GET"),
|
|
|
+ new KeyValuePair<string, string>(InternalHeaderNames.Path, "/"),
|
|
|
+ new KeyValuePair<string, string>(InternalHeaderNames.Scheme, "http"),
|
|
|
+ new KeyValuePair<string, string>(InternalHeaderNames.Authority, "example.com"),
|
|
|
+ };
|
|
|
+
|
|
|
+ await StartStreamAsync(3, headers2, endStream: false);
|
|
|
+ await SendDataAsync(3, _helloBytes, endStream: true);
|
|
|
+
|
|
|
+ // If the echo server doesn't give us the expected responses, the test has failed.
|
|
|
+ await ExpectAsync(Http2FrameType.HEADERS,
|
|
|
+ withLength: 2,
|
|
|
+ withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS,
|
|
|
+ withStreamId: 3);
|
|
|
+ await ExpectAsync(Http2FrameType.DATA,
|
|
|
+ withLength: 5,
|
|
|
+ withFlags: (byte)Http2DataFrameFlags.NONE,
|
|
|
+ withStreamId: 3);
|
|
|
+ await ExpectAsync(Http2FrameType.DATA,
|
|
|
+ withLength: 0,
|
|
|
+ withFlags: (byte)Http2DataFrameFlags.END_STREAM,
|
|
|
+ withStreamId: 3);
|
|
|
+
|
|
|
+ await StopConnectionAsync(expectedLastStreamId: 3, ignoreNonGoAwayFrames: false);
|
|
|
+ }
|
|
|
}
|