1
0
Эх сурвалжийг харах

Add Amb tests, fix Amb not canceling the losers

akarnokd 6 жил өмнө
parent
commit
6843410fc2

+ 232 - 0
Ix.NET/Source/System.Interactive.Async.Tests/System/Linq/Operators/Amb.cs

@@ -0,0 +1,232 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the Apache 2.0 License.
+// See the LICENSE file in the project root for more information. 
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading.Tasks;
+using Xunit;
+
+namespace Tests
+{
+    public class Amb : AsyncEnumerableExTests
+    {
+        [Fact]
+        public void Amb_Null()
+        {
+            Assert.Throws<ArgumentNullException>(() => AsyncEnumerableEx.Amb(default, Return42));
+            Assert.Throws<ArgumentNullException>(() => AsyncEnumerableEx.Amb(Return42, default));
+        }
+
+        [Fact]
+        public async Task Amb_First_Wins()
+        {
+            var source = AsyncEnumerable.Range(1, 5).Amb(AsyncEnumerableEx.Never<int>());
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                for (var i = 1; i <= 5; i++)
+                {
+                    Assert.True(await xs.MoveNextAsync());
+                    Assert.Equal(i, xs.Current);
+                }
+
+                Assert.False(await xs.MoveNextAsync());
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Amb_First_Wins_Alt()
+        {
+            var source = AsyncEnumerable.Range(1, 5).Amb(AsyncEnumerable.Range(1, 5).SelectAwait(async v =>
+            {
+                await Task.Delay(500);
+                return v;
+            }));
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                for (var i = 1; i <= 5; i++)
+                {
+                    Assert.True(await xs.MoveNextAsync());
+                    Assert.Equal(i, xs.Current);
+                }
+
+                Assert.False(await xs.MoveNextAsync());
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Amb_Second_Wins()
+        {
+            var source = AsyncEnumerableEx.Never<int>().Amb(AsyncEnumerable.Range(1, 5));
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                for (var i = 1; i <= 5; i++)
+                {
+                    Assert.True(await xs.MoveNextAsync());
+                    Assert.Equal(i, xs.Current);
+                }
+
+                Assert.False(await xs.MoveNextAsync());
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Amb_Second_Wins_Alt()
+        {
+            var source = AsyncEnumerable.Range(1, 5).SelectAwait(async v =>
+            {
+                await Task.Delay(500);
+                return v;
+            }).Amb(AsyncEnumerable.Range(6, 5));
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                for (var i = 1; i <= 5; i++)
+                {
+                    Assert.True(await xs.MoveNextAsync());
+                    Assert.Equal(i + 5, xs.Current);
+                }
+
+                Assert.False(await xs.MoveNextAsync());
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Amb_Many_First_Wins()
+        {
+            var source = AsyncEnumerableEx.Amb(
+                AsyncEnumerable.Range(1, 5),
+                AsyncEnumerableEx.Never<int>(),
+                AsyncEnumerableEx.Never<int>()
+            );
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                for (var i = 1; i <= 5; i++)
+                {
+                    Assert.True(await xs.MoveNextAsync());
+                    Assert.Equal(i, xs.Current);
+                }
+
+                Assert.False(await xs.MoveNextAsync());
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Amb_Many_Last_Wins()
+        {
+            var source = AsyncEnumerableEx.Amb(
+                AsyncEnumerableEx.Never<int>(),
+                AsyncEnumerableEx.Never<int>(),
+                AsyncEnumerable.Range(1, 5)
+            );
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                for (var i = 1; i <= 5; i++)
+                {
+                    Assert.True(await xs.MoveNextAsync());
+                    Assert.Equal(i, xs.Current);
+                }
+
+                Assert.False(await xs.MoveNextAsync());
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Amb_Many_Enum_First_Wins()
+        {
+            var source = AsyncEnumerableEx.Amb(new[] {
+                    AsyncEnumerable.Range(1, 5),
+                    AsyncEnumerableEx.Never<int>(),
+                    AsyncEnumerableEx.Never<int>()
+                }.AsEnumerable()
+            );
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                for (var i = 1; i <= 5; i++)
+                {
+                    Assert.True(await xs.MoveNextAsync());
+                    Assert.Equal(i, xs.Current);
+                }
+
+                Assert.False(await xs.MoveNextAsync());
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Amb_Many_Enum_Last_Wins()
+        {
+            var source = AsyncEnumerableEx.Amb(new[] {
+                    AsyncEnumerableEx.Never<int>(),
+                    AsyncEnumerableEx.Never<int>(),
+                    AsyncEnumerable.Range(1, 5)
+                }.AsEnumerable()
+            );
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                for (var i = 1; i <= 5; i++)
+                {
+                    Assert.True(await xs.MoveNextAsync());
+                    Assert.Equal(i, xs.Current);
+                }
+
+                Assert.False(await xs.MoveNextAsync());
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+    }
+}

+ 56 - 4
Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Amb.cs

@@ -27,6 +27,25 @@ namespace System.Linq
                 Task<bool>? firstMoveNext = null;
                 Task<bool>? secondMoveNext = null;
 
+                //
+                // We need separate tokens for each source so that the non-winner can get disposed and unblocked
+                // i.e., see Never()
+                //
+
+                var firstCancelToken = new CancellationTokenSource();
+                var secondCancelToken = new CancellationTokenSource();
+
+                //
+                // The incoming cancellationToken should still be able to cancel both
+                //
+
+                var bothRegistry = cancellationToken.Register(() =>
+                {
+                    firstCancelToken.Cancel();
+                    secondCancelToken.Cancel();
+                });
+
+
                 try
                 {
                     //
@@ -36,7 +55,7 @@ namespace System.Linq
                     //         adding a WhenAny combinator that does exactly that. We can even avoid calling AsTask.
                     //
 
-                    firstEnumerator = first.GetAsyncEnumerator(cancellationToken);
+                    firstEnumerator = first.GetAsyncEnumerator(firstCancelToken.Token);
                     firstMoveNext = firstEnumerator.MoveNextAsync().AsTask();
 
                     //
@@ -44,7 +63,7 @@ namespace System.Linq
                     //         overload which performs GetAsyncEnumerator/MoveNextAsync in pairs, rather than phased.
                     //
 
-                    secondEnumerator = second.GetAsyncEnumerator(cancellationToken);
+                    secondEnumerator = second.GetAsyncEnumerator(secondCancelToken.Token);
                     secondMoveNext = secondEnumerator.MoveNextAsync().AsTask();
                 }
                 catch
@@ -58,6 +77,8 @@ namespace System.Linq
                         AwaitMoveNextAsyncAndDispose(firstMoveNext, firstEnumerator)
                     };
 
+                    bothRegistry.Dispose();
+
                     await Task.WhenAll(cleanup).ConfigureAwait(false);
 
                     throw;
@@ -83,11 +104,13 @@ namespace System.Linq
                 if (moveNextWinner == firstMoveNext)
                 {
                     winner = firstEnumerator;
+                    secondCancelToken.Cancel();
                     disposeLoser = AwaitMoveNextAsyncAndDispose(secondMoveNext, secondEnumerator);
                 }
                 else
                 {
                     winner = secondEnumerator;
+                    firstCancelToken.Cancel();
                     disposeLoser = AwaitMoveNextAsyncAndDispose(firstMoveNext, firstEnumerator);
                 }
 
@@ -110,6 +133,8 @@ namespace System.Linq
                 }
                 finally
                 {
+                    bothRegistry.Dispose();
+
                     //
                     // REVIEW: This behavior differs from the original implementation in that we never discard any in flight
                     //         asynchronous operations. If an exception occurs while enumerating the winner, it can be
@@ -143,12 +168,24 @@ namespace System.Linq
 
                 var enumerators = new IAsyncEnumerator<TSource>[n];
                 var moveNexts = new Task<bool>[n];
+                var individualTokenSources = new CancellationTokenSource[n];
+                for (var i = 0; i < n; i++)
+                {
+                    individualTokenSources[i] = new CancellationTokenSource();
+                }
+                var allIndividualDispose = cancellationToken.Register(() =>
+                {
+                    foreach (var tokenSource in individualTokenSources)
+                    {
+                        tokenSource.Cancel();
+                    }
+                });
 
                 try
                 {
                     for (var i = 0; i < n; i++)
                     {
-                        var enumerator = sources[i].GetAsyncEnumerator(cancellationToken);
+                        var enumerator = sources[i].GetAsyncEnumerator(individualTokenSources[i].Token);
 
                         enumerators[i] = enumerator;
                         moveNexts[i] = enumerator.MoveNextAsync().AsTask();
@@ -161,9 +198,14 @@ namespace System.Linq
                     for (var i = 0; i < n; i++)
                     {
                         cleanup[i] = AwaitMoveNextAsyncAndDispose(moveNexts[i], enumerators[i]);
+
+                        individualTokenSources[i].Cancel();
                     }
 
                     await Task.WhenAll(cleanup).ConfigureAwait(false);
+
+                    allIndividualDispose.Dispose();
+
                     throw;
                 }
 
@@ -189,6 +231,7 @@ namespace System.Linq
                 {
                     if (i != winnerIndex)
                     {
+                        individualTokenSources[i].Cancel();
                         var loserCleanupTask = AwaitMoveNextAsyncAndDispose(moveNexts[i], enumerators[i]);
                         loserCleanupTasks.Add(loserCleanupTask);
                     }
@@ -215,6 +258,8 @@ namespace System.Linq
                 }
                 finally
                 {
+                    allIndividualDispose.Dispose();
+
                     await cleanupLosers.ConfigureAwait(false);
                 }
             }
@@ -236,7 +281,14 @@ namespace System.Linq
                 {
                     if (moveNextAsync != null)
                     {
-                        await moveNextAsync.ConfigureAwait(false);
+                        try
+                        {
+                            await moveNextAsync.ConfigureAwait(false);
+                        }
+                        catch (TaskCanceledException)
+                        {
+                            // ignored because of cancelling the non-winners
+                        }
                     }
                 }
             }