Browse Source

Initial work on async variants for OrderBy.

Bart De Smet 8 years ago
parent
commit
e280f57f0f

+ 2 - 0
Ix.NET/Source/System.Interactive.Async/IOrderedAsyncEnumerable.cs

@@ -3,11 +3,13 @@
 // See the LICENSE file in the project root for more information. 
 
 using System.Collections.Generic;
+using System.Threading.Tasks;
 
 namespace System.Linq
 {
     public interface IOrderedAsyncEnumerable<out TElement> : IAsyncEnumerable<TElement>
     {
         IOrderedAsyncEnumerable<TElement> CreateOrderedEnumerable<TKey>(Func<TElement, TKey> keySelector, IComparer<TKey> comparer, bool descending);
+        IOrderedAsyncEnumerable<TElement> CreateOrderedEnumerable<TKey>(Func<TElement, Task<TKey>> keySelector, IComparer<TKey> comparer, bool descending);
     }
 }

+ 88 - 0
Ix.NET/Source/System.Interactive.Async/OrderBy.cs

@@ -3,6 +3,7 @@
 // See the LICENSE file in the project root for more information. 
 
 using System.Collections.Generic;
+using System.Threading.Tasks;
 
 namespace System.Linq
 {
@@ -18,6 +19,16 @@ namespace System.Linq
             return source.OrderBy(keySelector, Comparer<TKey>.Default);
         }
 
+        public static IOrderedAsyncEnumerable<TSource> OrderBy<TSource, TKey>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (keySelector == null)
+                throw new ArgumentNullException(nameof(keySelector));
+
+            return source.OrderBy(keySelector, Comparer<TKey>.Default);
+        }
+
         public static IOrderedAsyncEnumerable<TSource> OrderBy<TSource, TKey>(this IAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector, IComparer<TKey> comparer)
         {
             if (source == null)
@@ -29,6 +40,17 @@ namespace System.Linq
 
             return new OrderedAsyncEnumerable<TSource, TKey>(source, keySelector, comparer, descending: false, parent: null);
         }
+        public static IOrderedAsyncEnumerable<TSource> OrderBy<TSource, TKey>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector, IComparer<TKey> comparer)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (keySelector == null)
+                throw new ArgumentNullException(nameof(keySelector));
+            if (comparer == null)
+                throw new ArgumentNullException(nameof(comparer));
+
+            return new OrderedAsyncEnumerableWithTask<TSource, TKey>(source, keySelector, comparer, descending: false, parent: null);
+        }
 
         public static IOrderedAsyncEnumerable<TSource> OrderByDescending<TSource, TKey>(this IAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector)
         {
@@ -40,6 +62,16 @@ namespace System.Linq
             return source.OrderByDescending(keySelector, Comparer<TKey>.Default);
         }
 
+        public static IOrderedAsyncEnumerable<TSource> OrderByDescending<TSource, TKey>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (keySelector == null)
+                throw new ArgumentNullException(nameof(keySelector));
+
+            return source.OrderByDescending(keySelector, Comparer<TKey>.Default);
+        }
+
         public static IOrderedAsyncEnumerable<TSource> OrderByDescending<TSource, TKey>(this IAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector, IComparer<TKey> comparer)
         {
             if (source == null)
@@ -52,6 +84,18 @@ namespace System.Linq
             return new OrderedAsyncEnumerable<TSource, TKey>(source, keySelector, comparer, descending: true, parent: null);
         }
 
+        public static IOrderedAsyncEnumerable<TSource> OrderByDescending<TSource, TKey>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector, IComparer<TKey> comparer)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (keySelector == null)
+                throw new ArgumentNullException(nameof(keySelector));
+            if (comparer == null)
+                throw new ArgumentNullException(nameof(comparer));
+
+            return new OrderedAsyncEnumerableWithTask<TSource, TKey>(source, keySelector, comparer, descending: true, parent: null);
+        }
+
         public static IOrderedAsyncEnumerable<TSource> ThenBy<TSource, TKey>(this IOrderedAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector)
         {
             if (source == null)
@@ -62,6 +106,16 @@ namespace System.Linq
             return source.ThenBy(keySelector, Comparer<TKey>.Default);
         }
 
+        public static IOrderedAsyncEnumerable<TSource> ThenBy<TSource, TKey>(this IOrderedAsyncEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (keySelector == null)
+                throw new ArgumentNullException(nameof(keySelector));
+
+            return source.ThenBy(keySelector, Comparer<TKey>.Default);
+        }
+
         public static IOrderedAsyncEnumerable<TSource> ThenBy<TSource, TKey>(this IOrderedAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector, IComparer<TKey> comparer)
         {
             if (source == null)
@@ -74,6 +128,18 @@ namespace System.Linq
             return source.CreateOrderedEnumerable(keySelector, comparer, descending: false);
         }
 
