Ver código fonte

Optimize SelectMany

Oren Novotny 9 anos atrás
pai
commit
77355d247b
1 arquivos alterados com 366 adições e 99 exclusões
  1. 366 99
      Ix.NET/Source/System.Interactive.Async/SelectMany.cs

+ 366 - 99
Ix.NET/Source/System.Interactive.Async/SelectMany.cs

@@ -30,137 +30,404 @@ namespace System.Linq
             if (selector == null)
                 throw new ArgumentNullException(nameof(selector));
 
-            return CreateEnumerable(
-                () =>
-                {
-                    var e = source.GetEnumerator();
-                    var ie = default(IAsyncEnumerator<TResult>);
-
-                    var innerDisposable = new AssignableDisposable();
+            return new SelectManyAsyncIterator<TSource, TResult>(source, selector);
+        }
 
-                    var cts = new CancellationTokenDisposable();
-                    var d = Disposable.Create(cts, innerDisposable, e);
+        public static IAsyncEnumerable<TResult> SelectMany<TSource, TResult>(this IAsyncEnumerable<TSource> source, Func<TSource, int, IAsyncEnumerable<TResult>> selector)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (selector == null)
+                throw new ArgumentNullException(nameof(selector));
 
-                    var inner = default(Func<CancellationToken, Task<bool>>);
-                    var outer = default(Func<CancellationToken, Task<bool>>);
+            return new SelectManyWithIndexAsyncIterator<TSource, TResult>(source, selector);
+        }
 
-                    inner = async ct =>
-                            {
-                                if (await ie.MoveNext(ct)
-                                            .ConfigureAwait(false))
-                                {
-                                    return true;
-                                }
-                                innerDisposable.Disposable = null;
-                                return await outer(ct)
-                                           .ConfigureAwait(false);
-                            };
-
-                    outer = async ct =>
-                            {
-                                if (await e.MoveNext(ct)
-                                           .ConfigureAwait(false))
-                                {
-                                    var enumerable = selector(e.Current);
-                                    ie = enumerable.GetEnumerator();
-                                    innerDisposable.Disposable = ie;
+        public static IAsyncEnumerable<TResult> SelectMany<TSource, TCollection, TResult>(this IAsyncEnumerable<TSource> source, Func<TSource, IAsyncEnumerable<TCollection>> selector, Func<TSource, TCollection, TResult> resultSelector)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (selector == null)
+                throw new ArgumentNullException(nameof(selector));
+            if (resultSelector == null)
+                throw new ArgumentNullException(nameof(resultSelector));
 
-                                    return await inner(ct)
-                                               .ConfigureAwait(false);
-                                }
-                                return false;
-                            };
-
-                    return CreateEnumerator(ct => ie == null ? outer(cts.Token) : inner(cts.Token),
-                                            () => ie.Current,
-                                            d.Dispose,
-                                            e
-                    );
-                });
+            return new SelectManyAsyncIterator<TSource, TCollection, TResult>(source, selector, resultSelector);
         }
 
