| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485 |
- commit c852bdcc332ffb998ec6a5b226e35d5e74d24009
- Author: BrennanConroy <[email protected]>
- Date: Wed Nov 21 11:47:39 2018 -0800
- Avoid zero-byte send in WebSockets (#3326)
- diff --git a/src/Common/WebSocketExtensions.cs b/src/Common/WebSocketExtensions.cs
- index a15ad78891a..fedb954296a 100644
- --- a/src/Common/WebSocketExtensions.cs
- +++ b/src/Common/WebSocketExtensions.cs
- @@ -39,22 +39,28 @@ namespace System.Net.WebSockets
- private static async ValueTask SendMultiSegmentAsync(WebSocket webSocket, ReadOnlySequence<byte> buffer, WebSocketMessageType webSocketMessageType, CancellationToken cancellationToken = default)
- {
- var position = buffer.Start;
- + // Get a segment before the loop so we can be one segment behind while writing
- + // This allows us to do a non-zero byte write for the endOfMessage = true send
- + buffer.TryGet(ref position, out var prevSegment);
- while (buffer.TryGet(ref position, out var segment))
- {
- #if NETCOREAPP3_0
- - await webSocket.SendAsync(segment, webSocketMessageType, endOfMessage: false, cancellationToken);
- + await webSocket.SendAsync(prevSegment, webSocketMessageType, endOfMessage: false, cancellationToken);
- #else
- - var isArray = MemoryMarshal.TryGetArray(segment, out var arraySegment);
- + var isArray = MemoryMarshal.TryGetArray(prevSegment, out var arraySegment);
- Debug.Assert(isArray);
- await webSocket.SendAsync(arraySegment, webSocketMessageType, endOfMessage: false, cancellationToken);
- #endif
- + prevSegment = segment;
- }
-
- - // Empty end of message frame
- + // End of message frame
- #if NETCOREAPP3_0
- - await webSocket.SendAsync(Memory<byte>.Empty, webSocketMessageType, endOfMessage: true, cancellationToken);
- + await webSocket.SendAsync(prevSegment, webSocketMessageType, endOfMessage: true, cancellationToken);
- #else
- - await webSocket.SendAsync(new ArraySegment<byte>(Array.Empty<byte>()), webSocketMessageType, endOfMessage: true, cancellationToken);
- + var isArrayEnd = MemoryMarshal.TryGetArray(prevSegment, out var arraySegmentEnd);
- + Debug.Assert(isArrayEnd);
- + await webSocket.SendAsync(arraySegmentEnd, webSocketMessageType, endOfMessage: true, cancellationToken);
- #endif
- }
- }
- diff --git a/test/Microsoft.AspNetCore.Http.Connections.Tests/WebSocketsTests.cs b/test/Microsoft.AspNetCore.Http.Connections.Tests/WebSocketsTests.cs
- index 0af2f658128..8068853f17b 100644
- --- a/test/Microsoft.AspNetCore.Http.Connections.Tests/WebSocketsTests.cs
- +++ b/test/Microsoft.AspNetCore.Http.Connections.Tests/WebSocketsTests.cs
- @@ -396,5 +396,36 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
- }
- }
- }
- +
- + [Fact]
- + public async Task MultiSegmentSendWillNotSendEmptyEndOfMessageFrame()
- + {
- + using (var feature = new TestWebSocketConnectionFeature())
- + {
- + var serverSocket = await feature.AcceptAsync();
- + var sequence = ReadOnlySequenceFactory.CreateSegments(new byte[] { 1 }, new byte[] { 15 });
- + Assert.False(sequence.IsSingleSegment);
- +
- + await serverSocket.SendAsync(sequence, WebSocketMessageType.Text);
- +
- + // Run the client socket
- + var client = feature.Client.ExecuteAndCaptureFramesAsync();
- +
- + await serverSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "", default);
- +
- + var messages = await client.OrTimeout();
- + Assert.Equal(2, messages.Received.Count);
- +
- + // First message: 1 byte, endOfMessage false
- + Assert.Single(messages.Received[0].Buffer);
- + Assert.Equal(1, messages.Received[0].Buffer[0]);
- + Assert.False(messages.Received[0].EndOfMessage);
- +
- + // Second message: 1 byte, endOfMessage true
- + Assert.Single(messages.Received[1].Buffer);
- + Assert.Equal(15, messages.Received[1].Buffer[0]);
- + Assert.True(messages.Received[1].EndOfMessage);
- + }
- + }
- }
- }
|