Browse Source

Minor optimizations in Aggregate.

Bart De Smet 7 years ago
parent
commit
955246cfb3
1 changed files with 29 additions and 8 deletions
  1. 29 8
      Ix.NET/Source/System.Linq.Async/System/Linq/Operators/Aggregate.cs

+ 29 - 8
Ix.NET/Source/System.Linq.Async/System/Linq/Operators/Aggregate.cs

@@ -17,7 +17,7 @@ namespace System.Linq
             if (accumulator == null)
                 throw new ArgumentNullException(nameof(accumulator));
 
-            return Aggregate(source, accumulator, CancellationToken.None);
+            return AggregateCore(source, accumulator, CancellationToken.None);
         }
 
         public static Task<TSource> Aggregate<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, TSource, TSource> accumulator, CancellationToken cancellationToken)
@@ -37,7 +37,7 @@ namespace System.Linq
             if (accumulator == null)
                 throw new ArgumentNullException(nameof(accumulator));
 
-            return Aggregate(source, accumulator, CancellationToken.None);
+            return AggregateCore(source, accumulator, CancellationToken.None);
         }
 
         public static Task<TSource> Aggregate<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, TSource, Task<TSource>> accumulator, CancellationToken cancellationToken)
@@ -57,7 +57,7 @@ namespace System.Linq
             if (accumulator == null)
                 throw new ArgumentNullException(nameof(accumulator));
 
-            return Aggregate(source, seed, accumulator, CancellationToken.None);
+            return AggregateCore(source, seed, accumulator, x => x, CancellationToken.None);
         }
 
         public static Task<TAccumulate> Aggregate<TSource, TAccumulate>(this IAsyncEnumerable<TSource> source, TAccumulate seed, Func<TAccumulate, TSource, TAccumulate> accumulator, CancellationToken cancellationToken)
@@ -67,7 +67,7 @@ namespace System.Linq
             if (accumulator == null)
                 throw new ArgumentNullException(nameof(accumulator));
 
-            return source.Aggregate(seed, accumulator, x => x, cancellationToken);
+            return AggregateCore(source, seed, accumulator, x => x, cancellationToken);
         }
 
         public static Task<TAccumulate> Aggregate<TSource, TAccumulate>(this IAsyncEnumerable<TSource> source, TAccumulate seed, Func<TAccumulate, TSource, Task<TAccumulate>> accumulator)
@@ -77,7 +77,7 @@ namespace System.Linq
             if (accumulator == null)
                 throw new ArgumentNullException(nameof(accumulator));
 
-            return Aggregate(source, seed, accumulator, CancellationToken.None);
+            return AggregateCore(source, seed, accumulator, CancellationToken.None);
         }
 
         public static Task<TAccumulate> Aggregate<TSource, TAccumulate>(this IAsyncEnumerable<TSource> source, TAccumulate seed, Func<TAccumulate, TSource, Task<TAccumulate>> accumulator, CancellationToken cancellationToken)
@@ -87,7 +87,7 @@ namespace System.Linq
             if (accumulator == null)
                 throw new ArgumentNullException(nameof(accumulator));
 
-            return source.Aggregate(seed, accumulator, x => Task.FromResult(x), cancellationToken);
+            return AggregateCore(source, seed, accumulator, cancellationToken);
         }
 
         public static Task<TResult> Aggregate<TSource, TAccumulate, TResult>(this IAsyncEnumerable<TSource> source, TAccumulate seed, Func<TAccumulate, TSource, TAccumulate> accumulator, Func<TAccumulate, TResult> resultSelector)
@@ -99,7 +99,7 @@ namespace System.Linq
             if (resultSelector == null)
                 throw new ArgumentNullException(nameof(resultSelector));
 
-            return Aggregate(source, seed, accumulator, resultSelector, CancellationToken.None);
+            return AggregateCore(source, seed, accumulator, resultSelector, CancellationToken.None);
         }
 
         public static Task<TResult> Aggregate<TSource, TAccumulate, TResult>(this IAsyncEnumerable<TSource> source, TAccumulate seed, Func<TAccumulate, TSource, TAccumulate> accumulator, Func<TAccumulate, TResult> resultSelector, CancellationToken cancellationToken)
@@ -123,7 +123,7 @@ namespace System.Linq
             if (resultSelector == null)
                 throw new ArgumentNullException(nameof(resultSelector));
 
-            return Aggregate(source, seed, accumulator, resultSelector, CancellationToken.None);
+            return AggregateCore(source, seed, accumulator, resultSelector, CancellationToken.None);
         }
 
         public static Task<TResult> Aggregate<TSource, TAccumulate, TResult>(this IAsyncEnumerable<TSource> source, TAccumulate seed, Func<TAccumulate, TSource, Task<TAccumulate>> accumulator, Func<TAccumulate, Task<TResult>> resultSelector, CancellationToken cancellationToken)
@@ -185,6 +185,27 @@ namespace System.Linq
             return acc;
         }
 
+        private static async Task<TResult> AggregateCore<TSource, TResult>(IAsyncEnumerable<TSource> source, TResult seed, Func<TResult, TSource, Task<TResult>> accumulator, CancellationToken cancellationToken)
+        {
+            var acc = seed;
+
+            var e = source.GetAsyncEnumerator(cancellationToken);
+
+            try
+            {
+                while (await e.MoveNextAsync().ConfigureAwait(false))
+                {
+                    acc = await accumulator(acc, e.Current).ConfigureAwait(false);
+                }
+            }
+            finally
+            {
+                await e.DisposeAsync().ConfigureAwait(false);
+            }
+
+            return acc;
+        }
+
         private static async Task<TResult> AggregateCore<TSource, TAccumulate, TResult>(IAsyncEnumerable<TSource> source, TAccumulate seed, Func<TAccumulate, TSource, Task<TAccumulate>> accumulator, Func<TAccumulate, Task<TResult>> resultSelector, CancellationToken cancellationToken)
         {
             var acc = seed;