瀏覽代碼

Optimize OrderBy. We take leverage the non-async versions after the lazy tolist as we cannot order w/o the full sequence

Oren Novotny 9 年之前
父節點
當前提交
6e9e63754b

+ 2 - 117
Ix.NET/Source/System.Interactive.Async/OrderBy.cs

@@ -21,29 +21,7 @@ namespace System.Linq
             if (comparer == null)
                 throw new ArgumentNullException(nameof(comparer));
 
-            return new OrderedAsyncEnumerable<TSource, TKey>(
-                CreateEnumerable(() =>
-                       {
-                           var current = default(IEnumerable<TSource>);
-
-                           return CreateEnumerator(
-                               async ct =>
-                               {
-                                   if (current == null)
-                                   {
-                                       current = await source.ToList(ct)
-                                                             .ConfigureAwait(false);
-                                       return true;
-                                   }
-                                   return false;
-                               },
-                               () => current,
-                               () => { }
-                           );
-                       }),
-                keySelector,
-                comparer
-            );
+            return new OrderedAsyncEnumerable<TSource, TKey>(source, keySelector, comparer, false, null);
         }
 
         public static IOrderedAsyncEnumerable<TSource> OrderBy<TSource, TKey>(this IAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector)
@@ -65,7 +43,7 @@ namespace System.Linq
             if (comparer == null)
                 throw new ArgumentNullException(nameof(comparer));
 
-            return source.OrderBy(keySelector, new ReverseComparer<TKey>(comparer));
+            return new OrderedAsyncEnumerable<TSource, TKey>(source, keySelector, comparer, true, null);
         }
 
         public static IOrderedAsyncEnumerable<TSource> OrderByDescending<TSource, TKey>(this IAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector)
@@ -121,98 +99,5 @@ namespace System.Linq
 
             return source.CreateOrderedEnumerable(keySelector, comparer, true);
         }
-
-        private class OrderedAsyncEnumerable<T, K> : IOrderedAsyncEnumerable<T>
-        {
-            private readonly IComparer<K> comparer;
-            private readonly IAsyncEnumerable<IEnumerable<T>> equivalenceClasses;
-            private readonly Func<T, K> keySelector;
-
-            public OrderedAsyncEnumerable(IAsyncEnumerable<IEnumerable<T>> equivalenceClasses, Func<T, K> keySelector, IComparer<K> comparer)
-            {
-                this.equivalenceClasses = equivalenceClasses;
-                this.keySelector = keySelector;
-                this.comparer = comparer;
-            }
-
-            public IOrderedAsyncEnumerable<T> CreateOrderedEnumerable<TKey>(Func<T, TKey> keySelector, IComparer<TKey> comparer, bool descending)
-            {
-                if (descending)
-                    comparer = new ReverseComparer<TKey>(comparer);
-
-                return new OrderedAsyncEnumerable<T, TKey>(Classes(), keySelector, comparer);
-            }
-
-            public IAsyncEnumerator<T> GetEnumerator()
-            {
-                return Classes()
-                    .SelectMany(x => x.ToAsyncEnumerable())
-                    .GetEnumerator();
-            }
-
-            private IAsyncEnumerable<IEnumerable<T>> Classes()
-            {
-                return CreateEnumerable(() =>
-                              {
-                                  var e = equivalenceClasses.GetEnumerator();
-                                  var list = new List<IEnumerable<T>>();
-                                  var e1 = default(IEnumerator<IEnumerable<T>>);
-
-                                  var cts = new CancellationTokenDisposable();
-                                  var d1 = new AssignableDisposable();
-                                  var d = Disposable.Create(cts, e, d1);
-
-                                  var f = default(Func<CancellationToken, Task<bool>>);
-
-                                  f = async ct =>
-                                      {
-                                          if (await e.MoveNext(ct)
-                                                     .ConfigureAwait(false))
-                                          {
-                                              list.AddRange(e.Current.OrderBy(keySelector, comparer)
-                                                             .GroupUntil(keySelector, x => x, comparer));
-                                              return await f(ct)
-                                                         .ConfigureAwait(false);
-                                          }
-                                          e.Dispose();
-
-                                          e1 = list.GetEnumerator();
-                                          d1.Disposable = e1;
-
-                                          return e1.MoveNext();
-                                      };
-
-                                  return CreateEnumerator(
-                                      async ct =>
-                                      {
-                                          if (e1 != null)
-                                          {
-                                              return e1.MoveNext();
-                                          }
-                                          return await f(cts.Token)
-                                                     .ConfigureAwait(false);
-                                      },
-                                      () => e1.Current,
-                                      d.Dispose,
-                                      e
-                                  );
-                              });
-            }
-        }
-
-        private class ReverseComparer<T> : IComparer<T>
-        {
-            private readonly IComparer<T> comparer;
-
-            public ReverseComparer(IComparer<T> comparer)
-            {
-                this.comparer = comparer;
-            }
-
-            public int Compare(T x, T y)
-            {
-                return -comparer.Compare(x, y);
-            }
-        }
     }
 }

