瀏覽代碼

Rewrite of OrderBy to properly support async.

Bart De Smet 7 年之前
父節點
當前提交
0bfbcd2dcb
共有 1 個文件被更改,包括 182 次插入97 次删除
  1. 182 97
      Ix.NET/Source/System.Linq.Async/System/Linq/Operators/OrderedAsyncEnumerable.cs

+ 182 - 97
Ix.NET/Source/System.Linq.Async/System/Linq/Operators/OrderedAsyncEnumerable.cs

@@ -9,10 +9,14 @@ using System.Threading.Tasks;
 
 namespace System.Linq
 {
+    // TODO: Add optimizations for First, Last, and ElementAt.
+    
     internal abstract class OrderedAsyncEnumerable<TElement> : AsyncIterator<TElement>, IOrderedAsyncEnumerable<TElement>
     {
-        internal IOrderedEnumerable<TElement> _enumerable;
-        internal IAsyncEnumerable<TElement> _source;
+        protected IAsyncEnumerable<TElement> _source;
+        private TElement[] _buffer;
+        private int[] _indexes;
+        private int _index;
 
         IOrderedAsyncEnumerable<TElement> IOrderedAsyncEnumerable<TElement>.CreateOrderedEnumerable<TKey>(Func<TElement, TKey> keySelector, IComparer<TKey> comparer, bool descending)
         {
@@ -24,7 +28,43 @@ namespace System.Linq
             return new OrderedAsyncEnumerableWithTask<TElement, TKey>(_source, keySelector, comparer, descending, this);
         }
 
-        internal abstract Task Initialize(CancellationToken cancellationToken);
+        protected override async ValueTask<bool> MoveNextCore(CancellationToken cancellationToken)
+        {
+            switch (state)
+            {
+                case AsyncIteratorState.Allocated:
+                    _buffer = await _source.ToArray(cancellationToken).ConfigureAwait(false);
+
+                    var sorter = GetAsyncEnumerableSorter(next: null);
+                    _indexes = await sorter.Sort(_buffer, _buffer.Length).ConfigureAwait(false);
+                    _index = 0;
+
+                    state = AsyncIteratorState.Iterating;
+                    goto case AsyncIteratorState.Iterating;
+
+                case AsyncIteratorState.Iterating:
+                    if (_index < _buffer.Length)
+                    {
+                        current = _buffer[_indexes[_index++]];
+                        return true;
+                    }
+
+                    await DisposeAsync().ConfigureAwait(false);
+                    break;
+            }
+
+            return false;
+        }
+
+        public override async ValueTask DisposeAsync()
+        {
+            _buffer = null;
+            _indexes = null;
+
+            await base.DisposeAsync().ConfigureAwait(false);
+        }
+
+        internal abstract AsyncEnumerableSorter<TElement> GetAsyncEnumerableSorter(AsyncEnumerableSorter<TElement> next);
     }
 
     internal sealed class OrderedAsyncEnumerable<TElement, TKey> : OrderedAsyncEnumerable<TElement>
@@ -34,9 +74,6 @@ namespace System.Linq
         private readonly Func<TElement, TKey> _keySelector;
         private readonly OrderedAsyncEnumerable<TElement> _parent;
 
-        private IEnumerator<TElement> _enumerator;
-        private IAsyncEnumerator<TElement> _parentEnumerator;
-
         public OrderedAsyncEnumerable(IAsyncEnumerable<TElement> source, Func<TElement, TKey> keySelector, IComparer<TKey> comparer, bool descending, OrderedAsyncEnumerable<TElement> parent)
         {
             Debug.Assert(source != null);
@@ -55,62 +92,16 @@ namespace System.Linq
             return new OrderedAsyncEnumerable<TElement, TKey>(_source, _keySelector, _comparer, _descending, _parent);
         }
 
-        public override async ValueTask DisposeAsync()
+        internal override AsyncEnumerableSorter<TElement> GetAsyncEnumerableSorter(AsyncEnumerableSorter<TElement> next)
         {
-            if (_enumerator != null)
-            {
-                _enumerator.Dispose();
-                _enumerator = null;
-            }
+            var sorter = new SyncKeySelectorAsyncEnumerableSorter<TElement, TKey>(_keySelector, _comparer, _descending, next);
 
-            if (_parentEnumerator != null)
+            if (_parent != null)
             {
-                await _parentEnumerator.DisposeAsync().ConfigureAwait(false);
-                _parentEnumerator = null;
+                return _parent.GetAsyncEnumerableSorter(sorter);
             }
 
-            await base.DisposeAsync().ConfigureAwait(false);
-        }
-
-        protected override async ValueTask<bool> MoveNextCore(CancellationToken cancellationToken)
-        {
-            switch (state)
-            {
-                case AsyncIteratorState.Allocated:
-
-                    await Initialize(cancellationToken).ConfigureAwait(false);
-
-                    _enumerator = _enumerable.GetEnumerator();
-                    state = AsyncIteratorState.Iterating;
-                    goto case AsyncIteratorState.Iterating;
-
-                case AsyncIteratorState.Iterating:
-                    if (_enumerator.MoveNext())
-                    {
-                        current = _enumerator.Current;
-                        return true;
-                    }
-
-                    await DisposeAsync().ConfigureAwait(false);
-                    break;
-            }
-
-            return false;
-        }
-
-        internal override async Task Initialize(CancellationToken cancellationToken)
-        {
-            if (_parent == null)
-            {
-                var buffer = await _source.ToList(cancellationToken).ConfigureAwait(false);
-                _enumerable = (!_descending ? buffer.OrderBy(_keySelector, _comparer) : buffer.OrderByDescending(_keySelector, _comparer));
-            }
-            else
-            {
-                _parentEnumerator = _parent.GetAsyncEnumerator(cancellationToken);
-                await _parent.Initialize(cancellationToken).ConfigureAwait(false);
-                _enumerable = _parent._enumerable.CreateOrderedEnumerable(_keySelector, _comparer, _descending);
-            }
+            return sorter;
         }
     }
 
