Browse Source

Client to Server Streaming with IAsyncEnumerable (#9310)

Mikael Mengistu 7 years ago
parent
commit
ebb9ad20db

+ 49 - 13
src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs

@@ -42,7 +42,9 @@ namespace Microsoft.AspNetCore.SignalR.Client
         private readonly SemaphoreSlim _connectionLock = new SemaphoreSlim(1, 1);
 
         private static readonly MethodInfo _sendStreamItemsMethod = typeof(HubConnection).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).Single(m => m.Name.Equals("SendStreamItems"));
-
+#if NETCOREAPP3_0
+        private static readonly MethodInfo _sendIAsyncStreamItemsMethod = typeof(HubConnection).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).Single(m => m.Name.Equals("SendIAsyncEnumerableStreamItems"));
+#endif
         // Persistent across all connections
         private readonly ILoggerFactory _loggerFactory;
         private readonly ILogger _logger;
@@ -533,13 +535,11 @@ namespace Microsoft.AspNetCore.SignalR.Client
             }
 
             LaunchStreams(readers, cancellationToken);
-
             return channel;
         }
 
         private Dictionary<string, object> PackageStreamingParams(ref object[] args, out List<string> streamIds)
         {
-            // lazy initialized, to avoid allocating unecessary dictionaries
             Dictionary<string, object> readers = null;
             streamIds = null;
             var newArgs = new List<object>(args.Length);
@@ -572,7 +572,6 @@ namespace Microsoft.AspNetCore.SignalR.Client
             }
 
             args = newArgs.ToArray();
-
             return readers;
         }
 
@@ -590,6 +589,15 @@ namespace Microsoft.AspNetCore.SignalR.Client
                 // For each stream that needs to be sent, run a "send items" task in the background.
                 // This reads from the channel, attaches streamId, and sends to server.
                 // A single background thread here quickly gets messy.
+#if NETCOREAPP3_0
+                if (ReflectionHelper.IsIAsyncEnumerable(reader.GetType()))
+                {
+                    _ = _sendIAsyncStreamItemsMethod
+                        .MakeGenericMethod(reader.GetType().GetInterface("IAsyncEnumerable`1").GetGenericArguments())
+                        .Invoke(this, new object[] { kvp.Key.ToString(), reader, cancellationToken });
+                    continue;
+                }
+#endif
                 _ = _sendStreamItemsMethod
                     .MakeGenericMethod(reader.GetType().GetGenericArguments())
                     .Invoke(this, new object[] { kvp.Key.ToString(), reader, cancellationToken });
@@ -597,24 +605,52 @@ namespace Microsoft.AspNetCore.SignalR.Client
         }
 
         // this is called via reflection using the `_sendStreamItems` field
-        private async Task SendStreamItems<T>(string streamId, ChannelReader<T> reader, CancellationToken token)
+        private Task SendStreamItems<T>(string streamId, ChannelReader<T> reader, CancellationToken token)
         {
-            Log.StartingStream(_logger, streamId);
-
-            var combinedToken = CancellationTokenSource.CreateLinkedTokenSource(_uploadStreamToken, token).Token;
-
-            string responseError = null;
-            try
+            async Task ReadChannelStream(CancellationTokenSource tokenSource)
             {
-                while (await reader.WaitToReadAsync(combinedToken))
+                while (await reader.WaitToReadAsync(tokenSource.Token))
                 {
-                    while (!combinedToken.IsCancellationRequested && reader.TryRead(out var item))
+                    while (!tokenSource.Token.IsCancellationRequested && reader.TryRead(out var item))
                     {
                         await SendWithLock(new StreamItemMessage(streamId, item));
                         Log.SendingStreamItem(_logger, streamId);
                     }
                 }
             }
+
+            return CommonStreaming(streamId, token, ReadChannelStream);
+        }
+
+#if NETCOREAPP3_0
+        // this is called via reflection using the `_sendIAsyncStreamItemsMethod` field
+        private Task SendIAsyncEnumerableStreamItems<T>(string streamId, IAsyncEnumerable<T> stream, CancellationToken token)
+        {
+            async Task ReadAsyncEnumerableStream(CancellationTokenSource tokenSource)
+            {
+                var streamValues = AsyncEnumerableAdapters.MakeCancelableTypedAsyncEnumerable(stream, tokenSource);
+
+                await foreach (var streamValue in streamValues)
+                {
+                    await SendWithLock(new StreamItemMessage(streamId, streamValue));
+                    Log.SendingStreamItem(_logger, streamId);
+                }
+            }
+
+            return CommonStreaming(streamId, token, ReadAsyncEnumerableStream);
+        }
+#endif
+
+        private async Task CommonStreaming(string streamId, CancellationToken token, Func<CancellationTokenSource, Task> createAndConsumeStream)
+        {
+            var cts = CancellationTokenSource.CreateLinkedTokenSource(_uploadStreamToken, token);
+
+            Log.StartingStream(_logger, streamId);
+            string responseError = null;
+            try
+            {
+                await createAndConsumeStream(cts);
+            }
             catch (OperationCanceledException)
             {
                 Log.CancelingStream(_logger, streamId);

+ 101 - 1
src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs

@@ -661,6 +661,106 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
             }
         }
 