+        public static IOrderedAsyncEnumerable<TSource> ThenBy<TSource, TKey>(this IOrderedAsyncEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector, IComparer<TKey> comparer)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (keySelector == null)
+                throw new ArgumentNullException(nameof(keySelector));
+            if (comparer == null)
+                throw new ArgumentNullException(nameof(comparer));
+
+            return source.CreateOrderedEnumerable(keySelector, comparer, descending: false);
+        }
+
         public static IOrderedAsyncEnumerable<TSource> ThenByDescending<TSource, TKey>(this IOrderedAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector)
         {
             if (source == null)
@@ -84,6 +150,16 @@ namespace System.Linq
             return source.ThenByDescending(keySelector, Comparer<TKey>.Default);
         }
 
+        public static IOrderedAsyncEnumerable<TSource> ThenByDescending<TSource, TKey>(this IOrderedAsyncEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (keySelector == null)
+                throw new ArgumentNullException(nameof(keySelector));
+
+            return source.ThenByDescending(keySelector, Comparer<TKey>.Default);
+        }
+
         public static IOrderedAsyncEnumerable<TSource> ThenByDescending<TSource, TKey>(this IOrderedAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector, IComparer<TKey> comparer)
         {
             if (source == null)
@@ -95,5 +171,17 @@ namespace System.Linq
 
             return source.CreateOrderedEnumerable(keySelector, comparer, descending: true);
         }
+
+        public static IOrderedAsyncEnumerable<TSource> ThenByDescending<TSource, TKey>(this IOrderedAsyncEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector, IComparer<TKey> comparer)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (keySelector == null)
+                throw new ArgumentNullException(nameof(keySelector));
+            if (comparer == null)
+                throw new ArgumentNullException(nameof(comparer));
+
+            return source.CreateOrderedEnumerable(keySelector, comparer, descending: true);
+        }
     }
 }

+ 112 - 1
Ix.NET/Source/System.Interactive.Async/OrderedAsyncEnumerable.cs

@@ -18,6 +18,11 @@ namespace System.Linq
             return new OrderedAsyncEnumerable<TElement, TKey>(source, keySelector, comparer, descending, this);
         }
 
+        IOrderedAsyncEnumerable<TElement> IOrderedAsyncEnumerable<TElement>.CreateOrderedEnumerable<TKey>(Func<TElement, Task<TKey>> keySelector, IComparer<TKey> comparer, bool descending)
+        {
+            return new OrderedAsyncEnumerableWithTask<TElement, TKey>(source, keySelector, comparer, descending, this);
+        }
+
         internal abstract Task Initialize();
     }
 
@@ -49,7 +54,6 @@ namespace System.Linq
             return new OrderedAsyncEnumerable<TElement, TKey>(source, keySelector, comparer, descending, parent);
         }
 
-
         public override async Task DisposeAsync()
         {
             if (enumerator != null)
@@ -108,4 +112,111 @@ namespace System.Linq
             }
         }
     }
