瀏覽代碼

Adding IAsyncIListProvider<T> support for OrderBy.

Bart De Smet 6 年之前
父節點
當前提交
d904a174fa

+ 47 - 0
Ix.NET/Source/System.Linq.Async.Tests/System/Linq/Operators/OrderBy.cs

@@ -211,5 +211,52 @@ namespace Tests
 
             Assert.True(ress.SequenceEqual(resa.ToEnumerable()));
         }
+
+        [Fact]
+        public async Task OrderBy_Optimize_ToArray()
+        {
+            foreach (var seed in new[] { 1905, 1948, 1983 })
+            {
+                var rand = GetRandom(seed, 10_000);
+                var randAsync = rand.ToAsyncEnumerable();
+
+                var res = rand.OrderBy(x => x % 2).ThenBy(x => x % 3).ThenByDescending(x => x % 4);
+                var resAsync = randAsync.OrderBy(x => x % 2).ThenBy(x => x % 3).ThenByDescending(x => x % 4);
+
+                var lst = res.ToArray();
+                var lstAsync = await resAsync.ToArrayAsync();
+
+                Assert.True(lst.SequenceEqual(lstAsync));
+            }
+        }
+
+        [Fact]
+        public async Task OrderBy_Optimize_ToList()
+        {
+            foreach (var seed in new[] { 1905, 1948, 1983 })
+            {
+                var rand = GetRandom(seed, 10_000);
+                var randAsync = rand.ToAsyncEnumerable();
+
+                var res = rand.OrderBy(x => x % 2).ThenBy(x => x % 3).ThenByDescending(x => x % 4);
+                var resAsync = randAsync.OrderBy(x => x % 2).ThenBy(x => x % 3).ThenByDescending(x => x % 4);
+
+                var lst = res.ToList();
+                var lstAsync = await resAsync.ToListAsync();
+
+                Assert.True(lst.SequenceEqual(lstAsync));
+            }
+        }
+
+        private static IEnumerable<int> GetRandom(int seed, int count)
+        {
+            var rand = new Random(seed);
+
+            while (count > 0)
+            {
+                yield return rand.Next();
+                count--;
+            }
+        }
     }
 }

+ 73 - 3
Ix.NET/Source/System.Linq.Async/System/Linq/Operators/OrderedAsyncEnumerable.cs

@@ -2,6 +2,7 @@
 // The .NET Foundation licenses this file to you under the Apache 2.0 License.
 // See the LICENSE file in the project root for more information. 
 
+using System.Collections;
 using System.Collections.Generic;
 using System.Diagnostics;
 using System.Threading;
@@ -11,7 +12,7 @@ namespace System.Linq
 {
     // TODO: Add optimizations for First, Last, and ElementAt.
     
-    internal abstract class OrderedAsyncEnumerable<TElement> : AsyncIterator<TElement>, IOrderedAsyncEnumerable<TElement>
+    internal abstract class OrderedAsyncEnumerable<TElement> : AsyncIterator<TElement>, IOrderedAsyncEnumerable<TElement>, IAsyncIListProvider<TElement>
     {
         protected IAsyncEnumerable<TElement> _source;
         private TElement[] _buffer;
@@ -40,8 +41,6 @@ namespace System.Linq
                 case AsyncIteratorState.Allocated:
                     _buffer = await _source.ToArrayAsync(_cancellationToken).ConfigureAwait(false);
 
-                    // REVIEW: If we add selectors with CancellationToken support, we should feed the token to Sort.
-
                     AsyncEnumerableSorter<TElement> sorter = GetAsyncEnumerableSorter(next: null, _cancellationToken);
                     _indexes = await sorter.Sort(_buffer, _buffer.Length).ConfigureAwait(false);
                     _index = 0;
@@ -72,6 +71,77 @@ namespace System.Linq
         }
 
         internal abstract AsyncEnumerableSorter<TElement> GetAsyncEnumerableSorter(AsyncEnumerableSorter<TElement> next, CancellationToken cancellationToken);
+
+        public async ValueTask<TElement[]> ToArrayAsync(CancellationToken cancellationToken)
+        {
+            AsyncEnumerableHelpers.ArrayWithLength<TElement> elements = await AsyncEnumerableHelpers.ToArrayWithLength(_source, cancellationToken).ConfigureAwait(false);
+
+            int count = elements.Length;
+
+            if (count == 0)
+            {
+#if NO_ARRAY_EMPTY
+                return EmptyArray<TElement>.Value;
+#else
+                return Array.Empty<TElement>();
+#endif
+            }
+
+            TElement[] array = elements.Array;
+
+            int[] map = await SortedMap(array, count, cancellationToken).ConfigureAwait(false);
+
+            var result = new TElement[count];
+
+            for (int i = 0; i < result.Length; i++)
+            {
+                result[i] = array[map[i]];
+            }
+
+            return result;
+        }
+
+        public async ValueTask<List<TElement>> ToListAsync(CancellationToken cancellationToken)
+        {
+            AsyncEnumerableHelpers.ArrayWithLength<TElement> elements = await AsyncEnumerableHelpers.ToArrayWithLength(_source, cancellationToken).ConfigureAwait(false);
+
+            int count = elements.Length;
+
+            if (count == 0)
+            {
+                return new List<TElement>(capacity: 0);
+            }
+
+            TElement[] array = elements.Array;
+
+            int[] map = await SortedMap(array, count, cancellationToken).ConfigureAwait(false);
+
+            var result = new List<TElement>(count);
+
+            for (int i = 0; i < count; i++)
+            {
+                result.Add(array[map[i]]);
+            }
+
+            return result;
+        }
+
+        public async ValueTask<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
+        {
+            if (_source is IAsyncIListProvider<TElement> listProv)
+            {
+                int count = await listProv.GetCountAsync(onlyIfCheap, cancellationToken).ConfigureAwait(false);
+            }
+
+            return !onlyIfCheap || _source is ICollection<TElement> || _source is ICollection ? await _source.CountAsync(cancellationToken).ConfigureAwait(false) : -1;
+        }
+
+        private ValueTask<int[]> SortedMap(TElement[] elements, int count, CancellationToken cancellationToken)
+        {
+            AsyncEnumerableSorter<TElement> sorter = GetAsyncEnumerableSorter(next: null, cancellationToken);
+
+            return sorter.Sort(elements, count);
+        }
     }
 
     internal sealed class OrderedAsyncEnumerable<TElement, TKey> : OrderedAsyncEnumerable<TElement>