+        [Theory]
+        [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
+        [LogLevel(LogLevel.Trace)]
+        public async Task CanStreamToServerWithIAsyncEnumerable(string protocolName, HttpTransportType transportType, string path)
+        {
+            var protocol = HubProtocols[protocolName];
+            using (StartServer<Startup>(out var server))
+            {
+                var connection = CreateHubConnection(server.Url, path, transportType, protocol, LoggerFactory);
+                try
+                {
+                    async IAsyncEnumerable<string> clientStreamData()
+                    {
+                        var items = new string[] { "A", "B", "C", "D" };
+                        foreach (var item in items)
+                        {
+                            await Task.Delay(10);
+                            yield return item;
+                        }
+                    }
+
+                    await connection.StartAsync().OrTimeout();
+
+                    var stream = clientStreamData();
+
+                    var channel = await connection.StreamAsChannelAsync<string>("StreamEcho", stream).OrTimeout();
+
+                    Assert.Equal("A", await channel.ReadAsync().AsTask().OrTimeout());
+                    Assert.Equal("B", await channel.ReadAsync().AsTask().OrTimeout());
+                    Assert.Equal("C", await channel.ReadAsync().AsTask().OrTimeout());
+                    Assert.Equal("D", await channel.ReadAsync().AsTask().OrTimeout());
+
+                    var results = await channel.ReadAndCollectAllAsync().OrTimeout();
+                    Assert.Empty(results);
+                }
+                catch (Exception ex)
+                {
+                    LoggerFactory.CreateLogger<HubConnectionTests>().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName);
+                    throw;
+                }
+                finally
+                {
+                    await connection.DisposeAsync().OrTimeout();
+                }
+            }
+        }
+
+        [Theory]
+        [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
+        [LogLevel(LogLevel.Trace)]
+        public async Task CanCancelIAsyncEnumerableClientToServerUpload(string protocolName, HttpTransportType transportType, string path)
+        {
+            var protocol = HubProtocols[protocolName];
+            using (StartServer<Startup>(out var server))
+            {
+                var connection = CreateHubConnection(server.Url, path, transportType, protocol, LoggerFactory);
+                try
+                {
+                    async IAsyncEnumerable<int> clientStreamData()
+                    {
+                        for (var i = 0; i < 1000; i++)
+                        {
+                            yield return i;
+                            await Task.Delay(10);
+                        }
+                    }
+
+                    await connection.StartAsync().OrTimeout();
+                    var results = new List<int>();
+                    var stream = clientStreamData();
+                    var cts = new CancellationTokenSource();
+                    var ex = await Assert.ThrowsAsync<OperationCanceledException>(async () =>
+                    {
+                        var channel = await connection.StreamAsChannelAsync<int>("StreamEchoInt", stream, cts.Token).OrTimeout();
+
+                        while (await channel.WaitToReadAsync())
+                        {
+                            while (channel.TryRead(out var item))
+                            {
+                                results.Add(item);
+                                cts.Cancel();
+                            }
+                        }
+                    });
+
+                    Assert.True(results.Count > 0 && results.Count < 1000);
+                    Assert.True(cts.IsCancellationRequested);
+                }
+                catch (Exception ex)
+                {
+                    LoggerFactory.CreateLogger<HubConnectionTests>().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName);
+                    throw;
+                }
+                finally
+                {
+                    await connection.DisposeAsync().OrTimeout();
+                }
+            }
+        }
+
         [Theory]
         [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
         [LogLevel(LogLevel.Trace)]
@@ -673,7 +773,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
                 try
                 {
                     await connection.StartAsync().OrTimeout();
-                    var stream = connection.StreamAsync<int>("Stream", 1000 );
+                    var stream = connection.StreamAsync<int>("Stream", 1000);
                     var results = new List<int>();
 
                     var cts = new CancellationTokenSource();

+ 30 - 0
src/SignalR/clients/csharp/Client/test/FunctionalTests/Hubs.cs

@@ -43,6 +43,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
 
         public ChannelReader<string> StreamEcho(ChannelReader<string> source) => TestHubMethodsImpl.StreamEcho(source);
 
+        public ChannelReader<int> StreamEchoInt(ChannelReader<int> source) => TestHubMethodsImpl.StreamEchoInt(source);
+
         public string GetUserIdentifier()
         {
             return Context.UserIdentifier;
@@ -121,6 +123,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
         }
 
         public ChannelReader<string> StreamEcho(ChannelReader<string> source) => TestHubMethodsImpl.StreamEcho(source);
+
+        public ChannelReader<int> StreamEchoInt(ChannelReader<int> source) => TestHubMethodsImpl.StreamEchoInt(source);
     }
 
     public class TestHubT : Hub<ITestHub>
@@ -151,6 +155,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
         }
 
         public ChannelReader<string> StreamEcho(ChannelReader<string> source) => TestHubMethodsImpl.StreamEcho(source);