-        public static IAsyncEnumerable<TResult> SelectMany<TSource, TResult>(this IAsyncEnumerable<TSource> source, Func<TSource, int, IAsyncEnumerable<TResult>> selector)
+        public static IAsyncEnumerable<TResult> SelectMany<TSource, TCollection, TResult>(this IAsyncEnumerable<TSource> source, Func<TSource, int, IAsyncEnumerable<TCollection>> selector, Func<TSource, TCollection, TResult> resultSelector)
         {
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
             if (selector == null)
                 throw new ArgumentNullException(nameof(selector));
+            if (resultSelector == null)
+                throw new ArgumentNullException(nameof(resultSelector));
 
-            return CreateEnumerable(
-                () =>
+            return new SelectManyWithIndexAsyncIterator<TSource, TCollection, TResult>(source, selector, resultSelector);
+        }
+
+        private sealed class SelectManyAsyncIterator<TSource, TResult> : AsyncIterator<TResult>
+        {
+            private const int State_Source = 1;
+            private const int State_Result = 2;
+            private readonly Func<TSource, IAsyncEnumerable<TResult>> selector;
+            private readonly IAsyncEnumerable<TSource> source;
+            private int mode;
+            private IAsyncEnumerator<TResult> resultEnumerator;
+
+            private IAsyncEnumerator<TSource> sourceEnumerator;
+
+            public SelectManyAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, IAsyncEnumerable<TResult>> selector)
+            {
+                this.source = source;
+                this.selector = selector;
+            }
+
+            public override AsyncIterator<TResult> Clone()
+            {
+                return new SelectManyAsyncIterator<TSource, TResult>(source, selector);
+            }
+
+            public override void Dispose()
+            {
+                if (sourceEnumerator != null)
                 {
-                    var e = source.GetEnumerator();
-                    var ie = default(IAsyncEnumerator<TResult>);
+                    sourceEnumerator.Dispose();
+                    sourceEnumerator = null;
+                }
 
-                    var index = 0;
+                if (resultEnumerator != null)
+                {
+                    resultEnumerator.Dispose();
+                    resultEnumerator = null;
+                }
 
-                    var innerDisposable = new AssignableDisposable();
+                base.Dispose();
+            }
 
-                    var cts = new CancellationTokenDisposable();
-                    var d = Disposable.Create(cts, innerDisposable, e);
+            protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        sourceEnumerator = source.GetEnumerator();
+                        mode = State_Source;
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
 
-                    var inner = default(Func<CancellationToken, Task<bool>>);
-                    var outer = default(Func<CancellationToken, Task<bool>>);
+                    case AsyncIteratorState.Iterating:
+                        switch (mode)
+                        {
+                            case State_Source:
+                                if (await sourceEnumerator.MoveNext(cancellationToken)
+                                                          .ConfigureAwait(false))
+                                {
+                                    resultEnumerator?.Dispose();
+                                    resultEnumerator = selector(sourceEnumerator.Current)
+                                        .GetEnumerator();
 
-                    inner = async ct =>
-                            {
-                                if (await ie.MoveNext(ct)
-                                            .ConfigureAwait(false))
+                                    mode = State_Result;
+                                    goto case State_Result;
+                                }
+                                break;
+
+                            case State_Result:
+                                if (await resultEnumerator.MoveNext(cancellationToken)
+                                                          .ConfigureAwait(false))
                                 {
+                                    current = resultEnumerator.Current;
                                     return true;
                                 }
-                                innerDisposable.Disposable = null;
-                                return await outer(ct)
-                                           .ConfigureAwait(false);
-                            };
-
-                    outer = async ct =>
-                            {
-                                if (await e.MoveNext(ct)
-                                           .ConfigureAwait(false))
+
+                                mode = State_Source;
+                                goto case State_Source; // loop
+                        }
+
+                        break;
+                }
+
+                Dispose();
+                return false;
+            }
+        }
+
+        private sealed class SelectManyAsyncIterator<TSource, TCollection, TResult> : AsyncIterator<TResult>
+        {
+            private const int State_Source = 1;
+            private const int State_Result = 2;
+            private readonly Func<TSource, IAsyncEnumerable<TCollection>> collectionSelector;
+            private readonly Func<TSource, TCollection, TResult> resultSelector;
+            private readonly IAsyncEnumerable<TSource> source;
+            private TSource currentSource;
+            private int mode;
+            private IAsyncEnumerator<TCollection> resultEnumerator;
+            private IAsyncEnumerator<TSource> sourceEnumerator;
+
+            public SelectManyAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, IAsyncEnumerable<TCollection>> collectionSelector, Func<TSource, TCollection, TResult> resultSelector)
+            {
+                this.source = source;
+                this.collectionSelector = collectionSelector;
+                this.resultSelector = resultSelector;
+            }
+
+            public override AsyncIterator<TResult> Clone()
+            {
+                return new SelectManyAsyncIterator<TSource, TCollection, TResult>(source, collectionSelector, resultSelector);
+            }
+
+            public override void Dispose()
+            {
+                if (sourceEnumerator != null)
+                {
+                    sourceEnumerator.Dispose();
+                    sourceEnumerator = null;
+                }
+
+                if (resultEnumerator != null)
+                {
+                    resultEnumerator.Dispose();
+                    resultEnumerator = null;
+                }
+
+                currentSource = default(TSource);
+
+                base.Dispose();
+            }
+
+            protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        sourceEnumerator = source.GetEnumerator();
+                        mode = State_Source;
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        switch (mode)
+                        {
+                            case State_Source:
+                                if (await sourceEnumerator.MoveNext(cancellationToken)
+                                                          .ConfigureAwait(false))
                                 {
-                                    var enumerable = selector(e.Current, checked(index++));
-                                    ie = enumerable.GetEnumerator();
-                                    innerDisposable.Disposable = ie;
+                                    resultEnumerator?.Dispose();
+                                    currentSource = sourceEnumerator.Current;
+                                    resultEnumerator = collectionSelector(currentSource)
+                                        .GetEnumerator();
 
-                                    return await inner(ct)
-                                               .ConfigureAwait(false);
+                                    mode = State_Result;
+                                    goto case State_Result;
+                                }
+                                break;
+
+                            case State_Result:
+                                if (await resultEnumerator.MoveNext(cancellationToken)
+                                                          .ConfigureAwait(false))
+                                {
+                                    current = resultSelector(currentSource, resultEnumerator.Current);
+                                    return true;
                                 }
-                                return false;
-                            };
-
-                    return CreateEnumerator(ct => ie == null ? outer(cts.Token) : inner(cts.Token),
-                                            () => ie.Current,
-                                            d.Dispose,
-                                            e
-                    );
-                });
+
+                                mode = State_Source;
+                                goto case State_Source; // loop
+                        }
+
+                        break;
+                }
+
+                Dispose();
+                return false;
+            }
         }
 
