TakeWhile.cs 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the MIT License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Threading;
  6. using System.Threading.Tasks;
  7. namespace System.Linq
  8. {
  9. public static partial class AsyncEnumerable
  10. {
  11. /// <summary>
  12. /// Returns elements from an async-enumerable sequence as long as a specified condition is true.
  13. /// </summary>
  14. /// <typeparam name="TSource">The type of the elements in the source sequence.</typeparam>
  15. /// <param name="source">A sequence to return elements from.</param>
  16. /// <param name="predicate">A function to test each element for a condition.</param>
  17. /// <returns>An async-enumerable sequence that contains the elements from the input sequence that occur before the element at which the test no longer passes.</returns>
  18. /// <exception cref="ArgumentNullException"><paramref name="source"/> or <paramref name="predicate"/> is null.</exception>
  19. public static IAsyncEnumerable<TSource> TakeWhile<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
  20. {
  21. if (source == null)
  22. throw Error.ArgumentNull(nameof(source));
  23. if (predicate == null)
  24. throw Error.ArgumentNull(nameof(predicate));
  25. return Core(source, predicate);
  26. static async IAsyncEnumerable<TSource> Core(IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
  27. {
  28. await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
  29. {
  30. if (!predicate(element))
  31. {
  32. break;
  33. }
  34. yield return element;
  35. }
  36. }
  37. }
  38. /// <summary>
  39. /// Returns elements from an async-enumerable sequence as long as a specified condition is true.
  40. /// The element's index is used in the logic of the predicate function.
  41. /// </summary>
  42. /// <typeparam name="TSource">The type of the elements in the source sequence.</typeparam>
  43. /// <param name="source">A sequence to return elements from.</param>
  44. /// <param name="predicate">A function to test each element for a condition; the second parameter of the function represents the index of the source element.</param>
  45. /// <returns>An async-enumerable sequence that contains the elements from the input sequence that occur before the element at which the test no longer passes.</returns>
  46. /// <exception cref="ArgumentNullException"><paramref name="source"/> or <paramref name="predicate"/> is null.</exception>
  47. public static IAsyncEnumerable<TSource> TakeWhile<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, int, bool> predicate)
  48. {
  49. if (source == null)
  50. throw Error.ArgumentNull(nameof(source));
  51. if (predicate == null)
  52. throw Error.ArgumentNull(nameof(predicate));
  53. return Core(source, predicate);
  54. static async IAsyncEnumerable<TSource> Core(IAsyncEnumerable<TSource> source, Func<TSource, int, bool> predicate, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
  55. {
  56. var index = -1;
  57. await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
  58. {
  59. checked
  60. {
  61. index++;
  62. }
  63. if (!predicate(element, index))
  64. {
  65. break;
  66. }
  67. yield return element;
  68. }
  69. }
  70. }
  71. /// <summary>
  72. /// Returns elements from an async-enumerable sequence as long as a specified condition is true.
  73. /// </summary>
  74. /// <typeparam name="TSource">The type of the elements in the source sequence.</typeparam>
  75. /// <param name="source">A sequence to return elements from.</param>
  76. /// <param name="predicate">An asynchronous predicate to test each element for a condition.</param>
  77. /// <returns>An async-enumerable sequence that contains the elements from the input sequence that occur before the element at which the test no longer passes.</returns>
  78. /// <exception cref="ArgumentNullException"><paramref name="source"/> or <paramref name="predicate"/> is null.</exception>
  79. [GenerateAsyncOverload]
  80. private static IAsyncEnumerable<TSource> TakeWhileAwaitCore<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, ValueTask<bool>> predicate)
  81. {
  82. if (source == null)
  83. throw Error.ArgumentNull(nameof(source));
  84. if (predicate == null)
  85. throw Error.ArgumentNull(nameof(predicate));
  86. return Core(source, predicate);
  87. static async IAsyncEnumerable<TSource> Core(IAsyncEnumerable<TSource> source, Func<TSource, ValueTask<bool>> predicate, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
  88. {
  89. await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
  90. {
  91. if (!await predicate(element).ConfigureAwait(false))
  92. {
  93. break;
  94. }
  95. yield return element;
  96. }
  97. }
  98. }
  99. #if !NO_DEEP_CANCELLATION
  100. [GenerateAsyncOverload]
  101. private static IAsyncEnumerable<TSource> TakeWhileAwaitWithCancellationCore<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, ValueTask<bool>> predicate)
  102. {
  103. if (source == null)
  104. throw Error.ArgumentNull(nameof(source));
  105. if (predicate == null)
  106. throw Error.ArgumentNull(nameof(predicate));
  107. return Core(source, predicate);
  108. static async IAsyncEnumerable<TSource> Core(IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, ValueTask<bool>> predicate, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
  109. {
  110. await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
  111. {
  112. if (!await predicate(element, cancellationToken).ConfigureAwait(false))
  113. {
  114. break;
  115. }
  116. yield return element;
  117. }
  118. }
  119. }
  120. #endif
  121. /// <summary>
  122. /// Returns elements from an async-enumerable sequence as long as a specified condition is true.
  123. /// The element's index is used in the logic of the predicate function.
  124. /// </summary>
  125. /// <typeparam name="TSource">The type of the elements in the source sequence.</typeparam>
  126. /// <param name="source">A sequence to return elements from.</param>
  127. /// <param name="predicate">An asynchronous function to test each element for a condition; the second parameter of the function represents the index of the source element.</param>
  128. /// <returns>An async-enumerable sequence that contains the elements from the input sequence that occur before the element at which the test no longer passes.</returns>
  129. /// <exception cref="ArgumentNullException"><paramref name="source"/> or <paramref name="predicate"/> is null.</exception>
  130. [GenerateAsyncOverload]
  131. private static IAsyncEnumerable<TSource> TakeWhileAwaitCore<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, int, ValueTask<bool>> predicate)
  132. {
  133. if (source == null)
  134. throw Error.ArgumentNull(nameof(source));
  135. if (predicate == null)
  136. throw Error.ArgumentNull(nameof(predicate));
  137. return Core(source, predicate);
  138. static async IAsyncEnumerable<TSource> Core(IAsyncEnumerable<TSource> source, Func<TSource, int, ValueTask<bool>> predicate, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
  139. {
  140. var index = -1;
  141. await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
  142. {
  143. checked
  144. {
  145. index++;
  146. }
  147. if (!await predicate(element, index).ConfigureAwait(false))
  148. {
  149. break;
  150. }
  151. yield return element;
  152. }
  153. }
  154. }
  155. #if !NO_DEEP_CANCELLATION
  156. [GenerateAsyncOverload]
  157. private static IAsyncEnumerable<TSource> TakeWhileAwaitWithCancellationCore<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, int, CancellationToken, ValueTask<bool>> predicate)
  158. {
  159. if (source == null)
  160. throw Error.ArgumentNull(nameof(source));
  161. if (predicate == null)
  162. throw Error.ArgumentNull(nameof(predicate));
  163. return Core(source, predicate);
  164. static async IAsyncEnumerable<TSource> Core(IAsyncEnumerable<TSource> source, Func<TSource, int, CancellationToken, ValueTask<bool>> predicate, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
  165. {
  166. var index = -1;
  167. await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
  168. {
  169. checked
  170. {
  171. index++;
  172. }
  173. if (!await predicate(element, index, cancellationToken).ConfigureAwait(false))
  174. {
  175. break;
  176. }
  177. yield return element;
  178. }
  179. }
  180. }
  181. #endif
  182. }
  183. }