SignalR 728 B

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. commit c852bdcc332ffb998ec6a5b226e35d5e74d24009
  2. Author: BrennanConroy <[email protected]>
  3. Date: Wed Nov 21 11:47:39 2018 -0800
  4. Avoid zero-byte send in WebSockets (#3326)
  5. diff --git a/src/Common/WebSocketExtensions.cs b/src/Common/WebSocketExtensions.cs
  6. index a15ad78891a..fedb954296a 100644
  7. --- a/src/Common/WebSocketExtensions.cs
  8. +++ b/src/Common/WebSocketExtensions.cs
  9. @@ -39,22 +39,28 @@ namespace System.Net.WebSockets
  10. private static async ValueTask SendMultiSegmentAsync(WebSocket webSocket, ReadOnlySequence<byte> buffer, WebSocketMessageType webSocketMessageType, CancellationToken cancellationToken = default)
  11. {
  12. var position = buffer.Start;
  13. + // Get a segment before the loop so we can be one segment behind while writing
  14. + // This allows us to do a non-zero byte write for the endOfMessage = true send
  15. + buffer.TryGet(ref position, out var prevSegment);
  16. while (buffer.TryGet(ref position, out var segment))
  17. {
  18. #if NETCOREAPP3_0
  19. - await webSocket.SendAsync(segment, webSocketMessageType, endOfMessage: false, cancellationToken);
  20. + await webSocket.SendAsync(prevSegment, webSocketMessageType, endOfMessage: false, cancellationToken);
  21. #else
  22. - var isArray = MemoryMarshal.TryGetArray(segment, out var arraySegment);
  23. + var isArray = MemoryMarshal.TryGetArray(prevSegment, out var arraySegment);
  24. Debug.Assert(isArray);
  25. await webSocket.SendAsync(arraySegment, webSocketMessageType, endOfMessage: false, cancellationToken);
  26. #endif
  27. + prevSegment = segment;
  28. }
  29. - // Empty end of message frame
  30. + // End of message frame
  31. #if NETCOREAPP3_0
  32. - await webSocket.SendAsync(Memory<byte>.Empty, webSocketMessageType, endOfMessage: true, cancellationToken);
  33. + await webSocket.SendAsync(prevSegment, webSocketMessageType, endOfMessage: true, cancellationToken);
  34. #else
  35. - await webSocket.SendAsync(new ArraySegment<byte>(Array.Empty<byte>()), webSocketMessageType, endOfMessage: true, cancellationToken);
  36. + var isArrayEnd = MemoryMarshal.TryGetArray(prevSegment, out var arraySegmentEnd);
  37. + Debug.Assert(isArrayEnd);
  38. + await webSocket.SendAsync(arraySegmentEnd, webSocketMessageType, endOfMessage: true, cancellationToken);
  39. #endif
  40. }
  41. }
  42. diff --git a/test/Microsoft.AspNetCore.Http.Connections.Tests/WebSocketsTests.cs b/test/Microsoft.AspNetCore.Http.Connections.Tests/WebSocketsTests.cs
  43. index 0af2f658128..8068853f17b 100644
  44. --- a/test/Microsoft.AspNetCore.Http.Connections.Tests/WebSocketsTests.cs
  45. +++ b/test/Microsoft.AspNetCore.Http.Connections.Tests/WebSocketsTests.cs
  46. @@ -396,5 +396,36 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
  47. }
  48. }
  49. }
  50. +
  51. + [Fact]
  52. + public async Task MultiSegmentSendWillNotSendEmptyEndOfMessageFrame()
  53. + {
  54. + using (var feature = new TestWebSocketConnectionFeature())
  55. + {
  56. + var serverSocket = await feature.AcceptAsync();
  57. + var sequence = ReadOnlySequenceFactory.CreateSegments(new byte[] { 1 }, new byte[] { 15 });
  58. + Assert.False(sequence.IsSingleSegment);
  59. +
  60. + await serverSocket.SendAsync(sequence, WebSocketMessageType.Text);
  61. +
  62. + // Run the client socket
  63. + var client = feature.Client.ExecuteAndCaptureFramesAsync();
  64. +
  65. + await serverSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "", default);
  66. +
  67. + var messages = await client.OrTimeout();
  68. + Assert.Equal(2, messages.Received.Count);
  69. +
  70. + // First message: 1 byte, endOfMessage false
  71. + Assert.Single(messages.Received[0].Buffer);
  72. + Assert.Equal(1, messages.Received[0].Buffer[0]);
  73. + Assert.False(messages.Received[0].EndOfMessage);
  74. +
  75. + // Second message: 1 byte, endOfMessage true
  76. + Assert.Single(messages.Received[1].Buffer);
  77. + Assert.Equal(15, messages.Received[1].Buffer[0]);
  78. + Assert.True(messages.Received[1].EndOfMessage);
  79. + }
  80. + }
  81. }
  82. }