+ 117 - 0
Ix.NET/Source/System.Interactive.Async/OrderedAsyncEnumerable.cs

@@ -0,0 +1,117 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the Apache 2.0 License.
+// See the LICENSE file in the project root for more information. 
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace System.Linq
+{
+    internal abstract class OrderedAsyncEnumerable<TElement> : AsyncEnumerable.AsyncIterator<TElement>, IOrderedAsyncEnumerable<TElement>
+    {
+        internal IOrderedEnumerable<TElement> enumerable;
+        internal IAsyncEnumerable<TElement> source;
+
+        IOrderedAsyncEnumerable<TElement> IOrderedAsyncEnumerable<TElement>.CreateOrderedEnumerable<TKey>(Func<TElement, TKey> keySelector, IComparer<TKey> comparer, bool descending)
+        {
+            return new OrderedAsyncEnumerable<TElement, TKey>(source, keySelector, comparer, descending, this);
+        }
+
+        internal abstract Task Initialize(CancellationToken cancellationToken);
+    }
+
+    internal sealed class OrderedAsyncEnumerable<TElement, TKey> : OrderedAsyncEnumerable<TElement>
+    {
+        private readonly IComparer<TKey> comparer;
+        private readonly bool descending;
+        private readonly Func<TElement, TKey> keySelector;
+
+
+        private IEnumerator<TElement> enumerator;
+
+        private readonly OrderedAsyncEnumerable<TElement> parent;
+        private IAsyncEnumerator<TElement> parentEnumerator;
+
+
+        public OrderedAsyncEnumerable(IAsyncEnumerable<TElement> source, Func<TElement, TKey> keySelector, IComparer<TKey> comparer, bool descending, OrderedAsyncEnumerable<TElement> parent)
+        {
+            if (source == null) throw new ArgumentNullException(nameof(source));
+            if (keySelector == null) throw new ArgumentNullException(nameof(keySelector));
+            this.source = source;
+            this.keySelector = keySelector;
+            this.comparer = comparer ?? Comparer<TKey>.Default;
+            this.descending = descending;
+            this.parent = parent;
+        }
+
+        public override AsyncEnumerable.AsyncIterator<TElement> Clone()
+        {
+            return new OrderedAsyncEnumerable<TElement, TKey>(source, keySelector, comparer, descending, parent);
+        }
+
+
+        public override void Dispose()
+        {
+            if (enumerator != null)
+            {
+                enumerator.Dispose();
+                enumerator = null;
+            }
+
+            if (parentEnumerator != null)
+            {
+                parentEnumerator.Dispose();
+                parentEnumerator = null;
+            }
+            base.Dispose();
+        }
+
+
+        protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+        {
+            switch (state)
+            {
+                case AsyncEnumerable.AsyncIteratorState.Allocated:
+
+                    await Initialize(cancellationToken)
+                        .ConfigureAwait(false);
+
+                    enumerator = enumerable.GetEnumerator();
+                    state = AsyncEnumerable.AsyncIteratorState.Iterating;
+                    goto case AsyncEnumerable.AsyncIteratorState.Iterating;
+
+                case AsyncEnumerable.AsyncIteratorState.Iterating:
+                    if (enumerator.MoveNext())
+                    {
+                        current = enumerator.Current;
+                        return true;
+                    }
+
+                    Dispose();
+                    break;
+            }
+
+            return false;
+        }
+
+        internal override async Task Initialize(CancellationToken cancellationToken)
+        {
+            if (parent == null)
+            {
+                var buffer = await source.ToList(cancellationToken)
+                                         .ConfigureAwait(false);
+                enumerable = (!@descending ? buffer.OrderBy(keySelector, comparer) : buffer.OrderByDescending(keySelector, comparer));
+            }
+            else
+            {
+                parentEnumerator = parent.GetEnumerator();
+                await parent.Initialize(cancellationToken)
+                            .ConfigureAwait(false);
+                enumerable = parent.enumerable.CreateOrderedEnumerable(keySelector, comparer, @descending);
+            }
+        }
+    }
+}

+ 0 - 2
Ix.NET/Source/Tests/AsyncTests.Bugs.cs

@@ -256,8 +256,6 @@ namespace Tests
         /// </summary>
         private sealed class CancellationTestEnumerable<T> : IEnumerable<T>
         {
-            private readonly CancellationToken cancellationToken;
-
             public CancellationTestEnumerable()
             {
             }