Преглед на файлове

Async variants of Scan.

Bart De Smet преди 8 години
родител
ревизия
2456f5767b
променени са 2 файла, в които са добавени 164 реда и са изтрити 6 реда
  1. 160 2
      Ix.NET/Source/System.Interactive.Async/Scan.cs
  2. 4 4
      Ix.NET/Source/Tests/AsyncTests.Single.cs

+ 160 - 2
Ix.NET/Source/System.Interactive.Async/Scan.cs

@@ -30,6 +30,26 @@ namespace System.Linq
             return new ScanAsyncEnumerable<TSource, TAccumulate>(source, seed, accumulator);
         }
 
+        public static IAsyncEnumerable<TSource> Scan<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, TSource, Task<TSource>> accumulator)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (accumulator == null)
+                throw new ArgumentNullException(nameof(accumulator));
+
+            return new ScanAsyncEnumerableWithTask<TSource>(source, accumulator);
+        }
+
+        public static IAsyncEnumerable<TAccumulate> Scan<TSource, TAccumulate>(this IAsyncEnumerable<TSource> source, TAccumulate seed, Func<TAccumulate, TSource, Task<TAccumulate>> accumulator)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (accumulator == null)
+                throw new ArgumentNullException(nameof(accumulator));
+
+            return new ScanAsyncEnumerableWithTask<TSource, TAccumulate>(source, seed, accumulator);
+        }
+
         private sealed class ScanAsyncEnumerable<TSource> : AsyncIterator<TSource>
         {
             private readonly Func<TSource, TSource, TSource> accumulator;
@@ -152,8 +172,7 @@ namespace System.Linq
                         goto case AsyncIteratorState.Iterating;
 
                     case AsyncIteratorState.Iterating:
-                        if (await enumerator.MoveNextAsync()
-                                             .ConfigureAwait(false))
+                        if (await enumerator.MoveNextAsync().ConfigureAwait(false))
                         {
                             var item = enumerator.Current;
                             accumulated = accumulator(accumulated, item);
@@ -169,5 +188,144 @@ namespace System.Linq
                 return false;
             }
         }
