浏览代码

Async variants of Count and LongCount.

Bart De Smet 8 年之前
父节点
当前提交
07c8d8c8cf
共有 2 个文件被更改,包括 48 次插入6 次删除
  1. 42 0
      Ix.NET/Source/System.Interactive.Async/Count.cs
  2. 6 6
      Ix.NET/Source/Tests/AsyncTests.Aggregates.cs

+ 42 - 0
Ix.NET/Source/System.Interactive.Async/Count.cs

@@ -41,6 +41,17 @@ namespace System.Linq
                          .Aggregate(0, (c, _) => checked(c + 1), cancellationToken);
                          .Aggregate(0, (c, _) => checked(c + 1), cancellationToken);
         }
         }
 
 
+        public static Task<int> Count<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate, CancellationToken cancellationToken)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (predicate == null)
+                throw new ArgumentNullException(nameof(predicate));
+
+            return source.Where(predicate)
+                         .Aggregate(0, (c, _) => checked(c + 1), cancellationToken);
+        }
+
         public static Task<int> Count<TSource>(this IAsyncEnumerable<TSource> source)
         public static Task<int> Count<TSource>(this IAsyncEnumerable<TSource> source)
         {
         {
             if (source == null)
             if (source == null)
@@ -59,6 +70,16 @@ namespace System.Linq
             return Count(source, predicate, CancellationToken.None);
             return Count(source, predicate, CancellationToken.None);
         }
         }
 
 
+        public static Task<int> Count<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (predicate == null)
+                throw new ArgumentNullException(nameof(predicate));
+
+            return Count(source, predicate, CancellationToken.None);
+        }
+
         public static Task<long> LongCount<TSource>(this IAsyncEnumerable<TSource> source, CancellationToken cancellationToken)
         public static Task<long> LongCount<TSource>(this IAsyncEnumerable<TSource> source, CancellationToken cancellationToken)
         {
         {
             if (source == null)
             if (source == null)
@@ -78,6 +99,17 @@ namespace System.Linq
                          .Aggregate(0L, (c, _) => checked(c + 1), cancellationToken);
                          .Aggregate(0L, (c, _) => checked(c + 1), cancellationToken);
         }
         }
 
 
+        public static Task<long> LongCount<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate, CancellationToken cancellationToken)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (predicate == null)
+                throw new ArgumentNullException(nameof(predicate));
+
+            return source.Where(predicate)
+                         .Aggregate(0L, (c, _) => checked(c + 1), cancellationToken);
+        }
+
         public static Task<long> LongCount<TSource>(this IAsyncEnumerable<TSource> source)
         public static Task<long> LongCount<TSource>(this IAsyncEnumerable<TSource> source)
         {
         {
             if (source == null)
             if (source == null)
@@ -95,5 +127,15 @@ namespace System.Linq
 
 
             return LongCount(source, predicate, CancellationToken.None);
             return LongCount(source, predicate, CancellationToken.None);
         }
         }
+
+        public static Task<long> LongCount<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (predicate == null)
+                throw new ArgumentNullException(nameof(predicate));
+
+            return LongCount(source, predicate, CancellationToken.None);
+        }
     }
     }
 }
 }

+ 6 - 6
Ix.NET/Source/Tests/AsyncTests.Aggregates.cs

@@ -159,11 +159,11 @@ namespace Tests
         {
         {
             await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.Count<int>(null));
             await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.Count<int>(null));
             await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.Count<int>(null, x => true));
             await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.Count<int>(null, x => true));
-            await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.Count<int>(AsyncEnumerable.Return(42), null));
+            await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.Count<int>(AsyncEnumerable.Return(42), default(Func<int, bool>)));
 
 
             await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.Count<int>(null, CancellationToken.None));
             await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.Count<int>(null, CancellationToken.None));
             await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.Count<int>(null, x => true, CancellationToken.None));
             await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.Count<int>(null, x => true, CancellationToken.None));
-            await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.Count<int>(AsyncEnumerable.Return(42), null, CancellationToken.None));
+            await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.Count<int>(AsyncEnumerable.Return(42), default(Func<int, bool>), CancellationToken.None));
         }
         }
 
 
         [Fact]
         [Fact]
@@ -186,7 +186,7 @@ namespace Tests
         public void Count3()
         public void Count3()
         {
         {
             var ex = new Exception("Bang!");
             var ex = new Exception("Bang!");
-            var ys = new[] { 1, 2, 3 }.ToAsyncEnumerable().Count(x => { throw ex; });
+            var ys = new[] { 1, 2, 3 }.ToAsyncEnumerable().Count(new Func<int, bool>(x => { throw ex; }));
             AssertThrows<Exception>(() => ys.Wait(WaitTimeoutMs), ex_ => ((AggregateException)ex_).Flatten().InnerExceptions.Single() == ex);
             AssertThrows<Exception>(() => ys.Wait(WaitTimeoutMs), ex_ => ((AggregateException)ex_).Flatten().InnerExceptions.Single() == ex);
         }
         }
 
 
@@ -195,11 +195,11 @@ namespace Tests
         {
         {
             await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.LongCount<int>(null));
             await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.LongCount<int>(null));
             await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.LongCount<int>(null, x => true));
             await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.LongCount<int>(null, x => true));
-            await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.LongCount<int>(AsyncEnumerable.Return(42), null));
+            await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.LongCount<int>(AsyncEnumerable.Return(42), default(Func<int, bool>)));
 
 
             await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.LongCount<int>(null, CancellationToken.None));
             await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.LongCount<int>(null, CancellationToken.None));
             await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.LongCount<int>(null, x => true, CancellationToken.None));
             await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.LongCount<int>(null, x => true, CancellationToken.None));
-            await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.LongCount<int>(AsyncEnumerable.Return(42), null, CancellationToken.None));
+            await Assert.ThrowsAsync<ArgumentNullException>(() => AsyncEnumerable.LongCount<int>(AsyncEnumerable.Return(42), default(Func<int, bool>), CancellationToken.None));
         }
         }
 
 
         [Fact]
         [Fact]
@@ -222,7 +222,7 @@ namespace Tests
         public void LongCount3()
         public void LongCount3()
         {
         {
             var ex = new Exception("Bang!");
             var ex = new Exception("Bang!");
-            var ys = new[] { 1, 2, 3 }.ToAsyncEnumerable().LongCount(x => { throw ex; });
+            var ys = new[] { 1, 2, 3 }.ToAsyncEnumerable().LongCount(new Func<int, bool>(x => { throw ex; }));
             AssertThrows<Exception>(() => ys.Wait(WaitTimeoutMs), ex_ => ((AggregateException)ex_).Flatten().InnerExceptions.Single() == ex);
             AssertThrows<Exception>(() => ys.Wait(WaitTimeoutMs), ex_ => ((AggregateException)ex_).Flatten().InnerExceptions.Single() == ex);
         }
         }