// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System.Collections.Generic; using System.Diagnostics; using System.Threading.Tasks; namespace System.Linq { internal abstract class OrderedAsyncEnumerable : AsyncIterator, IOrderedAsyncEnumerable { internal IOrderedEnumerable enumerable; internal IAsyncEnumerable source; IOrderedAsyncEnumerable IOrderedAsyncEnumerable.CreateOrderedEnumerable(Func keySelector, IComparer comparer, bool descending) { return new OrderedAsyncEnumerable(source, keySelector, comparer, descending, this); } IOrderedAsyncEnumerable IOrderedAsyncEnumerable.CreateOrderedEnumerable(Func> keySelector, IComparer comparer, bool descending) { return new OrderedAsyncEnumerableWithTask(source, keySelector, comparer, descending, this); } internal abstract Task Initialize(); } internal sealed class OrderedAsyncEnumerable : OrderedAsyncEnumerable { private readonly IComparer comparer; private readonly bool descending; private readonly Func keySelector; private readonly OrderedAsyncEnumerable parent; private IEnumerator enumerator; private IAsyncEnumerator parentEnumerator; public OrderedAsyncEnumerable(IAsyncEnumerable source, Func keySelector, IComparer comparer, bool descending, OrderedAsyncEnumerable parent) { Debug.Assert(source != null); Debug.Assert(keySelector != null); Debug.Assert(comparer != null); this.source = source; this.keySelector = keySelector; this.comparer = comparer; this.descending = descending; this.parent = parent; } public override AsyncIterator Clone() { return new OrderedAsyncEnumerable(source, keySelector, comparer, descending, parent); } public override async Task DisposeAsync() { if (enumerator != null) { enumerator.Dispose(); enumerator = null; } if (parentEnumerator != null) { await parentEnumerator.DisposeAsync().ConfigureAwait(false); parentEnumerator = null; } await base.DisposeAsync().ConfigureAwait(false); } protected override async Task MoveNextCore() { switch (state) { case AsyncIteratorState.Allocated: await Initialize().ConfigureAwait(false); enumerator = enumerable.GetEnumerator(); state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: if (enumerator.MoveNext()) { current = enumerator.Current; return true; } await DisposeAsync().ConfigureAwait(false); break; } return false; } internal override async Task Initialize() { if (parent == null) { var buffer = await source.ToList().ConfigureAwait(false); enumerable = (!@descending ? buffer.OrderBy(keySelector, comparer) : buffer.OrderByDescending(keySelector, comparer)); } else { parentEnumerator = parent.GetAsyncEnumerator(); await parent.Initialize().ConfigureAwait(false); enumerable = parent.enumerable.CreateOrderedEnumerable(keySelector, comparer, @descending); } } } internal sealed class OrderedAsyncEnumerableWithTask : OrderedAsyncEnumerable { private readonly IComparer comparer; private readonly bool descending; private readonly Func> keySelector; private readonly OrderedAsyncEnumerable parent; private IEnumerator enumerator; private IAsyncEnumerator parentEnumerator; public OrderedAsyncEnumerableWithTask(IAsyncEnumerable source, Func> keySelector, IComparer comparer, bool descending, OrderedAsyncEnumerable parent) { Debug.Assert(source != null); Debug.Assert(keySelector != null); Debug.Assert(comparer != null); this.source = source; this.keySelector = keySelector; this.comparer = comparer; this.descending = descending; this.parent = parent; } public override AsyncIterator Clone() { return new OrderedAsyncEnumerableWithTask(source, keySelector, comparer, descending, parent); } public override async Task DisposeAsync() { if (enumerator != null) { enumerator.Dispose(); enumerator = null; } if (parentEnumerator != null) { await parentEnumerator.DisposeAsync().ConfigureAwait(false); parentEnumerator = null; } await base.DisposeAsync().ConfigureAwait(false); } protected override async Task MoveNextCore() { switch (state) { case AsyncIteratorState.Allocated: await Initialize().ConfigureAwait(false); enumerator = enumerable.GetEnumerator(); state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: if (enumerator.MoveNext()) { current = enumerator.Current; return true; } await DisposeAsync().ConfigureAwait(false); break; } return false; } internal override async Task Initialize() { if (parent == null) { var buffer = await source.ToList().ConfigureAwait(false); enumerable = (!@descending ? buffer.OrderByAsync(keySelector, comparer) : buffer.OrderByDescendingAsync(keySelector, comparer)); } else { parentEnumerator = parent.GetAsyncEnumerator(); await parent.Initialize().ConfigureAwait(false); enumerable = parent.enumerable.CreateOrderedEnumerableAsync(keySelector, comparer, @descending); } } } internal static class EnumerableSortingExtensions { // TODO: Implement async sorting. public static IOrderedEnumerable OrderByAsync(this IEnumerable source, Func> keySelector, IComparer comparer) { return source.OrderBy(key => keySelector(key).GetAwaiter().GetResult(), comparer); } public static IOrderedEnumerable OrderByDescendingAsync(this IEnumerable source, Func> keySelector, IComparer comparer) { return source.OrderByDescending(key => keySelector(key).GetAwaiter().GetResult(), comparer); } public static IOrderedEnumerable CreateOrderedEnumerableAsync(this IOrderedEnumerable source, Func> keySelector, IComparer comparer, bool descending) { return source.CreateOrderedEnumerable(key => keySelector(key).GetAwaiter().GetResult(), comparer, descending); } } }