+
+        private sealed class ScanAsyncEnumerableWithTask<TSource> : AsyncIterator<TSource>
+        {
+            private readonly Func<TSource, TSource, Task<TSource>> accumulator;
+            private readonly IAsyncEnumerable<TSource> source;
+
+            private TSource accumulated;
+            private IAsyncEnumerator<TSource> enumerator;
+
+            private bool hasSeed;
+
+            public ScanAsyncEnumerableWithTask(IAsyncEnumerable<TSource> source, Func<TSource, TSource, Task<TSource>> accumulator)
+            {
+                Debug.Assert(source != null);
+                Debug.Assert(accumulator != null);
+
+                this.source = source;
+                this.accumulator = accumulator;
+            }
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new ScanAsyncEnumerableWithTask<TSource>(source, accumulator);
+            }
+
+            public override async Task DisposeAsync()
+            {
+                if (enumerator != null)
+                {
+                    await enumerator.DisposeAsync().ConfigureAwait(false);
+                    enumerator = null;
+                    accumulated = default(TSource);
+                }
+
+                await base.DisposeAsync().ConfigureAwait(false);
+            }
+
+            protected override async Task<bool> MoveNextCore()
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetAsyncEnumerator();
+                        hasSeed = false;
+                        accumulated = default(TSource);
+
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+
+                        while (await enumerator.MoveNextAsync().ConfigureAwait(false))
+                        {
+                            var item = enumerator.Current;
+                            if (!hasSeed)
+                            {
+                                hasSeed = true;
+                                accumulated = item;
+                                continue; // loop
+                            }
+
+                            accumulated = await accumulator(accumulated, item).ConfigureAwait(false);
+                            current = accumulated;
+                            return true;
+                        }
+
+                        break; // case
+
+                }
+
+                await DisposeAsync().ConfigureAwait(false);
+                return false;
+            }
+        }
+
+        private sealed class ScanAsyncEnumerableWithTask<TSource, TAccumulate> : AsyncIterator<TAccumulate>
+        {
+            private readonly Func<TAccumulate, TSource, Task<TAccumulate>> accumulator;
+            private readonly TAccumulate seed;
+            private readonly IAsyncEnumerable<TSource> source;
+
+            private TAccumulate accumulated;
+            private IAsyncEnumerator<TSource> enumerator;
+
+            public ScanAsyncEnumerableWithTask(IAsyncEnumerable<TSource> source, TAccumulate seed, Func<TAccumulate, TSource, Task<TAccumulate>> accumulator)
+            {
+                Debug.Assert(source != null);
+                Debug.Assert(accumulator != null);
+
+                this.source = source;
+                this.seed = seed;
+                this.accumulator = accumulator;
+            }
+
+            public override AsyncIterator<TAccumulate> Clone()
+            {
+                return new ScanAsyncEnumerableWithTask<TSource, TAccumulate>(source, seed, accumulator);
+            }
+
+            public override async Task DisposeAsync()
+            {
+                if (enumerator != null)
+                {
+                    await enumerator.DisposeAsync().ConfigureAwait(false);
+                    enumerator = null;
+                    accumulated = default(TAccumulate);
+                }
+
+                await base.DisposeAsync().ConfigureAwait(false);
+            }
+
+            protected override async Task<bool> MoveNextCore()
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetAsyncEnumerator();
+                        accumulated = seed;
+
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        if (await enumerator.MoveNextAsync().ConfigureAwait(false))
+                        {
+                            var item = enumerator.Current;
+                            accumulated = await accumulator(accumulated, item).ConfigureAwait(false);
+                            current = accumulated;
+                            return true;
+                        }
+
+                        break;
+
+                }
+
+                await DisposeAsync().ConfigureAwait(false);
+                return false;
+            }
+        }
     }
 }

+ 4 - 4
Ix.NET/Source/Tests/AsyncTests.Single.cs

@@ -3191,10 +3191,10 @@ namespace Tests
         public void Scan_Null()
         {
             AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Scan(default(IAsyncEnumerable<int>), 3, (x, y) => x + y));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Scan(AsyncEnumerable.Return(42), 3, null));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Scan(AsyncEnumerable.Return(42), 3, default(Func<int, int, int>)));
 
             AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Scan(default(IAsyncEnumerable<int>), (x, y) => x + y));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Scan(AsyncEnumerable.Return(42), null));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Scan(AsyncEnumerable.Return(42), default(Func<int, int, int>)));
         }
 
         [Fact]
@@ -3224,7 +3224,7 @@ namespace Tests
         public void Scan3()
         {
             var ex = new Exception("Bang!");
-            var xs = new[] { 1, 2, 3 }.ToAsyncEnumerable().Scan(8, (x, y) => { throw ex; });
+            var xs = new[] { 1, 2, 3 }.ToAsyncEnumerable().Scan(8, new Func<int, int, int>((x, y) => { throw ex; }));
 
             var e = xs.GetAsyncEnumerator();
             AssertThrows(() => e.MoveNextAsync().Wait(WaitTimeoutMs), (Exception ex_) => ((AggregateException)ex_).Flatten().InnerExceptions.Single() == ex);
@@ -3234,7 +3234,7 @@ namespace Tests
         public void Scan4()
         {
             var ex = new Exception("Bang!");
-            var xs = new[] { 1, 2, 3 }.ToAsyncEnumerable().Scan((x, y) => { throw ex; });
+            var xs = new[] { 1, 2, 3 }.ToAsyncEnumerable().Scan(new Func<int, int, int>((x, y) => { throw ex; }));
 
             var e = xs.GetAsyncEnumerator();
             AssertThrows(() => e.MoveNextAsync().Wait(WaitTimeoutMs), (Exception ex_) => ((AggregateException)ex_).Flatten().InnerExceptions.Single() == ex);