Explorar el Código

Async variant of Expand.

Bart De Smet hace 8 años
padre
commit
a663af4864

+ 100 - 0
Ix.NET/Source/System.Interactive.Async/Expand.cs

@@ -20,6 +20,16 @@ namespace System.Linq
             return new ExpandAsyncIterator<TSource>(source, selector);
         }
 
+        public static IAsyncEnumerable<TSource> Expand<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<IAsyncEnumerable<TSource>>> selector)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (selector == null)
+                throw new ArgumentNullException(nameof(selector));
+
+            return new ExpandAsyncIteratorWithTask<TSource>(source, selector);
+        }
+
         private sealed class ExpandAsyncIterator<TSource> : AsyncIterator<TSource>
         {
             private readonly Func<TSource, IAsyncEnumerable<TSource>> selector;
@@ -109,5 +119,95 @@ namespace System.Linq
                 return false;
             }
         }
+
+        private sealed class ExpandAsyncIteratorWithTask<TSource> : AsyncIterator<TSource>
+        {
+            private readonly Func<TSource, Task<IAsyncEnumerable<TSource>>> selector;
+            private readonly IAsyncEnumerable<TSource> source;
+
+            private IAsyncEnumerator<TSource> enumerator;
+
+            private Queue<IAsyncEnumerable<TSource>> queue;
+
+            public ExpandAsyncIteratorWithTask(IAsyncEnumerable<TSource> source, Func<TSource, Task<IAsyncEnumerable<TSource>>> selector)
+            {
+                Debug.Assert(source != null);
+                Debug.Assert(selector != null);
+
+                this.source = source;
+                this.selector = selector;
+            }
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new ExpandAsyncIteratorWithTask<TSource>(source, selector);
+            }
+
+            public override async Task DisposeAsync()
+            {
+                if (enumerator != null)
+                {
+                    await enumerator.DisposeAsync().ConfigureAwait(false);
+                    enumerator = null;
+                }
+
+                queue = null;
+
+                await base.DisposeAsync().ConfigureAwait(false);
+            }
+
+            protected override async Task<bool> MoveNextCore()
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        queue = new Queue<IAsyncEnumerable<TSource>>();
+                        queue.Enqueue(source);
+
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        while (true)
+                        {
+                            if (enumerator == null)
+                            {
+                                if (queue.Count > 0)
+                                {
+                                    var src = queue.Dequeue();
+
+                                    if (enumerator != null)
+                                    {
+                                        await enumerator.DisposeAsync().ConfigureAwait(false);
+                                    }
+
+                                    enumerator = src.GetAsyncEnumerator();
+
+                                    continue; // loop
+                                }
+
+                                break; // while
+                            }
+
+                            if (await enumerator.MoveNextAsync().ConfigureAwait(false))
+                            {
+                                var item = enumerator.Current;
+                                var next = await selector(item).ConfigureAwait(false);
+                                queue.Enqueue(next);
+                                current = item;
+                                return true;
+                            }
+
+                            await enumerator.DisposeAsync().ConfigureAwait(false);
+                            enumerator = null;
+                        }
+
+                        break; // case
+                }
+
+                await DisposeAsync().ConfigureAwait(false);
+                return false;
+            }
+        }
     }
 }

+ 4 - 4
Ix.NET/Source/Tests/AsyncTests.Single.cs

@@ -3138,8 +3138,8 @@ namespace Tests
         [Fact]
         public void Expand_Null()
         {
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Expand(default(IAsyncEnumerable<int>), x => null));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Expand(AsyncEnumerable.Return(42), null));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Expand(default(IAsyncEnumerable<int>), x => default(IAsyncEnumerable<int>)));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Expand(AsyncEnumerable.Return(42), default(Func<int, IAsyncEnumerable<int>>)));
         }
 
         [Fact]
@@ -3162,7 +3162,7 @@ namespace Tests
         public void Expand2()
         {
             var ex = new Exception("Bang!");
-            var xs = new[] { 2, 3 }.ToAsyncEnumerable().Expand(x => { throw ex; });
+            var xs = new[] { 2, 3 }.ToAsyncEnumerable().Expand(new Func<int, IAsyncEnumerable<int>>(x => { throw ex; }));
 
             var e = xs.GetAsyncEnumerator();
             AssertThrows(() => e.MoveNextAsync().Wait(WaitTimeoutMs), (Exception ex_) => ((AggregateException)ex_).Flatten().InnerExceptions.Single() == ex);
@@ -3171,7 +3171,7 @@ namespace Tests
         [Fact]
         public void Expand3()
         {
-            var xs = new[] { 2, 3 }.ToAsyncEnumerable().Expand(x => null);
+            var xs = new[] { 2, 3 }.ToAsyncEnumerable().Expand(x => default(IAsyncEnumerable<int>));
 
             var e = xs.GetAsyncEnumerator();
             HasNext(e, 2);