@@ -121,9 +112,6 @@ namespace System.Linq
         private readonly Func<TElement, Task<TKey>> _keySelector;
         private readonly OrderedAsyncEnumerable<TElement> _parent;
 
-        private IEnumerator<TElement> _enumerator;
-        private IAsyncEnumerator<TElement> _parentEnumerator;
-
         public OrderedAsyncEnumerableWithTask(IAsyncEnumerable<TElement> source, Func<TElement, Task<TKey>> keySelector, IComparer<TKey> comparer, bool descending, OrderedAsyncEnumerable<TElement> parent)
         {
             Debug.Assert(source != null);
@@ -142,82 +130,179 @@ namespace System.Linq
             return new OrderedAsyncEnumerableWithTask<TElement, TKey>(_source, _keySelector, _comparer, _descending, _parent);
         }
 
-        public override async ValueTask DisposeAsync()
+        internal override AsyncEnumerableSorter<TElement> GetAsyncEnumerableSorter(AsyncEnumerableSorter<TElement> next)
         {
-            if (_enumerator != null)
+            var sorter = new AsyncKeySelectorAsyncEnumerableSorter<TElement, TKey>(_keySelector, _comparer, _descending, next);
+
+            if (_parent != null)
             {
-                _enumerator.Dispose();
-                _enumerator = null;
+                return _parent.GetAsyncEnumerableSorter(sorter);
             }
 
-            if (_parentEnumerator != null)
+            return sorter;
+        }
+    }
+
+    internal abstract class AsyncEnumerableSorter<TElement>
+    {
+        internal abstract ValueTask ComputeKeys(TElement[] elements, int count);
+
+        internal abstract int CompareKeys(int index1, int index2);
+
+        public async ValueTask<int[]> Sort(TElement[] elements, int count)
+        {
+            await ComputeKeys(elements, count).ConfigureAwait(false);
+
+            var map = new int[count];
+
+            for (var i = 0; i < count; i++)
             {
-                await _parentEnumerator.DisposeAsync().ConfigureAwait(false);
-                _parentEnumerator = null;
+                map[i] = i;
             }
 
-            await base.DisposeAsync().ConfigureAwait(false);
+            QuickSort(map, 0, count - 1);
+
+            return map;
         }
 
-        protected override async ValueTask<bool> MoveNextCore(CancellationToken cancellationToken)
+        private void QuickSort(int[] map, int left, int right)
         {
-            switch (state)
+            do
             {
-                case AsyncIteratorState.Allocated:
+                var i = left;
+                var j = right;
+                var x = map[i + (j - i >> 1)];
 
-                    await Initialize(cancellationToken).ConfigureAwait(false);
+                do
+                {
+                    while (i < map.Length && CompareKeys(x, map[i]) > 0)
+                    {
+                        i++;
+                    }
 
-                    _enumerator = _enumerable.GetEnumerator();
-                    state = AsyncIteratorState.Iterating;
-                    goto case AsyncIteratorState.Iterating;
+                    while (j >= 0 && CompareKeys(x, map[j]) < 0)
+                    {
+                        j--;
+                    }
 
-                case AsyncIteratorState.Iterating:
-                    if (_enumerator.MoveNext())
+                    if (i > j)
                     {
-                        current = _enumerator.Current;
-                        return true;
+                        break;
                     }
 
-                    await DisposeAsync().ConfigureAwait(false);
-                    break;
+                    if (i < j)
+                    {
+                        var temp = map[i];
+                        map[i] = map[j];
+                        map[j] = temp;
+                    }
+
+                    i++;
+                    j--;
+                }
+                while (i <= j);
+
+                if (j - left <= right - i)
+                {
+                    if (left < j)
+                    {
+                        QuickSort(map, left, j);
+                    }
+
+                    left = i;
+                }
+                else
+                {
+                    if (i < right)
+                    {
+                        QuickSort(map, i, right);
+                    }
+
+                    right = j;
+                }
             }
+            while (left < right);
+        }
+    }
 
