ソースを参照

Adding first async variant of SelectMany.

Bart De Smet 8 年 前
コミット
5b41d51364

+ 103 - 3
Ix.NET/Source/System.Interactive.Async/SelectMany.cs

@@ -20,7 +20,6 @@ namespace System.Linq
             return source.SelectMany(_ => other);
         }
 
-
         public static IAsyncEnumerable<TResult> SelectMany<TSource, TResult>(this IAsyncEnumerable<TSource> source, Func<TSource, IAsyncEnumerable<TResult>> selector)
         {
             if (source == null)
@@ -31,6 +30,16 @@ namespace System.Linq
             return new SelectManyAsyncIterator<TSource, TResult>(source, selector);
         }
 
+        public static IAsyncEnumerable<TResult> SelectMany<TSource, TResult>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<IAsyncEnumerable<TResult>>> selector)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (selector == null)
+                throw new ArgumentNullException(nameof(selector));
+
+            return new SelectManyAsyncIteratorWithTask<TSource, TResult>(source, selector);
+        }
+
         public static IAsyncEnumerable<TResult> SelectMany<TSource, TResult>(this IAsyncEnumerable<TSource> source, Func<TSource, int, IAsyncEnumerable<TResult>> selector)
         {
             if (source == null)
@@ -129,8 +138,99 @@ namespace System.Linq
                                         await resultEnumerator.DisposeAsync().ConfigureAwait(false);
                                     }
 
-                                    resultEnumerator = selector(sourceEnumerator.Current)
-                                        .GetAsyncEnumerator();
+                                    var inner = selector(sourceEnumerator.Current);
+                                    resultEnumerator = inner.GetAsyncEnumerator();
+
+                                    mode = State_Result;
+                                    goto case State_Result;
+                                }
+                                break;
+
+                            case State_Result:
+                                if (await resultEnumerator.MoveNextAsync().ConfigureAwait(false))
+                                {
+                                    current = resultEnumerator.Current;
+                                    return true;
+                                }
+
+                                mode = State_Source;
+                                goto case State_Source; // loop
+                        }
+
+                        break;
+                }
+
+                await DisposeAsync().ConfigureAwait(false);
+                return false;
+            }
+        }
+
+        private sealed class SelectManyAsyncIteratorWithTask<TSource, TResult> : AsyncIterator<TResult>
+        {
+            private const int State_Source = 1;
+            private const int State_Result = 2;
+
+            private readonly Func<TSource, Task<IAsyncEnumerable<TResult>>> selector;
+            private readonly IAsyncEnumerable<TSource> source;
+
+            private int mode;
+            private IAsyncEnumerator<TResult> resultEnumerator;
+            private IAsyncEnumerator<TSource> sourceEnumerator;
+
+            public SelectManyAsyncIteratorWithTask(IAsyncEnumerable<TSource> source, Func<TSource, Task<IAsyncEnumerable<TResult>>> selector)
+            {
+                Debug.Assert(source != null);
+                Debug.Assert(selector != null);
+
+                this.source = source;
+                this.selector = selector;
+            }
+
+            public override AsyncIterator<TResult> Clone()
+            {
+                return new SelectManyAsyncIteratorWithTask<TSource, TResult>(source, selector);
+            }
+
+            public override async Task DisposeAsync()
+            {
+                if (sourceEnumerator != null)
+                {
+                    await sourceEnumerator.DisposeAsync().ConfigureAwait(false);
+                    sourceEnumerator = null;
+                }
+
+                if (resultEnumerator != null)
+                {
+                    await resultEnumerator.DisposeAsync().ConfigureAwait(false);
+                    resultEnumerator = null;
+                }
+
+                await base.DisposeAsync().ConfigureAwait(false);
+            }
+
+            protected override async Task<bool> MoveNextCore()
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        sourceEnumerator = source.GetAsyncEnumerator();
+                        mode = State_Source;
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        switch (mode)
+                        {
+                            case State_Source:
+                                if (await sourceEnumerator.MoveNextAsync().ConfigureAwait(false))
+                                {
+                                    if (resultEnumerator != null)
+                                    {
+                                        await resultEnumerator.DisposeAsync().ConfigureAwait(false);
+                                    }
+
+                                    var inner = await selector(sourceEnumerator.Current).ConfigureAwait(false);
+                                    resultEnumerator = inner.GetAsyncEnumerator();
 
                                     mode = State_Result;
                                     goto case State_Result;

+ 4 - 4
Ix.NET/Source/Tests/AsyncTests.Single.cs

@@ -268,13 +268,13 @@ namespace Tests
         [Fact]
         public void SelectMany_Null()
         {
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.SelectMany<int, int>(null, x => null));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.SelectMany<int, int>(null, (x, i) => null));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.SelectMany<int, int>(null, default(Func<int, IAsyncEnumerable<int>>)));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.SelectMany<int, int>(null, default(Func<int, int, IAsyncEnumerable<int>>)));
             AssertThrows<ArgumentNullException>(() => AsyncEnumerable.SelectMany<int, int>(AsyncEnumerable.Return(42), default(Func<int, IAsyncEnumerable<int>>)));
             AssertThrows<ArgumentNullException>(() => AsyncEnumerable.SelectMany<int, int>(AsyncEnumerable.Return(42), default(Func<int, int, IAsyncEnumerable<int>>)));
 
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.SelectMany<int, int, int>(null, x => null, (x, y) => x));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.SelectMany<int, int, int>(null, (x, i) => null, (x, y) => x));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.SelectMany<int, int, int>(null, default(Func<int, IAsyncEnumerable<int>>), (x, y) => x));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.SelectMany<int, int, int>(null, default(Func<int, int, IAsyncEnumerable<int>>), (x, y) => x));
             AssertThrows<ArgumentNullException>(() => AsyncEnumerable.SelectMany<int, int, int>(AsyncEnumerable.Return(42), default(Func<int, IAsyncEnumerable<int>>), (x, y) => x));
             AssertThrows<ArgumentNullException>(() => AsyncEnumerable.SelectMany<int, int, int>(AsyncEnumerable.Return(42), default(Func<int, int, IAsyncEnumerable<int>>), (x, y) => x));
             AssertThrows<ArgumentNullException>(() => AsyncEnumerable.SelectMany<int, int, int>(AsyncEnumerable.Return(42), x => null, default(Func<int, int, int>)));