+
+    internal sealed class OrderedAsyncEnumerableWithTask<TElement, TKey> : OrderedAsyncEnumerable<TElement>
+    {
+        private readonly IComparer<TKey> comparer;
+        private readonly bool descending;
+        private readonly Func<TElement, Task<TKey>> keySelector;
+        private readonly OrderedAsyncEnumerable<TElement> parent;
+
+        private IEnumerator<TElement> enumerator;
+        private IAsyncEnumerator<TElement> parentEnumerator;
+
+        public OrderedAsyncEnumerableWithTask(IAsyncEnumerable<TElement> source, Func<TElement, Task<TKey>> keySelector, IComparer<TKey> comparer, bool descending, OrderedAsyncEnumerable<TElement> parent)
+        {
+            Debug.Assert(source != null);
+            Debug.Assert(keySelector != null);
+            Debug.Assert(comparer != null);
+
+            this.source = source;
+            this.keySelector = keySelector;
+            this.comparer = comparer;
+            this.descending = descending;
+            this.parent = parent;
+        }
+
+        public override AsyncEnumerable.AsyncIterator<TElement> Clone()
+        {
+            return new OrderedAsyncEnumerableWithTask<TElement, TKey>(source, keySelector, comparer, descending, parent);
+        }
+
+        public override async Task DisposeAsync()
+        {
+            if (enumerator != null)
+            {
+                enumerator.Dispose();
+                enumerator = null;
+            }
+
+            if (parentEnumerator != null)
+            {
+                await parentEnumerator.DisposeAsync().ConfigureAwait(false);
+                parentEnumerator = null;
+            }
+
+            await base.DisposeAsync().ConfigureAwait(false);
+        }
+
+        protected override async Task<bool> MoveNextCore()
+        {
+            switch (state)
+            {
+                case AsyncEnumerable.AsyncIteratorState.Allocated:
+
+                    await Initialize().ConfigureAwait(false);
+
+                    enumerator = enumerable.GetEnumerator();
+                    state = AsyncEnumerable.AsyncIteratorState.Iterating;
+                    goto case AsyncEnumerable.AsyncIteratorState.Iterating;
+
+                case AsyncEnumerable.AsyncIteratorState.Iterating:
+                    if (enumerator.MoveNext())
+                    {
+                        current = enumerator.Current;
+                        return true;
+                    }
+
+                    await DisposeAsync().ConfigureAwait(false);
+                    break;
+            }
+
+            return false;
+        }
+
+        internal override async Task Initialize()
+        {
+            if (parent == null)
+            {
+                var buffer = await source.ToList().ConfigureAwait(false);
+                enumerable = (!@descending ? buffer.OrderByAsync(keySelector, comparer) : buffer.OrderByDescendingAsync(keySelector, comparer));
+            }
+            else
+            {
+                parentEnumerator = parent.GetAsyncEnumerator();
+                await parent.Initialize().ConfigureAwait(false);
+                enumerable = parent.enumerable.CreateOrderedEnumerableAsync(keySelector, comparer, @descending);
+            }
+        }
+    }
+
+    internal static class EnumerableSortingExtensions
+    {
+        // TODO: Implement async sorting.
+
+        public static IOrderedEnumerable<TSource> OrderByAsync<TSource, TKey>(this IEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector, IComparer<TKey> comparer)
+        {
+            return source.OrderBy(key => keySelector(key).GetAwaiter().GetResult(), comparer);
+        }
+
+        public static IOrderedEnumerable<TSource> OrderByDescendingAsync<TSource, TKey>(this IEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector, IComparer<TKey> comparer)
+        {
+            return source.OrderByDescending(key => keySelector(key).GetAwaiter().GetResult(), comparer);
+        }
+
+        public static IOrderedEnumerable<TSource> CreateOrderedEnumerableAsync<TSource, TKey>(this IOrderedEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector, IComparer<TKey> comparer, bool descending)
+        {
+            return source.CreateOrderedEnumerable(key => keySelector(key).GetAwaiter().GetResult(), comparer, descending);
+        }
+    }
 }

+ 24 - 24
Ix.NET/Source/Tests/AsyncTests.Single.cs

@@ -1668,35 +1668,35 @@ namespace Tests
         [Fact]
         public void OrderBy_Null()
         {
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.OrderBy<int, int>(null, x => x));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.OrderBy<int, int>(AsyncEnumerable.Return(42), null));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.OrderBy<int, int>(default(IAsyncEnumerable<int>), x => x));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.OrderBy<int, int>(AsyncEnumerable.Return(42), default(Func<int, int>)));
 
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.OrderBy<int, int>(null, x => x, Comparer<int>.Default));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.OrderBy<int, int>(AsyncEnumerable.Return(42), null, Comparer<int>.Default));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.OrderBy<int, int>(AsyncEnumerable.Return(42), x => x, null));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.OrderBy<int, int>(default(IAsyncEnumerable<int>), x => x, Comparer<int>.Default));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.OrderBy<int, int>(AsyncEnumerable.Return(42), default(Func<int, int>), Comparer<int>.Default));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.OrderBy<int, int>(AsyncEnumerable.Return(42), x => x, default(IComparer<int>)));
 
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.OrderByDescending<int, int>(null, x => x));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.OrderByDescending<int, int>(AsyncEnumerable.Return(42), null));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.OrderByDescending<int, int>(default(IAsyncEnumerable<int>), x => x));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.OrderByDescending<int, int>(AsyncEnumerable.Return(42), default(Func<int, int>)));
 
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.OrderByDescending<int, int>(null, x => x, Comparer<int>.Default));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.OrderByDescending<int, int>(AsyncEnumerable.Return(42), null, Comparer<int>.Default));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.OrderByDescending<int, int>(AsyncEnumerable.Return(42), x => x, null));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.OrderByDescending<int, int>(default(IAsyncEnumerable<int>), x => x, Comparer<int>.Default));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.OrderByDescending<int, int>(AsyncEnumerable.Return(42), default(Func<int, int>), Comparer<int>.Default));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.OrderByDescending<int, int>(AsyncEnumerable.Return(42), x => x, default(IComparer<int>)));
 
             var xs = AsyncEnumerable.Return(42).OrderBy(x => x);
 
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ThenBy<int, int>(null, x => x));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ThenBy<int, int>(xs, null));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ThenBy<int, int>(default(IOrderedAsyncEnumerable<int>), x => x));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ThenBy<int, int>(xs, default(Func<int, int>)));
 
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ThenBy<int, int>(null, x => x, Comparer<int>.Default));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ThenBy<int, int>(xs, null, Comparer<int>.Default));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ThenBy<int, int>(xs, x => x, null));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ThenBy<int, int>(default(IOrderedAsyncEnumerable<int>), x => x, Comparer<int>.Default));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ThenBy<int, int>(xs, default(Func<int, int>), Comparer<int>.Default));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ThenBy<int, int>(xs, x => x, default(IComparer<int>)));
 
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ThenByDescending<int, int>(null, x => x));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ThenByDescending<int, int>(xs, null));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ThenByDescending<int, int>(default(IOrderedAsyncEnumerable<int>), x => x));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ThenByDescending<int, int>(xs, default(Func<int, int>)));
 
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ThenByDescending<int, int>(null, x => x, Comparer<int>.Default));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ThenByDescending<int, int>(xs, null, Comparer<int>.Default));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ThenByDescending<int, int>(xs, x => x, null));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ThenByDescending<int, int>(default(IOrderedAsyncEnumerable<int>), x => x, Comparer<int>.Default));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ThenByDescending<int, int>(xs, default(Func<int, int>), Comparer<int>.Default));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.ThenByDescending<int, int>(xs, x => x, default(IComparer<int>)));
         }
 
         [Fact]