-            return false;
+    internal abstract class AsyncEnumerableSorterBase<TElement, TKey> : AsyncEnumerableSorter<TElement>
+    {
+        private readonly IComparer<TKey> _comparer;
+        private readonly bool _descending;
+        protected readonly AsyncEnumerableSorter<TElement> _next;
+        protected TKey[] _keys;
+
+        public AsyncEnumerableSorterBase(IComparer<TKey> comparer, bool descending, AsyncEnumerableSorter<TElement> next)
+        {
+            _comparer = comparer;
+            _descending = descending;
+            _next = next;
         }
 
-        internal override async Task Initialize(CancellationToken cancellationToken)
+        internal override int CompareKeys(int index1, int index2)
         {
-            if (_parent == null)
+            var c = _comparer.Compare(_keys[index1], _keys[index2]);
+
+            if (c == 0)
             {
-                var buffer = await _source.ToList(cancellationToken).ConfigureAwait(false);
-                _enumerable = (!_descending ? buffer.OrderByAsync(_keySelector, _comparer) : buffer.OrderByDescendingAsync(_keySelector, _comparer));
+                return _next == null ? index1 - index2 : _next.CompareKeys(index1, index2);
             }
             else
             {
-                _parentEnumerator = _parent.GetAsyncEnumerator(cancellationToken);
-                await _parent.Initialize(cancellationToken).ConfigureAwait(false);
-                _enumerable = _parent._enumerable.CreateOrderedEnumerableAsync(_keySelector, _comparer, _descending);
+                return (_descending != (c > 0)) ? 1 : -1;
             }
         }
     }
 
-    internal static class EnumerableSortingExtensions
+    internal sealed class SyncKeySelectorAsyncEnumerableSorter<TElement, TKey> : AsyncEnumerableSorterBase<TElement, TKey>
     {
-        // TODO: Implement async sorting.
+        private readonly Func<TElement, TKey> _keySelector;
 
-        public static IOrderedEnumerable<TSource> OrderByAsync<TSource, TKey>(this IEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector, IComparer<TKey> comparer)
+        public SyncKeySelectorAsyncEnumerableSorter(Func<TElement, TKey> keySelector, IComparer<TKey> comparer, bool descending, AsyncEnumerableSorter<TElement> next)
+            : base(comparer, descending, next)
         {
-            return source.OrderBy(key => keySelector(key).GetAwaiter().GetResult(), comparer);
+            _keySelector = keySelector;
         }
 
-        public static IOrderedEnumerable<TSource> OrderByDescendingAsync<TSource, TKey>(this IEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector, IComparer<TKey> comparer)
+        internal override async ValueTask ComputeKeys(TElement[] elements, int count)
         {
-            return source.OrderByDescending(key => keySelector(key).GetAwaiter().GetResult(), comparer);
+            _keys = new TKey[count];
+
+            for (var i = 0; i < count; i++)
+            {
+                _keys[i] = _keySelector(elements[i]);
+            }
+
+            if (_next != null)
+            {
+                await _next.ComputeKeys(elements, count).ConfigureAwait(false);
+            }
+        }
+    }
+
+    internal sealed class AsyncKeySelectorAsyncEnumerableSorter<TElement, TKey> : AsyncEnumerableSorterBase<TElement, TKey>
+    {
+        private readonly Func<TElement, Task<TKey>> _keySelector;
+
+        public AsyncKeySelectorAsyncEnumerableSorter(Func<TElement, Task<TKey>> keySelector, IComparer<TKey> comparer, bool descending, AsyncEnumerableSorter<TElement> next)
+            : base(comparer, descending, next)
+        {
+            _keySelector = keySelector;
         }
 
-        public static IOrderedEnumerable<TSource> CreateOrderedEnumerableAsync<TSource, TKey>(this IOrderedEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector, IComparer<TKey> comparer, bool descending)
+        internal override async ValueTask ComputeKeys(TElement[] elements, int count)
         {
-            return source.CreateOrderedEnumerable(key => keySelector(key).GetAwaiter().GetResult(), comparer, descending);
+            _keys = new TKey[count];
+
+            for (var i = 0; i < count; i++)
+            {
+                _keys[i] = await _keySelector(elements[i]).ConfigureAwait(false);
+            }
+
+            if (_next != null)
+            {
+                await _next.ComputeKeys(elements, count).ConfigureAwait(false);
+            }
         }
     }
 }