-        public static IAsyncEnumerable<TResult> SelectMany<TSource, TCollection, TResult>(this IAsyncEnumerable<TSource> source, Func<TSource, IAsyncEnumerable<TCollection>> selector, Func<TSource, TCollection, TResult> resultSelector)
+        private sealed class SelectManyWithIndexAsyncIterator<TSource, TCollection, TResult> : AsyncIterator<TResult>
         {
-            if (source == null)
-                throw new ArgumentNullException(nameof(source));
-            if (selector == null)
-                throw new ArgumentNullException(nameof(selector));
-            if (resultSelector == null)
-                throw new ArgumentNullException(nameof(resultSelector));
+            private const int State_Source = 1;
+            private const int State_Result = 2;
+            private readonly Func<TSource, int, IAsyncEnumerable<TCollection>> collectionSelector;
+            private readonly Func<TSource, TCollection, TResult> resultSelector;
+            private readonly IAsyncEnumerable<TSource> source;
+            private TSource currentSource;
+            private int index;
+            private int mode;
+            private IAsyncEnumerator<TCollection> resultEnumerator;
+            private IAsyncEnumerator<TSource> sourceEnumerator;
+
+            public SelectManyWithIndexAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, int, IAsyncEnumerable<TCollection>> collectionSelector, Func<TSource, TCollection, TResult> resultSelector)
+            {
+                this.source = source;
+                this.collectionSelector = collectionSelector;
+                this.resultSelector = resultSelector;
+            }
+
+            public override AsyncIterator<TResult> Clone()
+            {
+                return new SelectManyWithIndexAsyncIterator<TSource, TCollection, TResult>(source, collectionSelector, resultSelector);
+            }
+
+            public override void Dispose()
+            {
+                if (sourceEnumerator != null)
+                {
+                    sourceEnumerator.Dispose();
+                    sourceEnumerator = null;
+                }
 
-            return source.SelectMany(x => selector(x)
-                                         .Select(y => resultSelector(x, y)));
+                if (resultEnumerator != null)
+                {
+                    resultEnumerator.Dispose();
+                    resultEnumerator = null;
+                }
+
+                currentSource = default(TSource);
+
+                base.Dispose();
+            }
+
+            protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        sourceEnumerator = source.GetEnumerator();
+                        index = -1;
+                        mode = State_Source;
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        switch (mode)
+                        {
+                            case State_Source:
+                                if (await sourceEnumerator.MoveNext(cancellationToken)
+                                                          .ConfigureAwait(false))
+                                {
+                                    resultEnumerator?.Dispose();
+                                    currentSource = sourceEnumerator.Current;
+
+                                    checked
+                                    {
+                                        index++;
+                                    }
+
+                                    resultEnumerator = collectionSelector(currentSource, index)
+                                        .GetEnumerator();
+
+                                    mode = State_Result;
+                                    goto case State_Result;
+                                }
+                                break;
+
+                            case State_Result:
+                                if (await resultEnumerator.MoveNext(cancellationToken)
+                                                          .ConfigureAwait(false))
+                                {
+                                    current = resultSelector(currentSource, resultEnumerator.Current);
+                                    return true;
+                                }
+
+                                mode = State_Source;
+                                goto case State_Source; // loop
+                        }
+
+                        break;
+                }
+
+                Dispose();
+                return false;
+            }
         }
 