+
+        public ChannelReader<int> StreamEchoInt(ChannelReader<int> source) => TestHubMethodsImpl.StreamEchoInt(source);
     }
 
     internal static class TestHubMethodsImpl
@@ -212,6 +218,30 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
 
             return output.Reader;
         }
+
+        public static ChannelReader<int> StreamEchoInt(ChannelReader<int> source)
+        {
+            var output = Channel.CreateUnbounded<int>();
+            _ = Task.Run(async () =>
+            {
+                try
+                {
+                    while (await source.WaitToReadAsync())
+                    {
+                        while (source.TryRead(out var item))
+                        {
+                            await output.Writer.WriteAsync(item);
+                        }
+                    }
+                }
+                finally
+                {
+                    output.Writer.TryComplete();
+                }
+            });
+
+            return output.Reader;
+        }
     }
 
     public interface ITestHub

+ 12 - 3
src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocol.cs

@@ -210,12 +210,21 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
             return ApplyHeaders(headers, new StreamInvocationMessage(invocationId, target, arguments, streams));
         }
 
-        private static StreamItemMessage CreateStreamItemMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver)
+        private static HubMessage CreateStreamItemMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver)
         {
             var headers = ReadHeaders(input, ref offset);
             var invocationId = ReadInvocationId(input, ref offset);
-            var itemType = binder.GetStreamItemType(invocationId);
-            var value = DeserializeObject(input, ref offset, itemType, "item", resolver);
+            object value;
+            try
+            {
+                var itemType = binder.GetStreamItemType(invocationId);
+                value = DeserializeObject(input, ref offset, itemType, "item", resolver);
+            }
+            catch (Exception ex)
+            {
+                return new StreamBindingFailureMessage(invocationId, ExceptionDispatchInfo.Capture(ex));
+            }
+
             return ApplyHeaders(headers, new StreamItemMessage(invocationId, value));
         }
 

+ 26 - 0
src/SignalR/common/Shared/ReflectionHelper.cs

@@ -2,6 +2,8 @@
 // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. 
 
 using System;
+using System.Collections.Generic;
+using System.Linq;
 using System.Threading.Channels;
 
 namespace Microsoft.AspNetCore.SignalR
@@ -13,6 +15,13 @@ namespace Microsoft.AspNetCore.SignalR
         public static bool IsStreamingType(Type type, bool mustBeDirectType = false)
         {
             // TODO #2594 - add Streams here, to make sending files easy
+
+#if NETCOREAPP3_0
+            if (IsIAsyncEnumerable(type))
+            {
+                return true;
+            }
+#endif
             do
             {
                 if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(ChannelReader<>))
@@ -25,5 +34,22 @@ namespace Microsoft.AspNetCore.SignalR
 
             return false;
         }
+
+#if NETCOREAPP3_0
+        public static bool IsIAsyncEnumerable(Type type)
+        {
+            return type.GetInterfaces().Any(t =>
+            {
+                if (t.IsGenericType)
+                {
+                    return t.GetGenericTypeDefinition() == typeof(IAsyncEnumerable<>);
+                }
+                else
+                {
+                    return false;
+                }
+            });
+        }
+#endif
     }
 }

+ 4 - 2
src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTestBase.cs

@@ -277,8 +277,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
             // StreamItemMessage
             new InvalidMessageData("StreamItemMissingId", new byte[] { 0x92, 2, 0x80 }, "Reading 'invocationId' as String failed."),
             new InvalidMessageData("StreamItemInvocationIdBoolean", new byte[] { 0x93, 2, 0x80, 0xc2 }, "Reading 'invocationId' as String failed."),
-            new InvalidMessageData("StreamItemMissing", new byte[] { 0x93, 2, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z' }, "Deserializing object of the `String` type for 'item' failed."),
-            new InvalidMessageData("StreamItemTypeMismatch", new byte[] { 0x94, 2, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z', 42 }, "Deserializing object of the `String` type for 'item' failed."),
+
+            // These now trigger StreamBindingInvocationFailureMessages
+            //new InvalidMessageData("StreamItemMissing", new byte[] { 0x93, 2, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z' }, "Deserializing object of the `String` type for 'item' failed."),
+            //new InvalidMessageData("StreamItemTypeMismatch", new byte[] { 0x94, 2, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z', 42 }, "Deserializing object of the `String` type for 'item' failed."),
 
             // CompletionMessage
             new InvalidMessageData("CompletionMissingId", new byte[] { 0x92, 3, 0x80 }, "Reading 'invocationId' as String failed."),