@@ -1716,7 +1716,7 @@ namespace Tests
         {
             var ex = new Exception("Bang!");
             var xs = new[] { 2, 6, 1, 5, 7, 8, 9, 3, 4, 0 }.ToAsyncEnumerable();
-            var ys = xs.OrderBy<int, int>(x => { throw ex; });
+            var ys = xs.OrderBy<int, int>(new Func<int, int>(x => { throw ex; }));
 
             var e = ys.GetAsyncEnumerator();
             AssertThrows(() => e.MoveNextAsync().Wait(WaitTimeoutMs), (Exception ex_) => ((AggregateException)ex_).Flatten().InnerExceptions.Single() == ex);
@@ -1736,7 +1736,7 @@ namespace Tests
         {
             var ex = new Exception("Bang!");
             var xs = new[] { 2, 6, 1, 5, 7, 8, 9, 3, 4, 0 }.ToAsyncEnumerable();
-            var ys = xs.OrderBy(x => x).ThenBy<int, int>(x => { throw ex; });
+            var ys = xs.OrderBy(x => x).ThenBy<int, int>(new Func<int, int>(x => { throw ex; }));
 
             var e = ys.GetAsyncEnumerator();
             AssertThrows(() => e.MoveNextAsync().Wait(WaitTimeoutMs), (Exception ex_) => ((AggregateException)ex_).Flatten().InnerExceptions.Single() == ex);
@@ -1759,7 +1759,7 @@ namespace Tests
         {
             var ex = new Exception("Bang!");
             var xs = new[] { 2, 6, 1, 5, 7, 8, 9, 3, 4, 0 }.ToAsyncEnumerable();
-            var ys = xs.OrderByDescending<int, int>(x => { throw ex; });
+            var ys = xs.OrderByDescending<int, int>(new Func<int, int>(x => { throw ex; }));
 
             var e = ys.GetAsyncEnumerator();
             AssertThrows(() => e.MoveNextAsync().Wait(WaitTimeoutMs), (Exception ex_) => ((AggregateException)ex_).Flatten().InnerExceptions.Single() == ex);
@@ -1779,7 +1779,7 @@ namespace Tests
         {
             var ex = new Exception("Bang!");
             var xs = new[] { 2, 6, 1, 5, 7, 8, 9, 3, 4, 0 }.ToAsyncEnumerable();
-            var ys = xs.OrderBy<int, int>(x => x).ThenByDescending<int, int>(x => { throw ex; });
+            var ys = xs.OrderBy<int, int>(x => x).ThenByDescending<int, int>(new Func<int, int>(x => { throw ex; }));
 
             var e = ys.GetAsyncEnumerator();
             AssertThrows(() => e.MoveNextAsync().Wait(WaitTimeoutMs), (Exception ex_) => ((AggregateException)ex_).Flatten().InnerExceptions.Single() == ex);