Oren Novotny 9 years ago
parent
commit
fbb9d0d832
1 changed files with 128 additions and 73 deletions
  1. 128 73
      Ix.NET/Source/System.Interactive.Async/Scan.cs

+ 128 - 73
Ix.NET/Source/System.Interactive.Async/Scan.cs

@@ -19,40 +19,7 @@ namespace System.Linq
             if (accumulator == null)
                 throw new ArgumentNullException(nameof(accumulator));
 
-            return CreateEnumerable(
-                () =>
-                {
-                    var e = source.GetEnumerator();
-
-                    var cts = new CancellationTokenDisposable();
-                    var d = Disposable.Create(cts, e);
-
-                    var acc = seed;
-                    var current = default(TAccumulate);
-
-                    var f = default(Func<CancellationToken, Task<bool>>);
-                    f = async ct =>
-                        {
-                            if (!await e.MoveNext(ct)
-                                        .ConfigureAwait(false))
-                            {
-                                return false;
-                            }
-
-                            var item = e.Current;
-                            acc = accumulator(acc, item);
-
-                            current = acc;
-                            return true;
-                        };
-
-                    return CreateEnumerator(
-                        f,
-                        () => current,
-                        d.Dispose,
-                        e
-                    );
-                });
+            return new ScanAsyncEnumerable<TSource, TAccumulate>(source, seed, accumulator);
         }
 
         public static IAsyncEnumerable<TSource> Scan<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, TSource, TSource> accumulator)
@@ -62,50 +29,138 @@ namespace System.Linq
             if (accumulator == null)
                 throw new ArgumentNullException(nameof(accumulator));
 
-            return CreateEnumerable(
-                () =>
+            return new ScanAsyncEnumerable<TSource>(source, accumulator);
+        }
+
+        private sealed class ScanAsyncEnumerable<TSource, TAccumulate> : AsyncIterator<TAccumulate>
+        {
+            private readonly Func<TAccumulate, TSource, TAccumulate> accumulator;
+            private readonly TAccumulate seed;
+            private readonly IAsyncEnumerable<TSource> source;
+
+            private TAccumulate accumulated;
+            private IAsyncEnumerator<TSource> enumerator;
+
+            public ScanAsyncEnumerable(IAsyncEnumerable<TSource> source, TAccumulate seed, Func<TAccumulate, TSource, TAccumulate> accumulator)
+            {
+                this.source = source;
+                this.seed = seed;
+                this.accumulator = accumulator;
+            }
+
+            public override AsyncIterator<TAccumulate> Clone()
+            {
+                return new ScanAsyncEnumerable<TSource, TAccumulate>(source, seed, accumulator);
+            }
+
+            public override void Dispose()
+            {
+                if (enumerator != null)
+                {
+                    enumerator.Dispose();
+                    enumerator = null;
+                    accumulated = default(TAccumulate);
+                }
+
+                base.Dispose();
+            }
+
+            protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
                 {
-                    var e = source.GetEnumerator();
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetEnumerator();
+                        accumulated = seed;
 
-                    var cts = new CancellationTokenDisposable();
-                    var d = Disposable.Create(cts, e);
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
 
-                    var hasSeed = false;
-                    var acc = default(TSource);
-                    var current = default(TSource);
+                    case AsyncIteratorState.Iterating:
+                        if (!await enumerator.MoveNext(cancellationToken)
+                                             .ConfigureAwait(false))
+                        {
+                            break;
+                        }
+
+                        var item = enumerator.Current;
+                        accumulated = accumulator(accumulated, item);
+                        current = accumulated;
+                        return true;
+                }
+
+                Dispose();
+                return false;
+            }
+        }
+
+        private sealed class ScanAsyncEnumerable<TSource> : AsyncIterator<TSource>
+        {
+            private readonly Func<TSource, TSource, TSource> accumulator;
+            private readonly IAsyncEnumerable<TSource> source;
+            private TSource accumulated;
+            private IAsyncEnumerator<TSource> enumerator;
+
+            private bool hasSeed;
+
+            public ScanAsyncEnumerable(IAsyncEnumerable<TSource> source, Func<TSource, TSource, TSource> accumulator)
+            {
+                this.source = source;
+                this.accumulator = accumulator;
+            }
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new ScanAsyncEnumerable<TSource>(source, accumulator);
+            }
+
+            public override void Dispose()
+            {
+                if (enumerator != null)
+                {
+                    enumerator.Dispose();
+                    enumerator = null;
+                    accumulated = default(TSource);
+                }
+
+                base.Dispose();
+            }
+
+            protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetEnumerator();
+                        hasSeed = false;
+                        accumulated = default(TSource);
+
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        if (!await enumerator.MoveNext(cancellationToken)
+                                             .ConfigureAwait(false))
+                        {
+                            break;
+                        }
 
-                    var f = default(Func<CancellationToken, Task<bool>>);
-                    f = async ct =>
+                        var item = enumerator.Current;
+                        if (!hasSeed)
                         {
-                            if (!await e.MoveNext(ct)
-                                        .ConfigureAwait(false))
-                            {
-                                return false;
-                            }
-
-                            var item = e.Current;
-
-                            if (!hasSeed)
-                            {
-                                hasSeed = true;
-                                acc = item;
-                                return await f(ct)
-                                           .ConfigureAwait(false);
-                            }
-
-                            acc = accumulator(acc, item);
-
-                            current = acc;
-                            return true;
-                        };
-
-                    return CreateEnumerator(
-                        f,
-                        () => current,
-                        d.Dispose,
-                        e
-                    );
-                });
+                            hasSeed = true;
+                            accumulated = item;
+                            goto case AsyncIteratorState.Iterating; // loop
+                        }
+
+                        accumulated = accumulator(accumulated, item);
+                        current = accumulated;
+                        return true;
+                }
+
+                Dispose();
+                return false;
+            }
         }
     }
 }