Explorar o código

Fix for #183 - ForEachAsync possibly failing to dipose enumerator

Oren Novotny %!s(int64=9) %!d(string=hai) anos
pai
achega
7750c3ea84

+ 7 - 34
Ix.NET/Source/System.Interactive.Async/AsyncEnumerable.Single.cs

@@ -1654,49 +1654,22 @@ namespace System.Linq
             source.ForEachAsync(action, cancellationToken).Wait(cancellationToken);
         }
 
-        public static Task ForEachAsync<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource, int> action, CancellationToken cancellationToken)
+        public static async Task ForEachAsync<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource, int> action, CancellationToken cancellationToken)
         {
             if (source == null)
                 throw new ArgumentNullException("source");
             if (action == null)
                 throw new ArgumentNullException("action");
 
-            var tcs = new TaskCompletionSource<bool>();
 
-            var e = source.GetEnumerator();
-
-            var i = 0;
-
-            var f = default(Action<CancellationToken>);
-            f = ct =>
+            using (var asyncEnumerator = source.AsAsyncEnumerable().GetEnumerator())
             {
-                e.MoveNext(ct).Then(t =>
+                var i = 0;
+                while (await asyncEnumerator.MoveNext(cancellationToken))
                 {
-                    t.Handle(tcs, res =>
-                    {
-                        if (res)
-                        {
-                            try
-                            {
-                                action(e.Current, i++);
-                            }
-                            catch (Exception ex)
-                            {
-                                tcs.TrySetException(ex);
-                                return;
-                            }
-
-                            f(ct);
-                        }
-                        else
-                            tcs.TrySetResult(true);
-                    });
-                });
-            };
-
-            f(cancellationToken);
-
-            return tcs.Task.UsingEnumerator(e);
+                    action(asyncEnumerator.Current, i++);
+                }
+            }
         }
 
         public static IAsyncEnumerable<TSource> Repeat<TSource>(this IAsyncEnumerable<TSource> source, int count)

+ 10 - 9
Ix.NET/Source/Tests/AsyncTests.Single.cs

@@ -9,6 +9,7 @@ using System.Linq;
 using System.Text;
 using Xunit;
 using System.Threading;
+using System.Threading.Tasks;
 
 namespace Tests
 {
@@ -568,17 +569,17 @@ namespace Tests
         }
 
         [Fact]
-        public void ForEachAsync_Null()
+        public async Task ForEachAsync_Null()
         {
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ForEachAsync<int>(null, x => { }));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ForEachAsync<int>(AsyncEnumerable.Return(42), default(Action<int>)));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ForEachAsync<int>(null, (x, i) => { }));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ForEachAsync<int>(AsyncEnumerable.Return(42), default(Action<int, int>)));
+            await AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ForEachAsync<int>(null, x => { }));
+            await AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ForEachAsync<int>(AsyncEnumerable.Return(42), default(Action<int>)));
+            await AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ForEachAsync<int>(null, (x, i) => { }));
+            await AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ForEachAsync<int>(AsyncEnumerable.Return(42), default(Action<int, int>)));
 
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ForEachAsync<int>(null, x => { }, CancellationToken.None));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ForEachAsync<int>(AsyncEnumerable.Return(42), default(Action<int>), CancellationToken.None));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ForEachAsync<int>(null, (x, i) => { }, CancellationToken.None));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ForEachAsync<int>(AsyncEnumerable.Return(42), default(Action<int, int>), CancellationToken.None));
+            await AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ForEachAsync<int>(null, x => { }, CancellationToken.None));
+            await AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ForEachAsync<int>(AsyncEnumerable.Return(42), default(Action<int>), CancellationToken.None));
+            await AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ForEachAsync<int>(null, (x, i) => { }, CancellationToken.None));
+            await AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ForEachAsync<int>(AsyncEnumerable.Return(42), default(Action<int, int>), CancellationToken.None));
         }
 
         [Fact]

+ 7 - 0
Ix.NET/Source/Tests/AsyncTests.cs

@@ -6,6 +6,7 @@
 using System;
 using System.Collections.Generic;
 using System.Linq;
+using System.Threading.Tasks;
 using Xunit;
 
 namespace Tests
@@ -19,6 +20,12 @@ namespace Tests
             Assert.Throws<E>(a);
         }
 
+        public Task AssertThrows<E>(Func<Task> func)
+            where E : Exception
+        {
+            return Assert.ThrowsAsync<E>(func);
+        }
+
         public void AssertThrows<E>(Action a, Func<E, bool> assert)
             where E : Exception
         {