浏览代码

Use async iterators in Scan.

Bart De Smet 6 年之前
父节点
当前提交
24c8bc9ebf
共有 1 个文件被更改,包括 126 次插入0 次删除
  1. 126 0
      Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Scan.cs

+ 126 - 0
Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Scan.cs

@@ -11,6 +11,10 @@ namespace System.Linq
 {
     public static partial class AsyncEnumerableEx
     {
+        // NB: Implementations of Scan never yield the first element, unlike the behavior of Aggregate on a sequence with one
+        //     element, which returns the first element (or the seed if given an empty sequence). This is compatible with Rx
+        //     but one could argue whether it was the right default.
+
         public static IAsyncEnumerable<TSource> Scan<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, TSource, TSource> accumulator)
         {
             if (source == null)
@@ -18,7 +22,31 @@ namespace System.Linq
             if (accumulator == null)
                 throw Error.ArgumentNull(nameof(accumulator));
 
+#if USE_ASYNC_ITERATOR
+            return AsyncEnumerable.Create(Core);
+
+            async IAsyncEnumerator<TSource> Core(CancellationToken cancellationToken)
+            {
+                await using (var e = source.GetAsyncEnumerator(cancellationToken).ConfigureAwait(false))
+                {
+                    if (!await e.MoveNextAsync())
+                    {
+                        yield break;
+                    }
+
+                    TSource res = e.Current;
+
+                    while (await e.MoveNextAsync())
+                    {
+                        res = accumulator(res, e.Current);
+
+                        yield return res;
+                    }
+                }
+            }
+#else
             return new ScanAsyncEnumerable<TSource>(source, accumulator);
+#endif
         }
 
         public static IAsyncEnumerable<TAccumulate> Scan<TSource, TAccumulate>(this IAsyncEnumerable<TSource> source, TAccumulate seed, Func<TAccumulate, TSource, TAccumulate> accumulator)
@@ -28,7 +56,23 @@ namespace System.Linq
             if (accumulator == null)
                 throw Error.ArgumentNull(nameof(accumulator));
 
+#if USE_ASYNC_ITERATOR
+            return AsyncEnumerable.Create(Core);
+
+            async IAsyncEnumerator<TSource> Core(CancellationToken cancellationToken)
+            {
+                TAccumulate res = seed;
+
+                await foreach (TSource item in source.WithCancellation(cancellationToken).ConfigureAwait(false))
+                {
+                    res = accumulator(res, item);
+
+                    yield return res;
+                }
+            }
+#else
             return new ScanAsyncEnumerable<TSource, TAccumulate>(source, seed, accumulator);
+#endif
         }
 
         public static IAsyncEnumerable<TSource> Scan<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, TSource, ValueTask<TSource>> accumulator)
@@ -38,7 +82,31 @@ namespace System.Linq
             if (accumulator == null)
                 throw Error.ArgumentNull(nameof(accumulator));
 
+#if USE_ASYNC_ITERATOR
+            return AsyncEnumerable.Create(Core);
+
+            async IAsyncEnumerator<TSource> Core(CancellationToken cancellationToken)
+            {
+                await using (var e = source.GetAsyncEnumerator(cancellationToken).ConfigureAwait(false))
+                {
+                    if (!await e.MoveNextAsync())
+                    {
+                        yield break;
+                    }
+
+                    TSource res = e.Current;
+
+                    while (await e.MoveNextAsync())
+                    {
+                        res = await accumulator(res, e.Current).ConfigureAwait(false);
+
+                        yield return res;
+                    }
+                }
+            }
+#else
             return new ScanAsyncEnumerableWithTask<TSource>(source, accumulator);
+#endif
         }
 
 #if !NO_DEEP_CANCELLATION
@@ -49,7 +117,31 @@ namespace System.Linq
             if (accumulator == null)
                 throw Error.ArgumentNull(nameof(accumulator));
 
+#if USE_ASYNC_ITERATOR
+            return AsyncEnumerable.Create(Core);
+
+            async IAsyncEnumerator<TSource> Core(CancellationToken cancellationToken)
+            {
+                await using (var e = source.GetAsyncEnumerator(cancellationToken).ConfigureAwait(false))
+                {
+                    if (!await e.MoveNextAsync())
+                    {
+                        yield break;
+                    }
+
+                    TSource res = e.Current;
+
+                    while (await e.MoveNextAsync())
+                    {
+                        res = await accumulator(res, e.Current, cancellationToken).ConfigureAwait(false);
+
+                        yield return res;
+                    }
+                }
+            }
+#else
             return new ScanAsyncEnumerableWithTaskAndCancellation<TSource>(source, accumulator);
+#endif
         }
 #endif
 
@@ -60,7 +152,23 @@ namespace System.Linq
             if (accumulator == null)
                 throw Error.ArgumentNull(nameof(accumulator));
 
+#if USE_ASYNC_ITERATOR
+            return AsyncEnumerable.Create(Core);
+
+            async IAsyncEnumerator<TSource> Core(CancellationToken cancellationToken)
+            {
+                TAccumulate res = seed;
+
+                await foreach (TSource item in source.WithCancellation(cancellationToken).ConfigureAwait(false))
+                {
+                    res = await accumulator(res, item).ConfigureAwait(false);
+
+                    yield return res;
+                }
+            }
+#else
             return new ScanAsyncEnumerableWithTask<TSource, TAccumulate>(source, seed, accumulator);
+#endif
         }
 
 #if !NO_DEEP_CANCELLATION
@@ -71,10 +179,27 @@ namespace System.Linq
             if (accumulator == null)
                 throw Error.ArgumentNull(nameof(accumulator));
 
+#if USE_ASYNC_ITERATOR
+            return AsyncEnumerable.Create(Core);
+
+            async IAsyncEnumerator<TSource> Core(CancellationToken cancellationToken)
+            {
+                TAccumulate res = seed;
+
+                await foreach (TSource item in source.WithCancellation(cancellationToken).ConfigureAwait(false))
+                {
+                    res = await accumulator(res, item, cancellationToken).ConfigureAwait(false);
+
+                    yield return res;
+                }
+            }
+#else
             return new ScanAsyncEnumerableWithTaskAndCancellation<TSource, TAccumulate>(source, seed, accumulator);
+#endif
         }
 #endif
 
+#if !USE_ASYNC_ITERATOR
         private sealed class ScanAsyncEnumerable<TSource> : AsyncIterator<TSource>
         {
             private readonly Func<TSource, TSource, TSource> _accumulator;
@@ -496,6 +621,7 @@ namespace System.Linq
                 return false;
             }
         }
+#endif
 #endif
     }
 }