OrderedAsyncEnumerable.cs 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Diagnostics;
  6. using System.Threading;
  7. using System.Threading.Tasks;
  8. namespace System.Linq
  9. {
  10. internal abstract class OrderedAsyncEnumerable<TElement> : AsyncIterator<TElement>, IOrderedAsyncEnumerable<TElement>
  11. {
  12. internal IOrderedEnumerable<TElement> enumerable;
  13. internal IAsyncEnumerable<TElement> source;
  14. IOrderedAsyncEnumerable<TElement> IOrderedAsyncEnumerable<TElement>.CreateOrderedEnumerable<TKey>(Func<TElement, TKey> keySelector, IComparer<TKey> comparer, bool descending)
  15. {
  16. return new OrderedAsyncEnumerable<TElement, TKey>(source, keySelector, comparer, descending, this);
  17. }
  18. IOrderedAsyncEnumerable<TElement> IOrderedAsyncEnumerable<TElement>.CreateOrderedEnumerable<TKey>(Func<TElement, Task<TKey>> keySelector, IComparer<TKey> comparer, bool descending)
  19. {
  20. return new OrderedAsyncEnumerableWithTask<TElement, TKey>(source, keySelector, comparer, descending, this);
  21. }
  22. internal abstract Task Initialize(CancellationToken cancellationToken);
  23. }
  24. internal sealed class OrderedAsyncEnumerable<TElement, TKey> : OrderedAsyncEnumerable<TElement>
  25. {
  26. private readonly IComparer<TKey> _comparer;
  27. private readonly bool _descending;
  28. private readonly Func<TElement, TKey> _keySelector;
  29. private readonly OrderedAsyncEnumerable<TElement> _parent;
  30. private IEnumerator<TElement> _enumerator;
  31. private IAsyncEnumerator<TElement> _parentEnumerator;
  32. public OrderedAsyncEnumerable(IAsyncEnumerable<TElement> source, Func<TElement, TKey> keySelector, IComparer<TKey> comparer, bool descending, OrderedAsyncEnumerable<TElement> parent)
  33. {
  34. Debug.Assert(source != null);
  35. Debug.Assert(keySelector != null);
  36. Debug.Assert(comparer != null);
  37. this.source = source;
  38. _keySelector = keySelector;
  39. _comparer = comparer;
  40. _descending = descending;
  41. _parent = parent;
  42. }
  43. public override AsyncIterator<TElement> Clone()
  44. {
  45. return new OrderedAsyncEnumerable<TElement, TKey>(source, _keySelector, _comparer, _descending, _parent);
  46. }
  47. public override async ValueTask DisposeAsync()
  48. {
  49. if (_enumerator != null)
  50. {
  51. _enumerator.Dispose();
  52. _enumerator = null;
  53. }
  54. if (_parentEnumerator != null)
  55. {
  56. await _parentEnumerator.DisposeAsync().ConfigureAwait(false);
  57. _parentEnumerator = null;
  58. }
  59. await base.DisposeAsync().ConfigureAwait(false);
  60. }
  61. protected override async ValueTask<bool> MoveNextCore(CancellationToken cancellationToken)
  62. {
  63. switch (state)
  64. {
  65. case AsyncIteratorState.Allocated:
  66. await Initialize(cancellationToken).ConfigureAwait(false);
  67. _enumerator = enumerable.GetEnumerator();
  68. state = AsyncIteratorState.Iterating;
  69. goto case AsyncIteratorState.Iterating;
  70. case AsyncIteratorState.Iterating:
  71. if (_enumerator.MoveNext())
  72. {
  73. current = _enumerator.Current;
  74. return true;
  75. }
  76. await DisposeAsync().ConfigureAwait(false);
  77. break;
  78. }
  79. return false;
  80. }
  81. internal override async Task Initialize(CancellationToken cancellationToken)
  82. {
  83. if (_parent == null)
  84. {
  85. var buffer = await source.ToList(cancellationToken).ConfigureAwait(false);
  86. enumerable = (!_descending ? buffer.OrderBy(_keySelector, _comparer) : buffer.OrderByDescending(_keySelector, _comparer));
  87. }
  88. else
  89. {
  90. _parentEnumerator = _parent.GetAsyncEnumerator(cancellationToken);
  91. await _parent.Initialize(cancellationToken).ConfigureAwait(false);
  92. enumerable = _parent.enumerable.CreateOrderedEnumerable(_keySelector, _comparer, _descending);
  93. }
  94. }
  95. }
  96. internal sealed class OrderedAsyncEnumerableWithTask<TElement, TKey> : OrderedAsyncEnumerable<TElement>
  97. {
  98. private readonly IComparer<TKey> _comparer;
  99. private readonly bool _descending;
  100. private readonly Func<TElement, Task<TKey>> _keySelector;
  101. private readonly OrderedAsyncEnumerable<TElement> _parent;
  102. private IEnumerator<TElement> _enumerator;
  103. private IAsyncEnumerator<TElement> _parentEnumerator;
  104. public OrderedAsyncEnumerableWithTask(IAsyncEnumerable<TElement> source, Func<TElement, Task<TKey>> keySelector, IComparer<TKey> comparer, bool descending, OrderedAsyncEnumerable<TElement> parent)
  105. {
  106. Debug.Assert(source != null);
  107. Debug.Assert(keySelector != null);
  108. Debug.Assert(comparer != null);
  109. this.source = source;
  110. _keySelector = keySelector;
  111. _comparer = comparer;
  112. _descending = descending;
  113. _parent = parent;
  114. }
  115. public override AsyncIterator<TElement> Clone()
  116. {
  117. return new OrderedAsyncEnumerableWithTask<TElement, TKey>(source, _keySelector, _comparer, _descending, _parent);
  118. }
  119. public override async ValueTask DisposeAsync()
  120. {
  121. if (_enumerator != null)
  122. {
  123. _enumerator.Dispose();
  124. _enumerator = null;
  125. }
  126. if (_parentEnumerator != null)
  127. {
  128. await _parentEnumerator.DisposeAsync().ConfigureAwait(false);
  129. _parentEnumerator = null;
  130. }
  131. await base.DisposeAsync().ConfigureAwait(false);
  132. }
  133. protected override async ValueTask<bool> MoveNextCore(CancellationToken cancellationToken)
  134. {
  135. switch (state)
  136. {
  137. case AsyncIteratorState.Allocated:
  138. await Initialize(cancellationToken).ConfigureAwait(false);
  139. _enumerator = enumerable.GetEnumerator();
  140. state = AsyncIteratorState.Iterating;
  141. goto case AsyncIteratorState.Iterating;
  142. case AsyncIteratorState.Iterating:
  143. if (_enumerator.MoveNext())
  144. {
  145. current = _enumerator.Current;
  146. return true;
  147. }
  148. await DisposeAsync().ConfigureAwait(false);
  149. break;
  150. }
  151. return false;
  152. }
  153. internal override async Task Initialize(CancellationToken cancellationToken)
  154. {
  155. if (_parent == null)
  156. {
  157. var buffer = await source.ToList(cancellationToken).ConfigureAwait(false);
  158. enumerable = (!_descending ? buffer.OrderByAsync(_keySelector, _comparer) : buffer.OrderByDescendingAsync(_keySelector, _comparer));
  159. }
  160. else
  161. {
  162. _parentEnumerator = _parent.GetAsyncEnumerator(cancellationToken);
  163. await _parent.Initialize(cancellationToken).ConfigureAwait(false);
  164. enumerable = _parent.enumerable.CreateOrderedEnumerableAsync(_keySelector, _comparer, _descending);
  165. }
  166. }
  167. }
  168. internal static class EnumerableSortingExtensions
  169. {
  170. // TODO: Implement async sorting.
  171. public static IOrderedEnumerable<TSource> OrderByAsync<TSource, TKey>(this IEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector, IComparer<TKey> comparer)
  172. {
  173. return source.OrderBy(key => keySelector(key).GetAwaiter().GetResult(), comparer);
  174. }
  175. public static IOrderedEnumerable<TSource> OrderByDescendingAsync<TSource, TKey>(this IEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector, IComparer<TKey> comparer)
  176. {
  177. return source.OrderByDescending(key => keySelector(key).GetAwaiter().GetResult(), comparer);
  178. }
  179. public static IOrderedEnumerable<TSource> CreateOrderedEnumerableAsync<TSource, TKey>(this IOrderedEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector, IComparer<TKey> comparer, bool descending)
  180. {
  181. return source.CreateOrderedEnumerable(key => keySelector(key).GetAwaiter().GetResult(), comparer, descending);
  182. }
  183. }
  184. }