瀏覽代碼

Cleaning up Count and LongCount.

Bart De Smet 7 年之前
父節點
當前提交
a6e71784ca

+ 39 - 24
Ix.NET/Source/System.Linq.Async/System/Linq/Operators/Count.cs

@@ -10,70 +10,85 @@ namespace System.Linq
 {
     public static partial class AsyncEnumerable
     {
-        public static Task<int> Count<TSource>(this IAsyncEnumerable<TSource> source, CancellationToken cancellationToken)
+        public static Task<int> Count<TSource>(this IAsyncEnumerable<TSource> source)
         {
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
 
-            if (source is ICollection<TSource> collection)
-            {
-                return Task.FromResult(collection.Count);
-            }
-
-            if (source is IAsyncIListProvider<TSource> listProv)
-            {
-                return listProv.GetCountAsync(false, cancellationToken);
-            }
-
-            return source.Aggregate(0, (c, _) => checked(c + 1), cancellationToken);
+            return CountCore(source, CancellationToken.None);
         }
 
-        public static Task<int> Count<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate, CancellationToken cancellationToken)
+        public static Task<int> Count<TSource>(this IAsyncEnumerable<TSource> source, CancellationToken cancellationToken)
         {
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
-            if (predicate == null)
-                throw new ArgumentNullException(nameof(predicate));
 
-            return source.Where(predicate).Aggregate(0, (c, _) => checked(c + 1), cancellationToken);
+            return CountCore(source, cancellationToken);
         }
 
-        public static Task<int> Count<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate, CancellationToken cancellationToken)
+        public static Task<int> Count<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
         {
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
             if (predicate == null)
                 throw new ArgumentNullException(nameof(predicate));
 
-            return source.Where(predicate).Aggregate(0, (c, _) => checked(c + 1), cancellationToken);
+            return CountCore(source, predicate, CancellationToken.None);
         }
 
-        public static Task<int> Count<TSource>(this IAsyncEnumerable<TSource> source)
+        public static Task<int> Count<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate, CancellationToken cancellationToken)
         {
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
+            if (predicate == null)
+                throw new ArgumentNullException(nameof(predicate));
 
-            return Count(source, CancellationToken.None);
+            return CountCore(source, predicate, cancellationToken);
         }
 
-        public static Task<int> Count<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
+        public static Task<int> Count<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate)
         {
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
             if (predicate == null)
                 throw new ArgumentNullException(nameof(predicate));
 
-            return Count(source, predicate, CancellationToken.None);
+            return CountCore(source, predicate, CancellationToken.None);
         }
 
-        public static Task<int> Count<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate)
+        public static Task<int> Count<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate, CancellationToken cancellationToken)
         {
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
             if (predicate == null)
                 throw new ArgumentNullException(nameof(predicate));
 
-            return Count(source, predicate, CancellationToken.None);
+            return CountCore(source, predicate, cancellationToken);
+        }
+
+        private static Task<int> CountCore<TSource>(IAsyncEnumerable<TSource> source, CancellationToken cancellationToken)
+        {
+            if (source is ICollection<TSource> collection)
+            {
+                return Task.FromResult(collection.Count);
+            }
+
+            if (source is IAsyncIListProvider<TSource> listProv)
+            {
+                return listProv.GetCountAsync(onlyIfCheap: false, cancellationToken);
+            }
+
+            return source.Aggregate(0, (c, _) => checked(c + 1), cancellationToken);
+        }
+
+        private static Task<int> CountCore<TSource>(IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate, CancellationToken cancellationToken)
+        {
+            return source.Where(predicate).Aggregate(0, (c, _) => checked(c + 1), cancellationToken);
+        }
+
+        private static Task<int> CountCore<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate, CancellationToken cancellationToken)
+        {
+            return source.Where(predicate).Aggregate(0, (c, _) => checked(c + 1), cancellationToken);
         }
     }
 }

+ 29 - 14
Ix.NET/Source/System.Linq.Async/System/Linq/Operators/LongCount.cs

@@ -10,60 +10,75 @@ namespace System.Linq
 {
     public static partial class AsyncEnumerable
     {
-        public static Task<long> LongCount<TSource>(this IAsyncEnumerable<TSource> source, CancellationToken cancellationToken)
+        public static Task<long> LongCount<TSource>(this IAsyncEnumerable<TSource> source)
         {
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
 
-            return source.Aggregate(0L, (c, _) => checked(c + 1), cancellationToken);
+            return LongCountCore(source, CancellationToken.None);
         }
 
-        public static Task<long> LongCount<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate, CancellationToken cancellationToken)
+        public static Task<long> LongCount<TSource>(this IAsyncEnumerable<TSource> source, CancellationToken cancellationToken)
         {
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
-            if (predicate == null)
-                throw new ArgumentNullException(nameof(predicate));
 
-            return source.Where(predicate).Aggregate(0L, (c, _) => checked(c + 1), cancellationToken);
+            return LongCountCore(source, cancellationToken);
         }
 
-        public static Task<long> LongCount<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate, CancellationToken cancellationToken)
+        public static Task<long> LongCount<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
         {
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
             if (predicate == null)
                 throw new ArgumentNullException(nameof(predicate));
 
-            return source.Where(predicate).Aggregate(0L, (c, _) => checked(c + 1), cancellationToken);
+            return LongCountCore(source, predicate, CancellationToken.None);
         }
 
-        public static Task<long> LongCount<TSource>(this IAsyncEnumerable<TSource> source)
+        public static Task<long> LongCount<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate, CancellationToken cancellationToken)
         {
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
+            if (predicate == null)
+                throw new ArgumentNullException(nameof(predicate));
 
-            return LongCount(source, CancellationToken.None);
+            return LongCountCore(source, predicate, cancellationToken);
         }
 
-        public static Task<long> LongCount<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
+        public static Task<long> LongCount<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate)
         {
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
             if (predicate == null)
                 throw new ArgumentNullException(nameof(predicate));
 
-            return LongCount(source, predicate, CancellationToken.None);
+            return LongCountCore(source, predicate, CancellationToken.None);
         }
 
-        public static Task<long> LongCount<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate)
+        public static Task<long> LongCount<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate, CancellationToken cancellationToken)
         {
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
             if (predicate == null)
                 throw new ArgumentNullException(nameof(predicate));
 
-            return LongCount(source, predicate, CancellationToken.None);
+            return LongCountCore(source, predicate, cancellationToken);
+        }
+
+        private static Task<long> LongCountCore<TSource>(IAsyncEnumerable<TSource> source, CancellationToken cancellationToken)
+        {
+            return source.Aggregate(0L, (c, _) => checked(c + 1), cancellationToken);
+        }
+
+        private static Task<long> LongCountCore<TSource>(IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate, CancellationToken cancellationToken)
+        {
+            return source.Where(predicate).Aggregate(0L, (c, _) => checked(c + 1), cancellationToken);
+        }
+
+        private static Task<long> LongCountCore<TSource>(IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate, CancellationToken cancellationToken)
+        {
+            return source.Where(predicate).Aggregate(0L, (c, _) => checked(c + 1), cancellationToken);
         }
     }
 }