TakeWhile.cs 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Threading;
  6. using System.Threading.Tasks;
  7. namespace System.Linq
  8. {
  9. public static partial class AsyncEnumerable
  10. {
  11. /// <summary>
  12. /// Returns elements from an async-enumerable sequence as long as a specified condition is true.
  13. /// </summary>
  14. /// <typeparam name="TSource">The type of the elements in the source sequence.</typeparam>
  15. /// <param name="source">A sequence to return elements from.</param>
  16. /// <param name="predicate">A function to test each element for a condition.</param>
  17. /// <returns>An async-enumerable sequence that contains the elements from the input sequence that occur before the element at which the test no longer passes.</returns>
  18. /// <exception cref="ArgumentNullException"><paramref name="source"/> or <paramref name="predicate"/> is null.</exception>
  19. public static IAsyncEnumerable<TSource> TakeWhile<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
  20. {
  21. if (source == null)
  22. throw Error.ArgumentNull(nameof(source));
  23. if (predicate == null)
  24. throw Error.ArgumentNull(nameof(predicate));
  25. return Create(Core);
  26. async IAsyncEnumerator<TSource> Core(CancellationToken cancellationToken)
  27. {
  28. await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
  29. {
  30. if (!predicate(element))
  31. {
  32. break;
  33. }
  34. yield return element;
  35. }
  36. }
  37. }
  38. /// <summary>
  39. /// Returns elements from an async-enumerable sequence as long as a specified condition is true.
  40. /// The element's index is used in the logic of the predicate function.
  41. /// </summary>
  42. /// <typeparam name="TSource">The type of the elements in the source sequence.</typeparam>
  43. /// <param name="source">A sequence to return elements from.</param>
  44. /// <param name="predicate">A function to test each element for a condition; the second parameter of the function represents the index of the source element.</param>
  45. /// <returns>An async-enumerable sequence that contains the elements from the input sequence that occur before the element at which the test no longer passes.</returns>
  46. /// <exception cref="ArgumentNullException"><paramref name="source"/> or <paramref name="predicate"/> is null.</exception>
  47. public static IAsyncEnumerable<TSource> TakeWhile<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, int, bool> predicate)
  48. {
  49. if (source == null)
  50. throw Error.ArgumentNull(nameof(source));
  51. if (predicate == null)
  52. throw Error.ArgumentNull(nameof(predicate));
  53. return Create(Core);
  54. async IAsyncEnumerator<TSource> Core(CancellationToken cancellationToken)
  55. {
  56. var index = -1;
  57. await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
  58. {
  59. checked
  60. {
  61. index++;
  62. }
  63. if (!predicate(element, index))
  64. {
  65. break;
  66. }
  67. yield return element;
  68. }
  69. }
  70. }
  71. internal static IAsyncEnumerable<TSource> TakeWhileAwaitCore<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, ValueTask<bool>> predicate)
  72. {
  73. if (source == null)
  74. throw Error.ArgumentNull(nameof(source));
  75. if (predicate == null)
  76. throw Error.ArgumentNull(nameof(predicate));
  77. return Create(Core);
  78. async IAsyncEnumerator<TSource> Core(CancellationToken cancellationToken)
  79. {
  80. await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
  81. {
  82. if (!await predicate(element).ConfigureAwait(false))
  83. {
  84. break;
  85. }
  86. yield return element;
  87. }
  88. }
  89. }
  90. #if !NO_DEEP_CANCELLATION
  91. internal static IAsyncEnumerable<TSource> TakeWhileAwaitWithCancellationCore<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, ValueTask<bool>> predicate)
  92. {
  93. if (source == null)
  94. throw Error.ArgumentNull(nameof(source));
  95. if (predicate == null)
  96. throw Error.ArgumentNull(nameof(predicate));
  97. return Create(Core);
  98. async IAsyncEnumerator<TSource> Core(CancellationToken cancellationToken)
  99. {
  100. await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
  101. {
  102. if (!await predicate(element, cancellationToken).ConfigureAwait(false))
  103. {
  104. break;
  105. }
  106. yield return element;
  107. }
  108. }
  109. }
  110. #endif
  111. internal static IAsyncEnumerable<TSource> TakeWhileAwaitCore<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, int, ValueTask<bool>> predicate)
  112. {
  113. if (source == null)
  114. throw Error.ArgumentNull(nameof(source));
  115. if (predicate == null)
  116. throw Error.ArgumentNull(nameof(predicate));
  117. return Create(Core);
  118. async IAsyncEnumerator<TSource> Core(CancellationToken cancellationToken)
  119. {
  120. var index = -1;
  121. await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
  122. {
  123. checked
  124. {
  125. index++;
  126. }
  127. if (!await predicate(element, index).ConfigureAwait(false))
  128. {
  129. break;
  130. }
  131. yield return element;
  132. }
  133. }
  134. }
  135. #if !NO_DEEP_CANCELLATION
  136. internal static IAsyncEnumerable<TSource> TakeWhileAwaitWithCancellationCore<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, int, CancellationToken, ValueTask<bool>> predicate)
  137. {
  138. if (source == null)
  139. throw Error.ArgumentNull(nameof(source));
  140. if (predicate == null)
  141. throw Error.ArgumentNull(nameof(predicate));
  142. return Create(Core);
  143. async IAsyncEnumerator<TSource> Core(CancellationToken cancellationToken)
  144. {
  145. var index = -1;
  146. await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
  147. {
  148. checked
  149. {
  150. index++;
  151. }
  152. if (!await predicate(element, index, cancellationToken).ConfigureAwait(false))
  153. {
  154. break;
  155. }
  156. yield return element;
  157. }
  158. }
  159. }
  160. #endif
  161. }
  162. }