Browse Source

Async variants of Any and All.

Bart De Smet 8 years ago
parent
commit
f085fba7f3

+ 82 - 0
Ix.NET/Source/System.Interactive.Async/AnyAll.cs

@@ -22,6 +22,16 @@ namespace System.Linq
             return All(source, predicate, CancellationToken.None);
         }
 
+        public static Task<bool> All<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (predicate == null)
+                throw new ArgumentNullException(nameof(predicate));
+
+            return All(source, predicate, CancellationToken.None);
+        }
+
         public static Task<bool> All<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate, CancellationToken cancellationToken)
         {
             if (source == null)
@@ -32,6 +42,16 @@ namespace System.Linq
             return All_(source, predicate, cancellationToken);
         }
 
+        public static Task<bool> All<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate, CancellationToken cancellationToken)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (predicate == null)
+                throw new ArgumentNullException(nameof(predicate));
+
+            return All_(source, predicate, cancellationToken);
+        }
+
         public static Task<bool> Any<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
         {
             if (source == null)
@@ -42,6 +62,16 @@ namespace System.Linq
             return Any(source, predicate, CancellationToken.None);
         }
 
+        public static Task<bool> Any<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (predicate == null)
+                throw new ArgumentNullException(nameof(predicate));
+
+            return Any(source, predicate, CancellationToken.None);
+        }
+
         public static Task<bool> Any<TSource>(this IAsyncEnumerable<TSource> source)
         {
             if (source == null)
@@ -60,6 +90,16 @@ namespace System.Linq
             return Any_(source, predicate, cancellationToken);
         }
 
+        public static Task<bool> Any<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate, CancellationToken cancellationToken)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (predicate == null)
+                throw new ArgumentNullException(nameof(predicate));
+
+            return Any_(source, predicate, cancellationToken);
+        }
+
         public static async Task<bool> Any<TSource>(this IAsyncEnumerable<TSource> source, CancellationToken cancellationToken)
         {
             if (source == null)
@@ -98,6 +138,27 @@ namespace System.Linq
             return true;
         }
 
+        private static async Task<bool> All_<TSource>(IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate, CancellationToken cancellationToken)
+        {
+            var e = source.GetAsyncEnumerator();
+
+            try
+            {
+                while (await e.MoveNextAsync(cancellationToken)
+                              .ConfigureAwait(false))
+                {
+                    if (!await predicate(e.Current).ConfigureAwait(false))
+                        return false;
+                }
+            }
+            finally
+            {
+                await e.DisposeAsync().ConfigureAwait(false);
+            }
+
+            return true;
+        }
+
         private static async Task<bool> Any_<TSource>(IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate, CancellationToken cancellationToken)
         {
             var e = source.GetAsyncEnumerator();
@@ -118,5 +179,26 @@ namespace System.Linq
 
             return false;
         }
+
+        private static async Task<bool> Any_<TSource>(IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate, CancellationToken cancellationToken)
+        {
+            var e = source.GetAsyncEnumerator();
+
+            try
+            {
+                while (await e.MoveNextAsync(cancellationToken)
+                              .ConfigureAwait(false))
+                {
+                    if (await predicate(e.Current).ConfigureAwait(false))
+                        return true;
+                }
+            }
+            finally
+            {
+                await e.DisposeAsync().ConfigureAwait(false);
+            }
+
+            return false;
+        }
     }
 }

+ 6 - 6
Ix.NET/Source/Tests/AsyncTests.Aggregates.cs

@@ -230,10 +230,10 @@ namespace Tests
         public async Task All_Null()
         {
             await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.All<int>(null, x => true));
-            await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.All<int>(AsyncEnumerable.Return(42), null));
+            await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.All<int>(AsyncEnumerable.Return(42), default(Func<int, bool>)));
 
             await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.All<int>(null, x => true, CancellationToken.None));
-            await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.All<int>(AsyncEnumerable.Return(42), null, CancellationToken.None));
+            await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.All<int>(AsyncEnumerable.Return(42), default(Func<int, bool>), CancellationToken.None));
         }
 
         [Fact]
@@ -262,7 +262,7 @@ namespace Tests
         public void All4()
         {
             var ex = new Exception("Bang!");
-            var res = new[] { 2, 8, 4 }.ToAsyncEnumerable().All(x => { throw ex; });
+            var res = new[] { 2, 8, 4 }.ToAsyncEnumerable().All(new Func<int, bool>(x => { throw ex; }));
             AssertThrows<Exception>(() => res.Wait(WaitTimeoutMs), ex_ => ((AggregateException)ex_).Flatten().InnerExceptions.Single() == ex);
         }
 
@@ -271,11 +271,11 @@ namespace Tests
         {
             await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.Any<int>(null));
             await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.Any<int>(null, x => true));
-            await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.Any<int>(AsyncEnumerable.Return(42), null));
+            await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.Any<int>(AsyncEnumerable.Return(42), default(Func<int, bool>)));
 
             await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.Any<int>(null, CancellationToken.None));
             await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.Any<int>(null, x => true, CancellationToken.None));
-            await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.Any<int>(AsyncEnumerable.Return(42), null, CancellationToken.None));
+            await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.Any<int>(AsyncEnumerable.Return(42), default(Func<int, bool>), CancellationToken.None));
         }
 
         [Fact]
@@ -304,7 +304,7 @@ namespace Tests
         public void Any4()
         {
             var ex = new Exception("Bang!");
-            var res = new[] { 2, 8, 4 }.ToAsyncEnumerable().Any(x => { throw ex; });
+            var res = new[] { 2, 8, 4 }.ToAsyncEnumerable().Any(new Func<int, bool>(x => { throw ex; }));
             AssertThrows<Exception>(() => res.Wait(WaitTimeoutMs), ex_ => ((AggregateException)ex_).Flatten().InnerExceptions.Single() == ex);
         }