-        public static IAsyncEnumerable<TResult> SelectMany<TSource, TCollection, TResult>(this IAsyncEnumerable<TSource> source, Func<TSource, int, IAsyncEnumerable<TCollection>> selector, Func<TSource, TCollection, TResult> resultSelector)
+        private sealed class SelectManyWithIndexAsyncIterator<TSource, TResult> : AsyncIterator<TResult>
         {
-            if (source == null)
-                throw new ArgumentNullException(nameof(source));
-            if (selector == null)
-                throw new ArgumentNullException(nameof(selector));
-            if (resultSelector == null)
-                throw new ArgumentNullException(nameof(resultSelector));
+            private const int State_Source = 1;
+            private const int State_Result = 2;
+            private readonly Func<TSource, int, IAsyncEnumerable<TResult>> selector;
+            private readonly IAsyncEnumerable<TSource> source;
+            private int index;
+            private int mode;
+            private IAsyncEnumerator<TResult> resultEnumerator;
+            private IAsyncEnumerator<TSource> sourceEnumerator;
+
+            public SelectManyWithIndexAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, int, IAsyncEnumerable<TResult>> selector)
+            {
+                this.source = source;
+                this.selector = selector;
+            }
+
+            public override AsyncIterator<TResult> Clone()
+            {
+                return new SelectManyWithIndexAsyncIterator<TSource, TResult>(source, selector);
+            }
+
+            public override void Dispose()
+            {
+                if (sourceEnumerator != null)
+                {
+                    sourceEnumerator.Dispose();
+                    sourceEnumerator = null;
+                }
+
+                if (resultEnumerator != null)
+                {
+                    resultEnumerator.Dispose();
+                    resultEnumerator = null;
+                }
+
+                base.Dispose();
+            }
+
+            protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        sourceEnumerator = source.GetEnumerator();
+                        index = -1;
+                        mode = State_Source;
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        switch (mode)
+                        {
+                            case State_Source:
+                                if (await sourceEnumerator.MoveNext(cancellationToken)
+                                                          .ConfigureAwait(false))
+                                {
+                                    resultEnumerator?.Dispose();
+                                    checked
+                                    {
+                                        index++;
+                                    }
+                                    resultEnumerator = selector(sourceEnumerator.Current, index)
+                                        .GetEnumerator();
+
+                                    mode = State_Result;
+                                    goto case State_Result;
+                                }
+                                break;
+
+                            case State_Result:
+                                if (await resultEnumerator.MoveNext(cancellationToken)
+                                                          .ConfigureAwait(false))
+                                {
+                                    current = resultEnumerator.Current;
+                                    return true;
+                                }
+
+                                mode = State_Source;
+                                goto case State_Source; // loop
+                        }
+
+                        break;
+                }
 
-            return source.SelectMany((x, i) => selector(x, i)
-                                         .Select(y => resultSelector(x, y)));
+                Dispose();
+                return false;
+            }
         }
     }
 }