akarnokd пре 6 година
родитељ
комит
f1daa4f495

+ 110 - 0
Ix.NET/Source/System.Interactive.Async.Tests/System/Linq/Operators/Amb.cs

@@ -5,6 +5,7 @@
 using System;
 using System.Collections.Generic;
 using System.Linq;
+using System.Threading;
 using System.Threading.Tasks;
 using Xunit;
 
@@ -228,5 +229,114 @@ namespace Tests
                 await xs.DisposeAsync();
             }
         }
+
+
+        [Fact]
+        public async Task Amb_First_GetAsyncEnumerator_Crashes()
+        {
+            var source = new FailingGetAsyncEnumerator<int>().Amb(AsyncEnumerableEx.Never<int>());
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                await xs.MoveNextAsync();
+
+                Assert.False(true, "Should not have gotten here");
+            }
+            catch (InvalidOperationException)
+            {
+                // we expect this
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Amb_Second_GetAsyncEnumerator_Crashes()
+        {
+            var source = AsyncEnumerableEx.Never<int>().Amb(new FailingGetAsyncEnumerator<int>());
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                await xs.MoveNextAsync();
+
+                Assert.False(true, "Should not have gotten here");
+            }
+            catch (InvalidOperationException)
+            {
+                // we expect this
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Amb_Many_First_GetAsyncEnumerator_Crashes()
+        {
+            var source = AsyncEnumerableEx.Amb(
+                new FailingGetAsyncEnumerator<int>(),
+                AsyncEnumerableEx.Never<int>(),
+                AsyncEnumerableEx.Never<int>()
+            );
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                await xs.MoveNextAsync();
+
+                Assert.False(true, "Should not have gotten here");
+            }
+            catch (InvalidOperationException)
+            {
+                // we expect this
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Amb_Many_Last_GetAsyncEnumerator_Crashes()
+        {
+            var source = AsyncEnumerableEx.Amb(
+                AsyncEnumerableEx.Never<int>(),
+                AsyncEnumerableEx.Never<int>(),
+                new FailingGetAsyncEnumerator<int>()
+            );
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                await xs.MoveNextAsync();
+
+                Assert.False(true, "Should not have gotten here");
+            }
+            catch (InvalidOperationException)
+            {
+                // we expect this
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+
+        private class FailingGetAsyncEnumerator<T> : IAsyncEnumerable<T>
+        {
+            public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
+            {
+                throw new InvalidOperationException();
+            }
+        }
     }
 }

+ 17 - 39
Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Amb.cs

@@ -32,19 +32,8 @@ namespace System.Linq
                 // i.e., see Never()
                 //
 
-                var firstCancelToken = new CancellationTokenSource();
-                var secondCancelToken = new CancellationTokenSource();
-
-                //
-                // The incoming cancellationToken should still be able to cancel both
-                //
-
-                var bothRegistry = cancellationToken.Register(() =>
-                {
-                    firstCancelToken.Cancel();
-                    secondCancelToken.Cancel();
-                });
-
+                var firstCancelToken = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
+                var secondCancelToken = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
 
                 try
                 {
@@ -68,16 +57,18 @@ namespace System.Linq
                 }
                 catch
                 {
+                    secondCancelToken.Cancel();
+                    firstCancelToken.Cancel();
+
                     // NB: AwaitMoveNextAsyncAndDispose checks for null for both arguments, reducing the need for many null
                     //     checks over here.
 
                     var cleanup = new[]
                     {
-                        AwaitMoveNextAsyncAndDispose(secondMoveNext, secondEnumerator),
-                        AwaitMoveNextAsyncAndDispose(firstMoveNext, firstEnumerator)
+                        AwaitMoveNextAsyncAndDispose(secondMoveNext, secondEnumerator, secondCancelToken.Token),
+                        AwaitMoveNextAsyncAndDispose(firstMoveNext, firstEnumerator, firstCancelToken.Token)
                     };
 
-                    bothRegistry.Dispose();
 
                     await Task.WhenAll(cleanup).ConfigureAwait(false);
 
@@ -105,13 +96,13 @@ namespace System.Linq
                 {
                     winner = firstEnumerator;
                     secondCancelToken.Cancel();
-                    disposeLoser = AwaitMoveNextAsyncAndDispose(secondMoveNext, secondEnumerator);
+                    disposeLoser = AwaitMoveNextAsyncAndDispose(secondMoveNext, secondEnumerator, secondCancelToken.Token);
                 }
                 else
                 {
                     winner = secondEnumerator;
                     firstCancelToken.Cancel();
-                    disposeLoser = AwaitMoveNextAsyncAndDispose(firstMoveNext, firstEnumerator);
+                    disposeLoser = AwaitMoveNextAsyncAndDispose(firstMoveNext, firstEnumerator, firstCancelToken.Token);
                 }
 
                 try
@@ -133,8 +124,6 @@ namespace System.Linq
                 }
                 finally
                 {
-                    bothRegistry.Dispose();
-
                     //
                     // REVIEW: This behavior differs from the original implementation in that we never discard any in flight
                     //         asynchronous operations. If an exception occurs while enumerating the winner, it can be
@@ -171,15 +160,8 @@ namespace System.Linq
                 var individualTokenSources = new CancellationTokenSource[n];
                 for (var i = 0; i < n; i++)
                 {
-                    individualTokenSources[i] = new CancellationTokenSource();
+                    individualTokenSources[i] = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
                 }
-                var allIndividualDispose = cancellationToken.Register(() =>
-                {
-                    foreach (var tokenSource in individualTokenSources)
-                    {
-                        tokenSource.Cancel();
-                    }
-                });
 
                 try
                 {
@@ -195,14 +177,12 @@ namespace System.Linq
                 {
                     var cleanup = new Task[n];
 
-                    for (var i = 0; i < n; i++)
+                    for (var i = n - 1; i >= 0; i--)
                     {
-                        cleanup[i] = AwaitMoveNextAsyncAndDispose(moveNexts[i], enumerators[i]);
-
                         individualTokenSources[i].Cancel();
-                    }
 
-                    allIndividualDispose.Dispose();
+                        cleanup[i] = AwaitMoveNextAsyncAndDispose(moveNexts[i], enumerators[i], individualTokenSources[i].Token);
+                    }
 
                     await Task.WhenAll(cleanup).ConfigureAwait(false);
 
@@ -227,12 +207,12 @@ namespace System.Linq
 
                 var loserCleanupTasks = new List<Task>(n - 1);
 
-                for (var i = 0; i < n; i++)
+                for (var i = n - 1; i >= 0; i--)
                 {
                     if (i != winnerIndex)
                     {
                         individualTokenSources[i].Cancel();
-                        var loserCleanupTask = AwaitMoveNextAsyncAndDispose(moveNexts[i], enumerators[i]);
+                        var loserCleanupTask = AwaitMoveNextAsyncAndDispose(moveNexts[i], enumerators[i], individualTokenSources[i].Token);
                         loserCleanupTasks.Add(loserCleanupTask);
                     }
                 }
@@ -258,8 +238,6 @@ namespace System.Linq
                 }
                 finally
                 {
-                    allIndividualDispose.Dispose();
-
                     await cleanupLosers.ConfigureAwait(false);
                 }
             }
@@ -273,7 +251,7 @@ namespace System.Linq
             return Amb(sources.ToArray());
         }
 
-        private static async Task AwaitMoveNextAsyncAndDispose<T>(Task<bool>? moveNextAsync, IAsyncEnumerator<T>? enumerator)
+        private static async Task AwaitMoveNextAsyncAndDispose<T>(Task<bool>? moveNextAsync, IAsyncEnumerator<T>? enumerator, CancellationToken token)
         {
             if (enumerator != null)
             {
@@ -285,7 +263,7 @@ namespace System.Linq
                         {
                             await moveNextAsync.ConfigureAwait(false);
                         }
-                        catch (TaskCanceledException)
+                        catch (TaskCanceledException tce) // when (tce.CancellationToken == token)
                         {
                             // ignored because of cancelling the non-winners
                         }