Where.cs 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the MIT License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Threading;
  6. using System.Threading.Tasks;
  7. namespace System.Linq
  8. {
  9. public static partial class AsyncEnumerable
  10. {
  11. #if INCLUDE_SYSTEM_LINQ_ASYNCENUMERABLE_DUPLICATES
  12. // https://learn.microsoft.com/en-us/dotnet/api/system.linq.asyncenumerable.where?view=net-9.0-pp#system-linq-asyncenumerable-where-1(system-collections-generic-iasyncenumerable((-0))-system-func((-0-system-boolean)))
  13. /// <summary>
  14. /// Filters the elements of an async-enumerable sequence based on a predicate.
  15. /// </summary>
  16. /// <typeparam name="TSource">The type of the elements in the source sequence.</typeparam>
  17. /// <param name="source">An async-enumerable sequence whose elements to filter.</param>
  18. /// <param name="predicate">A function to test each source element for a condition.</param>
  19. /// <returns>An async-enumerable sequence that contains elements from the input sequence that satisfy the condition.</returns>
  20. /// <exception cref="ArgumentNullException"><paramref name="source"/> or <paramref name="predicate"/> is null.</exception>
  21. public static IAsyncEnumerable<TSource> Where<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
  22. {
  23. if (source == null)
  24. throw Error.ArgumentNull(nameof(source));
  25. if (predicate == null)
  26. throw Error.ArgumentNull(nameof(predicate));
  27. if (source is AsyncIteratorBase<TSource> iterator)
  28. {
  29. return iterator.Where(predicate);
  30. }
  31. // TODO: Can we add array/list optimizations here, does it make sense?
  32. return new WhereEnumerableAsyncIterator<TSource>(source, predicate);
  33. }
  34. // https://learn.microsoft.com/en-us/dotnet/api/system.linq.asyncenumerable.where?view=net-9.0-pp#system-linq-asyncenumerable-where-1(system-collections-generic-iasyncenumerable((-0))-system-func((-0-system-int32-system-boolean)))
  35. /// <summary>
  36. /// Filters the elements of an async-enumerable sequence based on a predicate by incorporating the element's index.
  37. /// </summary>
  38. /// <typeparam name="TSource">The type of the elements in the source sequence.</typeparam>
  39. /// <param name="source">An async-enumerable sequence whose elements to filter.</param>
  40. /// <param name="predicate">A function to test each source element for a condition; the second parameter of the function represents the index of the source element.</param>
  41. /// <returns>An async-enumerable sequence that contains elements from the input sequence that satisfy the condition.</returns>
  42. /// <exception cref="ArgumentNullException"><paramref name="source"/> or <paramref name="predicate"/> is null.</exception>
  43. public static IAsyncEnumerable<TSource> Where<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, int, bool> predicate)
  44. {
  45. if (source == null)
  46. throw Error.ArgumentNull(nameof(source));
  47. if (predicate == null)
  48. throw Error.ArgumentNull(nameof(predicate));
  49. return Core(source, predicate);
  50. static async IAsyncEnumerable<TSource> Core(IAsyncEnumerable<TSource> source, Func<TSource, int, bool> predicate, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
  51. {
  52. var index = -1;
  53. await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
  54. {
  55. checked
  56. {
  57. index++;
  58. }
  59. if (predicate(element, index))
  60. {
  61. yield return element;
  62. }
  63. }
  64. }
  65. }
  66. #endif // INCLUDE_SYSTEM_LINQ_ASYNCENUMERABLE_DUPLICATES
  67. /// <summary>
  68. /// Filters the elements of an async-enumerable sequence based on an asynchronous predicate.
  69. /// </summary>
  70. /// <typeparam name="TSource">The type of the elements in the source sequence.</typeparam>
  71. /// <param name="source">An async-enumerable sequence whose elements to filter.</param>
  72. /// <param name="predicate">An asynchronous predicate to test each source element for a condition.</param>
  73. /// <returns>An async-enumerable sequence that contains elements from the input sequence that satisfy the condition.</returns>
  74. /// <exception cref="ArgumentNullException"><paramref name="source"/> or <paramref name="predicate"/> is null.</exception>
  75. [GenerateAsyncOverload]
  76. [Obsolete("Use Where. IAsyncEnumerable LINQ is now in System.Linq.AsyncEnumerable, and the WhereAwait functionality now exists as overloads of Where. You will need to modify your callback to take an additional CancellationToken argument.")]
  77. private static IAsyncEnumerable<TSource> WhereAwaitCore<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, ValueTask<bool>> predicate)
  78. {
  79. if (source == null)
  80. throw Error.ArgumentNull(nameof(source));
  81. if (predicate == null)
  82. throw Error.ArgumentNull(nameof(predicate));
  83. if (source is AsyncIteratorBase<TSource> iterator)
  84. {
  85. return iterator.Where(predicate);
  86. }
  87. // TODO: Can we add array/list optimizations here, does it make sense?
  88. return new WhereEnumerableAsyncIteratorWithTask<TSource>(source, predicate);
  89. }
  90. #if !NO_DEEP_CANCELLATION
  91. [GenerateAsyncOverload]
  92. [Obsolete("Use Where. IAsyncEnumerable LINQ is now in System.Linq.AsyncEnumerable, and the WhereAwaitWithCancellation functionality now exists as overloads of Where.")]
  93. private static IAsyncEnumerable<TSource> WhereAwaitWithCancellationCore<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, ValueTask<bool>> predicate)
  94. {
  95. if (source == null)
  96. throw Error.ArgumentNull(nameof(source));
  97. if (predicate == null)
  98. throw Error.ArgumentNull(nameof(predicate));
  99. if (source is AsyncIteratorBase<TSource> iterator)
  100. {
  101. return iterator.Where(predicate);
  102. }
  103. // TODO: Can we add array/list optimizations here, does it make sense?
  104. return new WhereEnumerableAsyncIteratorWithTaskAndCancellation<TSource>(source, predicate);
  105. }
  106. #endif
  107. /// <summary>
  108. /// Filters the elements of an async-enumerable sequence based on an asynchronous predicate that incorporates the element's index.
  109. /// </summary>
  110. /// <typeparam name="TSource">The type of the elements in the source sequence.</typeparam>
  111. /// <param name="source">An async-enumerable sequence whose elements to filter.</param>
  112. /// <param name="predicate">An asynchronous predicate to test each source element for a condition; the second parameter of the function represents the index of the source element.</param>
  113. /// <returns>An async-enumerable sequence that contains elements from the input sequence that satisfy the condition.</returns>
  114. /// <exception cref="ArgumentNullException"><paramref name="source"/> or <paramref name="predicate"/> is null.</exception>
  115. [GenerateAsyncOverload]
  116. [Obsolete("Use Where. IAsyncEnumerable LINQ is now in System.Linq.AsyncEnumerable, and the WhereAwait functionality now exists as overloads of Where. You will need to modify your callback to take an additional CancellationToken argument.")]
  117. private static IAsyncEnumerable<TSource> WhereAwaitCore<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, int, ValueTask<bool>> predicate)
  118. {
  119. if (source == null)
  120. throw Error.ArgumentNull(nameof(source));
  121. if (predicate == null)
  122. throw Error.ArgumentNull(nameof(predicate));
  123. return Core(source, predicate);
  124. static async IAsyncEnumerable<TSource> Core(IAsyncEnumerable<TSource> source, Func<TSource, int, ValueTask<bool>> predicate, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
  125. {
  126. var index = -1;
  127. await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
  128. {
  129. checked
  130. {
  131. index++;
  132. }
  133. if (await predicate(element, index).ConfigureAwait(false))
  134. {
  135. yield return element;
  136. }
  137. }
  138. }
  139. }
  140. #if !NO_DEEP_CANCELLATION
  141. [GenerateAsyncOverload]
  142. [Obsolete("Use Where. IAsyncEnumerable LINQ is now in System.Linq.AsyncEnumerable, and the WhereAwaitWithCancellation functionality now exists as overloads of Where.")]
  143. private static IAsyncEnumerable<TSource> WhereAwaitWithCancellationCore<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, int, CancellationToken, ValueTask<bool>> predicate)
  144. {
  145. if (source == null)
  146. throw Error.ArgumentNull(nameof(source));
  147. if (predicate == null)
  148. throw Error.ArgumentNull(nameof(predicate));
  149. return Core(source, predicate);
  150. static async IAsyncEnumerable<TSource> Core(IAsyncEnumerable<TSource> source, Func<TSource, int, CancellationToken, ValueTask<bool>> predicate, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
  151. {
  152. var index = -1;
  153. await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
  154. {
  155. checked
  156. {
  157. index++;
  158. }
  159. if (await predicate(element, index, cancellationToken).ConfigureAwait(false))
  160. {
  161. yield return element;
  162. }
  163. }
  164. }
  165. }
  166. #endif
  167. internal sealed class WhereEnumerableAsyncIterator<TSource> : AsyncIterator<TSource>
  168. {
  169. private readonly Func<TSource, bool> _predicate;
  170. private readonly IAsyncEnumerable<TSource> _source;
  171. private IAsyncEnumerator<TSource>? _enumerator;
  172. public WhereEnumerableAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
  173. {
  174. _source = source;
  175. _predicate = predicate;
  176. }
  177. public override AsyncIteratorBase<TSource> Clone()
  178. {
  179. return new WhereEnumerableAsyncIterator<TSource>(_source, _predicate);
  180. }
  181. public override async ValueTask DisposeAsync()
  182. {
  183. if (_enumerator != null)
  184. {
  185. await _enumerator.DisposeAsync().ConfigureAwait(false);
  186. _enumerator = null;
  187. }
  188. await base.DisposeAsync().ConfigureAwait(false);
  189. }
  190. public override IAsyncEnumerable<TResult> Select<TResult>(Func<TSource, TResult> selector)
  191. {
  192. return new WhereSelectEnumerableAsyncIterator<TSource, TResult>(_source, _predicate, selector);
  193. }
  194. public override IAsyncEnumerable<TSource> Where(Func<TSource, bool> predicate)
  195. {
  196. return new WhereEnumerableAsyncIterator<TSource>(_source, CombinePredicates(_predicate, predicate));
  197. }
  198. protected override async ValueTask<bool> MoveNextCore()
  199. {
  200. switch (_state)
  201. {
  202. case AsyncIteratorState.Allocated:
  203. _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
  204. _state = AsyncIteratorState.Iterating;
  205. goto case AsyncIteratorState.Iterating;
  206. case AsyncIteratorState.Iterating:
  207. while (await _enumerator!.MoveNextAsync().ConfigureAwait(false))
  208. {
  209. var item = _enumerator.Current;
  210. if (_predicate(item))
  211. {
  212. _current = item;
  213. return true;
  214. }
  215. }
  216. await DisposeAsync().ConfigureAwait(false);
  217. break;
  218. }
  219. return false;
  220. }
  221. }
  222. internal sealed class WhereEnumerableAsyncIteratorWithTask<TSource> : AsyncIterator<TSource>
  223. {
  224. private readonly Func<TSource, ValueTask<bool>> _predicate;
  225. private readonly IAsyncEnumerable<TSource> _source;
  226. private IAsyncEnumerator<TSource>? _enumerator;
  227. public WhereEnumerableAsyncIteratorWithTask(IAsyncEnumerable<TSource> source, Func<TSource, ValueTask<bool>> predicate)
  228. {
  229. _source = source;
  230. _predicate = predicate;
  231. }
  232. public override AsyncIteratorBase<TSource> Clone()
  233. {
  234. return new WhereEnumerableAsyncIteratorWithTask<TSource>(_source, _predicate);
  235. }
  236. public override async ValueTask DisposeAsync()
  237. {
  238. if (_enumerator != null)
  239. {
  240. await _enumerator.DisposeAsync().ConfigureAwait(false);
  241. _enumerator = null;
  242. }
  243. await base.DisposeAsync().ConfigureAwait(false);
  244. }
  245. public override IAsyncEnumerable<TSource> Where(Func<TSource, ValueTask<bool>> predicate)
  246. {
  247. return new WhereEnumerableAsyncIteratorWithTask<TSource>(_source, CombinePredicates(_predicate, predicate));
  248. }
  249. protected override async ValueTask<bool> MoveNextCore()
  250. {
  251. switch (_state)
  252. {
  253. case AsyncIteratorState.Allocated:
  254. _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
  255. _state = AsyncIteratorState.Iterating;
  256. goto case AsyncIteratorState.Iterating;
  257. case AsyncIteratorState.Iterating:
  258. while (await _enumerator!.MoveNextAsync().ConfigureAwait(false))
  259. {
  260. var item = _enumerator.Current;
  261. if (await _predicate(item).ConfigureAwait(false))
  262. {
  263. _current = item;
  264. return true;
  265. }
  266. }
  267. await DisposeAsync().ConfigureAwait(false);
  268. break;
  269. }
  270. return false;
  271. }
  272. }
  273. #if !NO_DEEP_CANCELLATION
  274. internal sealed class WhereEnumerableAsyncIteratorWithTaskAndCancellation<TSource> : AsyncIterator<TSource>
  275. {
  276. private readonly Func<TSource, CancellationToken, ValueTask<bool>> _predicate;
  277. private readonly IAsyncEnumerable<TSource> _source;
  278. private IAsyncEnumerator<TSource>? _enumerator;
  279. public WhereEnumerableAsyncIteratorWithTaskAndCancellation(IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, ValueTask<bool>> predicate)
  280. {
  281. _source = source;
  282. _predicate = predicate;
  283. }
  284. public override AsyncIteratorBase<TSource> Clone()
  285. {
  286. return new WhereEnumerableAsyncIteratorWithTaskAndCancellation<TSource>(_source, _predicate);
  287. }
  288. public override async ValueTask DisposeAsync()
  289. {
  290. if (_enumerator != null)
  291. {
  292. await _enumerator.DisposeAsync().ConfigureAwait(false);
  293. _enumerator = null;
  294. }
  295. await base.DisposeAsync().ConfigureAwait(false);
  296. }
  297. public override IAsyncEnumerable<TSource> Where(Func<TSource, CancellationToken, ValueTask<bool>> predicate)
  298. {
  299. return new WhereEnumerableAsyncIteratorWithTaskAndCancellation<TSource>(_source, CombinePredicates(_predicate, predicate));
  300. }
  301. protected override async ValueTask<bool> MoveNextCore()
  302. {
  303. switch (_state)
  304. {
  305. case AsyncIteratorState.Allocated:
  306. _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
  307. _state = AsyncIteratorState.Iterating;
  308. goto case AsyncIteratorState.Iterating;
  309. case AsyncIteratorState.Iterating:
  310. while (await _enumerator!.MoveNextAsync().ConfigureAwait(false))
  311. {
  312. var item = _enumerator.Current;
  313. if (await _predicate(item, _cancellationToken).ConfigureAwait(false))
  314. {
  315. _current = item;
  316. return true;
  317. }
  318. }
  319. await DisposeAsync().ConfigureAwait(false);
  320. break;
  321. }
  322. return false;
  323. }
  324. }
  325. #endif
  326. private sealed class WhereSelectEnumerableAsyncIterator<TSource, TResult> : AsyncIterator<TResult>
  327. {
  328. private readonly Func<TSource, bool> _predicate;
  329. private readonly Func<TSource, TResult> _selector;
  330. private readonly IAsyncEnumerable<TSource> _source;
  331. private IAsyncEnumerator<TSource>? _enumerator;
  332. public WhereSelectEnumerableAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate, Func<TSource, TResult> selector)
  333. {
  334. _source = source;
  335. _predicate = predicate;
  336. _selector = selector;
  337. }
  338. public override AsyncIteratorBase<TResult> Clone()
  339. {
  340. return new WhereSelectEnumerableAsyncIterator<TSource, TResult>(_source, _predicate, _selector);
  341. }
  342. public override async ValueTask DisposeAsync()
  343. {
  344. if (_enumerator != null)
  345. {
  346. await _enumerator.DisposeAsync().ConfigureAwait(false);
  347. _enumerator = null;
  348. }
  349. await base.DisposeAsync().ConfigureAwait(false);
  350. }
  351. public override IAsyncEnumerable<TResult1> Select<TResult1>(Func<TResult, TResult1> selector)
  352. {
  353. return new WhereSelectEnumerableAsyncIterator<TSource, TResult1>(_source, _predicate, CombineSelectors(_selector, selector));
  354. }
  355. protected override async ValueTask<bool> MoveNextCore()
  356. {
  357. switch (_state)
  358. {
  359. case AsyncIteratorState.Allocated:
  360. _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
  361. _state = AsyncIteratorState.Iterating;
  362. goto case AsyncIteratorState.Iterating;
  363. case AsyncIteratorState.Iterating:
  364. while (await _enumerator!.MoveNextAsync().ConfigureAwait(false))
  365. {
  366. var item = _enumerator.Current;
  367. if (_predicate(item))
  368. {
  369. _current = _selector(item);
  370. return true;
  371. }
  372. }
  373. await DisposeAsync().ConfigureAwait(false);
  374. break;
  375. }
  376. return false;
  377. }
  378. }
  379. }
  380. }