Where.cs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Diagnostics;
  6. using System.Threading.Tasks;
  7. namespace System.Linq
  8. {
  9. public static partial class AsyncEnumerable
  10. {
  11. public static IAsyncEnumerable<TSource> Where<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
  12. {
  13. if (source == null)
  14. {
  15. throw new ArgumentNullException(nameof(source));
  16. }
  17. if (predicate == null)
  18. {
  19. throw new ArgumentNullException(nameof(predicate));
  20. }
  21. if (source is AsyncIterator<TSource> iterator)
  22. {
  23. return iterator.Where(predicate);
  24. }
  25. // TODO: Can we add array/list optimizations here, does it make sense?
  26. return new WhereEnumerableAsyncIterator<TSource>(source, predicate);
  27. }
  28. public static IAsyncEnumerable<TSource> Where<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, int, bool> predicate)
  29. {
  30. if (source == null)
  31. {
  32. throw new ArgumentNullException(nameof(source));
  33. }
  34. if (predicate == null)
  35. {
  36. throw new ArgumentNullException(nameof(predicate));
  37. }
  38. return new WhereEnumerableWithIndexAsyncIterator<TSource>(source, predicate);
  39. }
  40. public static IAsyncEnumerable<TSource> Where<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate)
  41. {
  42. if (source == null)
  43. throw new ArgumentNullException(nameof(source));
  44. if (predicate == null)
  45. throw new ArgumentNullException(nameof(predicate));
  46. if (source is AsyncIterator<TSource> iterator)
  47. {
  48. return iterator.Where(predicate);
  49. }
  50. // TODO: Can we add array/list optimizations here, does it make sense?
  51. return new WhereEnumerableAsyncIteratorWithTask<TSource>(source, predicate);
  52. }
  53. public static IAsyncEnumerable<TSource> Where<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, int, Task<bool>> predicate)
  54. {
  55. if (source == null)
  56. throw new ArgumentNullException(nameof(source));
  57. if (predicate == null)
  58. throw new ArgumentNullException(nameof(predicate));
  59. return new WhereEnumerableWithIndexAsyncIteratorWithTask<TSource>(source, predicate);
  60. }
  61. private static Func<TSource, bool> CombinePredicates<TSource>(Func<TSource, bool> predicate1, Func<TSource, bool> predicate2)
  62. {
  63. return x => predicate1(x) && predicate2(x);
  64. }
  65. private static Func<TSource, Task<bool>> CombinePredicates<TSource>(Func<TSource, Task<bool>> predicate1, Func<TSource, Task<bool>> predicate2)
  66. {
  67. return async x => await predicate1(x).ConfigureAwait(false) && await predicate2(x).ConfigureAwait(false);
  68. }
  69. internal sealed class WhereEnumerableAsyncIterator<TSource> : AsyncIterator<TSource>
  70. {
  71. private readonly Func<TSource, bool> predicate;
  72. private readonly IAsyncEnumerable<TSource> source;
  73. private IAsyncEnumerator<TSource> enumerator;
  74. public WhereEnumerableAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
  75. {
  76. Debug.Assert(source != null);
  77. Debug.Assert(predicate != null);
  78. this.source = source;
  79. this.predicate = predicate;
  80. }
  81. public override AsyncIterator<TSource> Clone()
  82. {
  83. return new WhereEnumerableAsyncIterator<TSource>(source, predicate);
  84. }
  85. public override async Task DisposeAsync()
  86. {
  87. if (enumerator != null)
  88. {
  89. await enumerator.DisposeAsync().ConfigureAwait(false);
  90. enumerator = null;
  91. }
  92. await base.DisposeAsync().ConfigureAwait(false);
  93. }
  94. public override IAsyncEnumerable<TResult> Select<TResult>(Func<TSource, TResult> selector)
  95. {
  96. return new WhereSelectEnumerableAsyncIterator<TSource, TResult>(source, predicate, selector);
  97. }
  98. public override IAsyncEnumerable<TSource> Where(Func<TSource, bool> predicate)
  99. {
  100. return new WhereEnumerableAsyncIterator<TSource>(source, CombinePredicates(this.predicate, predicate));
  101. }
  102. protected override async Task<bool> MoveNextCore()
  103. {
  104. switch (state)
  105. {
  106. case AsyncIteratorState.Allocated:
  107. enumerator = source.GetAsyncEnumerator();
  108. state = AsyncIteratorState.Iterating;
  109. goto case AsyncIteratorState.Iterating;
  110. case AsyncIteratorState.Iterating:
  111. while (await enumerator.MoveNextAsync().ConfigureAwait(false))
  112. {
  113. var item = enumerator.Current;
  114. if (predicate(item))
  115. {
  116. current = item;
  117. return true;
  118. }
  119. }
  120. await DisposeAsync().ConfigureAwait(false);
  121. break;
  122. }
  123. return false;
  124. }
  125. }
  126. internal sealed class WhereEnumerableWithIndexAsyncIterator<TSource> : AsyncIterator<TSource>
  127. {
  128. private readonly Func<TSource, int, bool> predicate;
  129. private readonly IAsyncEnumerable<TSource> source;
  130. private IAsyncEnumerator<TSource> enumerator;
  131. private int index;
  132. public WhereEnumerableWithIndexAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, int, bool> predicate)
  133. {
  134. Debug.Assert(source != null);
  135. Debug.Assert(predicate != null);
  136. this.source = source;
  137. this.predicate = predicate;
  138. }
  139. public override AsyncIterator<TSource> Clone()
  140. {
  141. return new WhereEnumerableWithIndexAsyncIterator<TSource>(source, predicate);
  142. }
  143. public override async Task DisposeAsync()
  144. {
  145. if (enumerator != null)
  146. {
  147. await enumerator.DisposeAsync().ConfigureAwait(false);
  148. enumerator = null;
  149. }
  150. await base.DisposeAsync().ConfigureAwait(false);
  151. }
  152. protected override async Task<bool> MoveNextCore()
  153. {
  154. switch (state)
  155. {
  156. case AsyncIteratorState.Allocated:
  157. enumerator = source.GetAsyncEnumerator();
  158. index = -1;
  159. state = AsyncIteratorState.Iterating;
  160. goto case AsyncIteratorState.Iterating;
  161. case AsyncIteratorState.Iterating:
  162. while (await enumerator.MoveNextAsync().ConfigureAwait(false))
  163. {
  164. checked
  165. {
  166. index++;
  167. }
  168. var item = enumerator.Current;
  169. if (predicate(item, index))
  170. {
  171. current = item;
  172. return true;
  173. }
  174. }
  175. await DisposeAsync().ConfigureAwait(false);
  176. break;
  177. }
  178. return false;
  179. }
  180. }
  181. internal sealed class WhereEnumerableAsyncIteratorWithTask<TSource> : AsyncIterator<TSource>
  182. {
  183. private readonly Func<TSource, Task<bool>> predicate;
  184. private readonly IAsyncEnumerable<TSource> source;
  185. private IAsyncEnumerator<TSource> enumerator;
  186. public WhereEnumerableAsyncIteratorWithTask(IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate)
  187. {
  188. Debug.Assert(source != null);
  189. Debug.Assert(predicate != null);
  190. this.source = source;
  191. this.predicate = predicate;
  192. }
  193. public override AsyncIterator<TSource> Clone()
  194. {
  195. return new WhereEnumerableAsyncIteratorWithTask<TSource>(source, predicate);
  196. }
  197. public override async Task DisposeAsync()
  198. {
  199. if (enumerator != null)
  200. {
  201. await enumerator.DisposeAsync().ConfigureAwait(false);
  202. enumerator = null;
  203. }
  204. await base.DisposeAsync().ConfigureAwait(false);
  205. }
  206. public override IAsyncEnumerable<TSource> Where(Func<TSource, Task<bool>> predicate)
  207. {
  208. return new WhereEnumerableAsyncIteratorWithTask<TSource>(source, CombinePredicates(this.predicate, predicate));
  209. }
  210. protected override async Task<bool> MoveNextCore()
  211. {
  212. switch (state)
  213. {
  214. case AsyncIteratorState.Allocated:
  215. enumerator = source.GetAsyncEnumerator();
  216. state = AsyncIteratorState.Iterating;
  217. goto case AsyncIteratorState.Iterating;
  218. case AsyncIteratorState.Iterating:
  219. while (await enumerator.MoveNextAsync().ConfigureAwait(false))
  220. {
  221. var item = enumerator.Current;
  222. if (await predicate(item).ConfigureAwait(false))
  223. {
  224. current = item;
  225. return true;
  226. }
  227. }
  228. await DisposeAsync().ConfigureAwait(false);
  229. break;
  230. }
  231. return false;
  232. }
  233. }
  234. internal sealed class WhereEnumerableWithIndexAsyncIteratorWithTask<TSource> : AsyncIterator<TSource>
  235. {
  236. private readonly Func<TSource, int, Task<bool>> predicate;
  237. private readonly IAsyncEnumerable<TSource> source;
  238. private IAsyncEnumerator<TSource> enumerator;
  239. private int index;
  240. public WhereEnumerableWithIndexAsyncIteratorWithTask(IAsyncEnumerable<TSource> source, Func<TSource, int, Task<bool>> predicate)
  241. {
  242. Debug.Assert(source != null);
  243. Debug.Assert(predicate != null);
  244. this.source = source;
  245. this.predicate = predicate;
  246. }
  247. public override AsyncIterator<TSource> Clone()
  248. {
  249. return new WhereEnumerableWithIndexAsyncIteratorWithTask<TSource>(source, predicate);
  250. }
  251. public override async Task DisposeAsync()
  252. {
  253. if (enumerator != null)
  254. {
  255. await enumerator.DisposeAsync().ConfigureAwait(false);
  256. enumerator = null;
  257. }
  258. await base.DisposeAsync().ConfigureAwait(false);
  259. }
  260. protected override async Task<bool> MoveNextCore()
  261. {
  262. switch (state)
  263. {
  264. case AsyncIteratorState.Allocated:
  265. enumerator = source.GetAsyncEnumerator();
  266. index = -1;
  267. state = AsyncIteratorState.Iterating;
  268. goto case AsyncIteratorState.Iterating;
  269. case AsyncIteratorState.Iterating:
  270. while (await enumerator.MoveNextAsync().ConfigureAwait(false))
  271. {
  272. checked
  273. {
  274. index++;
  275. }
  276. var item = enumerator.Current;
  277. if (await predicate(item, index).ConfigureAwait(false))
  278. {
  279. current = item;
  280. return true;
  281. }
  282. }
  283. await DisposeAsync().ConfigureAwait(false);
  284. break;
  285. }
  286. return false;
  287. }
  288. }
  289. internal sealed class WhereSelectEnumerableAsyncIterator<TSource, TResult> : AsyncIterator<TResult>
  290. {
  291. private readonly Func<TSource, bool> predicate;
  292. private readonly Func<TSource, TResult> selector;
  293. private readonly IAsyncEnumerable<TSource> source;
  294. private IAsyncEnumerator<TSource> enumerator;
  295. public WhereSelectEnumerableAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate, Func<TSource, TResult> selector)
  296. {
  297. Debug.Assert(source != null);
  298. Debug.Assert(predicate != null);
  299. Debug.Assert(selector != null);
  300. this.source = source;
  301. this.predicate = predicate;
  302. this.selector = selector;
  303. }
  304. public override AsyncIterator<TResult> Clone()
  305. {
  306. return new WhereSelectEnumerableAsyncIterator<TSource, TResult>(source, predicate, selector);
  307. }
  308. public override async Task DisposeAsync()
  309. {
  310. if (enumerator != null)
  311. {
  312. await enumerator.DisposeAsync().ConfigureAwait(false);
  313. enumerator = null;
  314. }
  315. await base.DisposeAsync().ConfigureAwait(false);
  316. }
  317. public override IAsyncEnumerable<TResult1> Select<TResult1>(Func<TResult, TResult1> selector)
  318. {
  319. return new WhereSelectEnumerableAsyncIterator<TSource, TResult1>(source, predicate, CombineSelectors(this.selector, selector));
  320. }
  321. protected override async Task<bool> MoveNextCore()
  322. {
  323. switch (state)
  324. {
  325. case AsyncIteratorState.Allocated:
  326. enumerator = source.GetAsyncEnumerator();
  327. state = AsyncIteratorState.Iterating;
  328. goto case AsyncIteratorState.Iterating;
  329. case AsyncIteratorState.Iterating:
  330. while (await enumerator.MoveNextAsync().ConfigureAwait(false))
  331. {
  332. var item = enumerator.Current;
  333. if (predicate(item))
  334. {
  335. current = selector(item);
  336. return true;
  337. }
  338. }
  339. await DisposeAsync().ConfigureAwait(false);
  340. break;
  341. }
  342. return false;
  343. }
  344. }
  345. }
  346. }