Where.cs 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the MIT License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Threading;
  6. using System.Threading.Tasks;
  7. namespace System.Linq
  8. {
  9. public static partial class AsyncEnumerable
  10. {
  11. #if INCLUDE_SYSTEM_LINQ_ASYNCENUMERABLE_DUPLICATES
  12. // https://learn.microsoft.com/en-us/dotnet/api/system.linq.asyncenumerable.where?view=net-9.0-pp#system-linq-asyncenumerable-where-1(system-collections-generic-iasyncenumerable((-0))-system-func((-0-system-boolean)))
  13. /// <summary>
  14. /// Filters the elements of an async-enumerable sequence based on a predicate.
  15. /// </summary>
  16. /// <typeparam name="TSource">The type of the elements in the source sequence.</typeparam>
  17. /// <param name="source">An async-enumerable sequence whose elements to filter.</param>
  18. /// <param name="predicate">A function to test each source element for a condition.</param>
  19. /// <returns>An async-enumerable sequence that contains elements from the input sequence that satisfy the condition.</returns>
  20. /// <exception cref="ArgumentNullException"><paramref name="source"/> or <paramref name="predicate"/> is null.</exception>
  21. public static IAsyncEnumerable<TSource> Where<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
  22. {
  23. if (source == null)
  24. throw Error.ArgumentNull(nameof(source));
  25. if (predicate == null)
  26. throw Error.ArgumentNull(nameof(predicate));
  27. if (source is AsyncIteratorBase<TSource> iterator)
  28. {
  29. return iterator.Where(predicate);
  30. }
  31. // TODO: Can we add array/list optimizations here, does it make sense?
  32. return new WhereEnumerableAsyncIterator<TSource>(source, predicate);
  33. }
  34. // https://learn.microsoft.com/en-us/dotnet/api/system.linq.asyncenumerable.where?view=net-9.0-pp#system-linq-asyncenumerable-where-1(system-collections-generic-iasyncenumerable((-0))-system-func((-0-system-int32-system-boolean)))
  35. /// <summary>
  36. /// Filters the elements of an async-enumerable sequence based on a predicate by incorporating the element's index.
  37. /// </summary>
  38. /// <typeparam name="TSource">The type of the elements in the source sequence.</typeparam>
  39. /// <param name="source">An async-enumerable sequence whose elements to filter.</param>
  40. /// <param name="predicate">A function to test each source element for a condition; the second parameter of the function represents the index of the source element.</param>
  41. /// <returns>An async-enumerable sequence that contains elements from the input sequence that satisfy the condition.</returns>
  42. /// <exception cref="ArgumentNullException"><paramref name="source"/> or <paramref name="predicate"/> is null.</exception>
  43. public static IAsyncEnumerable<TSource> Where<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, int, bool> predicate)
  44. {
  45. if (source == null)
  46. throw Error.ArgumentNull(nameof(source));
  47. if (predicate == null)
  48. throw Error.ArgumentNull(nameof(predicate));
  49. return Core(source, predicate);
  50. static async IAsyncEnumerable<TSource> Core(IAsyncEnumerable<TSource> source, Func<TSource, int, bool> predicate, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
  51. {
  52. var index = -1;
  53. await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
  54. {
  55. checked
  56. {
  57. index++;
  58. }
  59. if (predicate(element, index))
  60. {
  61. yield return element;
  62. }
  63. }
  64. }
  65. }
  66. #endif // INCLUDE_SYSTEM_LINQ_ASYNCENUMERABLE_DUPLICATES
  67. /// <summary>
  68. /// Filters the elements of an async-enumerable sequence based on an asynchronous predicate.
  69. /// </summary>
  70. /// <typeparam name="TSource">The type of the elements in the source sequence.</typeparam>
  71. /// <param name="source">An async-enumerable sequence whose elements to filter.</param>
  72. /// <param name="predicate">An asynchronous predicate to test each source element for a condition.</param>
  73. /// <returns>An async-enumerable sequence that contains elements from the input sequence that satisfy the condition.</returns>
  74. /// <exception cref="ArgumentNullException"><paramref name="source"/> or <paramref name="predicate"/> is null.</exception>
  75. [GenerateAsyncOverload]
  76. [Obsolete("Use Where. IAsyncEnumerable LINQ is now in System.Linq.AsyncEnumerable, and the WhereAwait functionality now exists as overloads of Where.")]
  77. private static IAsyncEnumerable<TSource> WhereAwaitCore<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, ValueTask<bool>> predicate)
  78. {
  79. if (source == null)
  80. throw Error.ArgumentNull(nameof(source));
  81. if (predicate == null)
  82. throw Error.ArgumentNull(nameof(predicate));
  83. if (source is AsyncIteratorBase<TSource> iterator)
  84. {
  85. return iterator.Where(predicate);
  86. }
  87. // TODO: Can we add array/list optimizations here, does it make sense?
  88. return new WhereEnumerableAsyncIteratorWithTask<TSource>(source, predicate);
  89. }
  90. #if !NO_DEEP_CANCELLATION
  91. [GenerateAsyncOverload]
  92. [Obsolete("Use Where. IAsyncEnumerable LINQ is now in System.Linq.AsyncEnumerable, and the WhereAwaitWithCancellation functionality now exists as overloads of Where.")]
  93. private static IAsyncEnumerable<TSource> WhereAwaitWithCancellationCore<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, ValueTask<bool>> predicate)
  94. {
  95. if (source == null)
  96. throw Error.ArgumentNull(nameof(source));
  97. if (predicate == null)
  98. throw Error.ArgumentNull(nameof(predicate));
  99. if (source is AsyncIteratorBase<TSource> iterator)
  100. {
  101. return iterator.Where(predicate);
  102. }
  103. // TODO: Can we add array/list optimizations here, does it make sense?
  104. return new WhereEnumerableAsyncIteratorWithTaskAndCancellation<TSource>(source, predicate);
  105. }
  106. #endif
  107. /// <summary>
  108. /// Filters the elements of an async-enumerable sequence based on an asynchronous predicate that incorporates the element's index.
  109. /// </summary>
  110. /// <typeparam name="TSource">The type of the elements in the source sequence.</typeparam>
  111. /// <param name="source">An async-enumerable sequence whose elements to filter.</param>
  112. /// <param name="predicate">An asynchronous predicate to test each source element for a condition; the second parameter of the function represents the index of the source element.</param>
  113. /// <returns>An async-enumerable sequence that contains elements from the input sequence that satisfy the condition.</returns>
  114. /// <exception cref="ArgumentNullException"><paramref name="source"/> or <paramref name="predicate"/> is null.</exception>
  115. [GenerateAsyncOverload]
  116. [Obsolete("Use Where. IAsyncEnumerable LINQ is now in System.Linq.AsyncEnumerable, and the WhereAwait functionality now exists as overloads of Where.")]
  117. private static IAsyncEnumerable<TSource> WhereAwaitCore<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, int, ValueTask<bool>> predicate)
  118. {
  119. if (source == null)
  120. throw Error.ArgumentNull(nameof(source));
  121. if (predicate == null)
  122. throw Error.ArgumentNull(nameof(predicate));
  123. return Core(source, predicate);
  124. static async IAsyncEnumerable<TSource> Core(IAsyncEnumerable<TSource> source, Func<TSource, int, ValueTask<bool>> predicate, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
  125. {
  126. var index = -1;
  127. await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
  128. {
  129. checked
  130. {
  131. index++;
  132. }
  133. if (await predicate(element, index).ConfigureAwait(false))
  134. {
  135. yield return element;
  136. }
  137. }
  138. }
  139. }
  140. #if !NO_DEEP_CANCELLATION
  141. [GenerateAsyncOverload]
  142. [Obsolete("Use Where. IAsyncEnumerable LINQ is now in System.Linq.AsyncEnumerable, and the WhereAwaitWithCancellation functionality now exists as overloads of Where.")]
  143. private static IAsyncEnumerable<TSource> WhereAwaitWithCancellationCore<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, int, CancellationToken, ValueTask<bool>> predicate)
  144. {
  145. if (source == null)
  146. throw Error.ArgumentNull(nameof(source));
  147. if (predicate == null)
  148. throw Error.ArgumentNull(nameof(predicate));
  149. return Core(source, predicate);
  150. static async IAsyncEnumerable<TSource> Core(IAsyncEnumerable<TSource> source, Func<TSource, int, CancellationToken, ValueTask<bool>> predicate, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default)
  151. {
  152. var index = -1;
  153. await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
  154. {
  155. checked
  156. {
  157. index++;
  158. }
  159. if (await predicate(element, index, cancellationToken).ConfigureAwait(false))
  160. {
  161. yield return element;
  162. }
  163. }
  164. }
  165. }
  166. #endif
  167. internal sealed class WhereEnumerableAsyncIterator<TSource> : AsyncIterator<TSource>
  168. {
  169. private readonly Func<TSource, bool> _predicate;
  170. private readonly IAsyncEnumerable<TSource> _source;
  171. private IAsyncEnumerator<TSource>? _enumerator;
  172. public WhereEnumerableAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
  173. {
  174. _source = source;
  175. _predicate = predicate;
  176. }
  177. public override AsyncIteratorBase<TSource> Clone()
  178. {
  179. return new WhereEnumerableAsyncIterator<TSource>(_source, _predicate);
  180. }
  181. public override async ValueTask DisposeAsync()
  182. {
  183. if (_enumerator != null)
  184. {
  185. await _enumerator.DisposeAsync().ConfigureAwait(false);
  186. _enumerator = null;
  187. }
  188. await base.DisposeAsync().ConfigureAwait(false);
  189. }
  190. public override IAsyncEnumerable<TResult> Select<TResult>(Func<TSource, TResult> selector)
  191. {
  192. return new WhereSelectEnumerableAsyncIterator<TSource, TResult>(_source, _predicate, selector);
  193. }
  194. public override IAsyncEnumerable<TSource> Where(Func<TSource, bool> predicate)
  195. {
  196. return new WhereEnumerableAsyncIterator<TSource>(_source, CombinePredicates(_predicate, predicate));
  197. }
  198. protected override async ValueTask<bool> MoveNextCore()
  199. {
  200. switch (_state)
  201. {
  202. case AsyncIteratorState.Allocated:
  203. _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
  204. _state = AsyncIteratorState.Iterating;
  205. goto case AsyncIteratorState.Iterating;
  206. case AsyncIteratorState.Iterating:
  207. while (await _enumerator!.MoveNextAsync().ConfigureAwait(false))
  208. {
  209. var item = _enumerator.Current;
  210. if (_predicate(item))
  211. {
  212. _current = item;
  213. return true;
  214. }
  215. }
  216. await DisposeAsync().ConfigureAwait(false);
  217. break;
  218. }
  219. return false;
  220. }
  221. }
  222. internal sealed class WhereEnumerableAsyncIteratorWithTask<TSource> : AsyncIterator<TSource>
  223. {
  224. private readonly Func<TSource, ValueTask<bool>> _predicate;
  225. private readonly IAsyncEnumerable<TSource> _source;
  226. private IAsyncEnumerator<TSource>? _enumerator;
  227. public WhereEnumerableAsyncIteratorWithTask(IAsyncEnumerable<TSource> source, Func<TSource, ValueTask<bool>> predicate)
  228. {
  229. _source = source;
  230. _predicate = predicate;
  231. }
  232. public override AsyncIteratorBase<TSource> Clone()
  233. {
  234. return new WhereEnumerableAsyncIteratorWithTask<TSource>(_source, _predicate);
  235. }
  236. public override async ValueTask DisposeAsync()
  237. {
  238. if (_enumerator != null)
  239. {
  240. await _enumerator.DisposeAsync().ConfigureAwait(false);
  241. _enumerator = null;
  242. }
  243. await base.DisposeAsync().ConfigureAwait(false);
  244. }
  245. public override IAsyncEnumerable<TSource> Where(Func<TSource, ValueTask<bool>> predicate)
  246. {
  247. return new WhereEnumerableAsyncIteratorWithTask<TSource>(_source, CombinePredicates(_predicate, predicate));
  248. }
  249. protected override async ValueTask<bool> MoveNextCore()
  250. {
  251. switch (_state)
  252. {
  253. case AsyncIteratorState.Allocated:
  254. _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
  255. _state = AsyncIteratorState.Iterating;
  256. goto case AsyncIteratorState.Iterating;
  257. case AsyncIteratorState.Iterating:
  258. while (await _enumerator!.MoveNextAsync().ConfigureAwait(false))
  259. {
  260. var item = _enumerator.Current;
  261. if (await _predicate(item).ConfigureAwait(false))
  262. {
  263. _current = item;
  264. return true;
  265. }
  266. }
  267. await DisposeAsync().ConfigureAwait(false);
  268. break;
  269. }
  270. return false;
  271. }
  272. }
  273. #if !NO_DEEP_CANCELLATION
  274. internal sealed class WhereEnumerableAsyncIteratorWithTaskAndCancellation<TSource> : AsyncIterator<TSource>
  275. {
  276. private readonly Func<TSource, CancellationToken, ValueTask<bool>> _predicate;
  277. private readonly IAsyncEnumerable<TSource> _source;
  278. private IAsyncEnumerator<TSource>? _enumerator;
  279. public WhereEnumerableAsyncIteratorWithTaskAndCancellation(IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, ValueTask<bool>> predicate)
  280. {
  281. _source = source;
  282. _predicate = predicate;
  283. }
  284. public override AsyncIteratorBase<TSource> Clone()
  285. {
  286. return new WhereEnumerableAsyncIteratorWithTaskAndCancellation<TSource>(_source, _predicate);
  287. }
  288. public override async ValueTask DisposeAsync()
  289. {
  290. if (_enumerator != null)
  291. {
  292. await _enumerator.DisposeAsync().ConfigureAwait(false);
  293. _enumerator = null;
  294. }
  295. await base.DisposeAsync().ConfigureAwait(false);
  296. }
  297. public override IAsyncEnumerable<TSource> Where(Func<TSource, CancellationToken, ValueTask<bool>> predicate)
  298. {
  299. return new WhereEnumerableAsyncIteratorWithTaskAndCancellation<TSource>(_source, CombinePredicates(_predicate, predicate));
  300. }
  301. protected override async ValueTask<bool> MoveNextCore()
  302. {
  303. switch (_state)
  304. {
  305. case AsyncIteratorState.Allocated:
  306. _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
  307. _state = AsyncIteratorState.Iterating;
  308. goto case AsyncIteratorState.Iterating;
  309. case AsyncIteratorState.Iterating:
  310. while (await _enumerator!.MoveNextAsync().ConfigureAwait(false))
  311. {
  312. var item = _enumerator.Current;
  313. if (await _predicate(item, _cancellationToken).ConfigureAwait(false))
  314. {
  315. _current = item;
  316. return true;
  317. }
  318. }
  319. await DisposeAsync().ConfigureAwait(false);
  320. break;
  321. }
  322. return false;
  323. }
  324. }
  325. #endif
  326. private sealed class WhereSelectEnumerableAsyncIterator<TSource, TResult> : AsyncIterator<TResult>
  327. {
  328. private readonly Func<TSource, bool> _predicate;
  329. private readonly Func<TSource, TResult> _selector;
  330. private readonly IAsyncEnumerable<TSource> _source;
  331. private IAsyncEnumerator<TSource>? _enumerator;
  332. public WhereSelectEnumerableAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate, Func<TSource, TResult> selector)
  333. {
  334. _source = source;
  335. _predicate = predicate;
  336. _selector = selector;
  337. }
  338. public override AsyncIteratorBase<TResult> Clone()
  339. {
  340. return new WhereSelectEnumerableAsyncIterator<TSource, TResult>(_source, _predicate, _selector);
  341. }
  342. public override async ValueTask DisposeAsync()
  343. {
  344. if (_enumerator != null)
  345. {
  346. await _enumerator.DisposeAsync().ConfigureAwait(false);
  347. _enumerator = null;
  348. }
  349. await base.DisposeAsync().ConfigureAwait(false);
  350. }
  351. public override IAsyncEnumerable<TResult1> Select<TResult1>(Func<TResult, TResult1> selector)
  352. {
  353. return new WhereSelectEnumerableAsyncIterator<TSource, TResult1>(_source, _predicate, CombineSelectors(_selector, selector));
  354. }
  355. protected override async ValueTask<bool> MoveNextCore()
  356. {
  357. switch (_state)
  358. {
  359. case AsyncIteratorState.Allocated:
  360. _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
  361. _state = AsyncIteratorState.Iterating;
  362. goto case AsyncIteratorState.Iterating;
  363. case AsyncIteratorState.Iterating:
  364. while (await _enumerator!.MoveNextAsync().ConfigureAwait(false))
  365. {
  366. var item = _enumerator.Current;
  367. if (_predicate(item))
  368. {
  369. _current = _selector(item);
  370. return true;
  371. }
  372. }
  373. await DisposeAsync().ConfigureAwait(false);
  374. break;
  375. }
  376. return false;
  377. }
  378. }
  379. }
  380. }