// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System; using System.Collections.Generic; using System.Linq; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerable { public static IOrderedAsyncEnumerable OrderBy(this IAsyncEnumerable source, Func keySelector, IComparer comparer) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (comparer == null) throw new ArgumentNullException(nameof(comparer)); return new OrderedAsyncEnumerable( CreateEnumerable(() => { var current = default(IEnumerable); return CreateEnumerator( async ct => { if (current == null) { current = await source.ToList(ct) .ConfigureAwait(false); return true; } return false; }, () => current, () => { } ); }), keySelector, comparer ); } public static IOrderedAsyncEnumerable OrderBy(this IAsyncEnumerable source, Func keySelector) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); return source.OrderBy(keySelector, Comparer.Default); } public static IOrderedAsyncEnumerable OrderByDescending(this IAsyncEnumerable source, Func keySelector, IComparer comparer) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (comparer == null) throw new ArgumentNullException(nameof(comparer)); return source.OrderBy(keySelector, new ReverseComparer(comparer)); } public static IOrderedAsyncEnumerable OrderByDescending(this IAsyncEnumerable source, Func keySelector) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); return source.OrderByDescending(keySelector, Comparer.Default); } public static IOrderedAsyncEnumerable ThenBy(this IOrderedAsyncEnumerable source, Func keySelector) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); return source.ThenBy(keySelector, Comparer.Default); } public static IOrderedAsyncEnumerable ThenBy(this IOrderedAsyncEnumerable source, Func keySelector, IComparer comparer) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (comparer == null) throw new ArgumentNullException(nameof(comparer)); return source.CreateOrderedEnumerable(keySelector, comparer, false); } public static IOrderedAsyncEnumerable ThenByDescending(this IOrderedAsyncEnumerable source, Func keySelector) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); return source.ThenByDescending(keySelector, Comparer.Default); } public static IOrderedAsyncEnumerable ThenByDescending(this IOrderedAsyncEnumerable source, Func keySelector, IComparer comparer) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (comparer == null) throw new ArgumentNullException(nameof(comparer)); return source.CreateOrderedEnumerable(keySelector, comparer, true); } private class OrderedAsyncEnumerable : IOrderedAsyncEnumerable { private readonly IComparer comparer; private readonly IAsyncEnumerable> equivalenceClasses; private readonly Func keySelector; public OrderedAsyncEnumerable(IAsyncEnumerable> equivalenceClasses, Func keySelector, IComparer comparer) { this.equivalenceClasses = equivalenceClasses; this.keySelector = keySelector; this.comparer = comparer; } public IOrderedAsyncEnumerable CreateOrderedEnumerable(Func keySelector, IComparer comparer, bool descending) { if (descending) comparer = new ReverseComparer(comparer); return new OrderedAsyncEnumerable(Classes(), keySelector, comparer); } public IAsyncEnumerator GetEnumerator() { return Classes() .SelectMany(x => x.ToAsyncEnumerable()) .GetEnumerator(); } private IAsyncEnumerable> Classes() { return CreateEnumerable(() => { var e = equivalenceClasses.GetEnumerator(); var list = new List>(); var e1 = default(IEnumerator>); var cts = new CancellationTokenDisposable(); var d1 = new AssignableDisposable(); var d = Disposable.Create(cts, e, d1); var f = default(Func>); f = async ct => { if (await e.MoveNext(ct) .ConfigureAwait(false)) { list.AddRange(e.Current.OrderBy(keySelector, comparer) .GroupUntil(keySelector, x => x, comparer)); return await f(ct) .ConfigureAwait(false); } e.Dispose(); e1 = list.GetEnumerator(); d1.Disposable = e1; return e1.MoveNext(); }; return CreateEnumerator( async ct => { if (e1 != null) { return e1.MoveNext(); } return await f(cts.Token) .ConfigureAwait(false); }, () => e1.Current, d.Dispose, e ); }); } } private class ReverseComparer : IComparer { private readonly IComparer comparer; public ReverseComparer(IComparer comparer) { this.comparer = comparer; } public int Compare(T x, T y) { return -comparer.Compare(x, y); } } } }