Take.cs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Diagnostics;
  6. using System.Threading.Tasks;
  7. namespace System.Linq
  8. {
  9. public static partial class AsyncEnumerable
  10. {
  11. public static IAsyncEnumerable<TSource> Take<TSource>(this IAsyncEnumerable<TSource> source, int count)
  12. {
  13. if (source == null)
  14. throw new ArgumentNullException(nameof(source));
  15. if (count <= 0)
  16. {
  17. return Empty<TSource>();
  18. }
  19. return new TakeAsyncIterator<TSource>(source, count);
  20. }
  21. public static IAsyncEnumerable<TSource> TakeLast<TSource>(this IAsyncEnumerable<TSource> source, int count)
  22. {
  23. if (source == null)
  24. throw new ArgumentNullException(nameof(source));
  25. if (count <= 0)
  26. {
  27. return Empty<TSource>();
  28. }
  29. return new TakeLastAsyncIterator<TSource>(source, count);
  30. }
  31. public static IAsyncEnumerable<TSource> TakeWhile<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
  32. {
  33. if (source == null)
  34. throw new ArgumentNullException(nameof(source));
  35. if (predicate == null)
  36. throw new ArgumentNullException(nameof(predicate));
  37. return new TakeWhileAsyncIterator<TSource>(source, predicate);
  38. }
  39. public static IAsyncEnumerable<TSource> TakeWhile<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, int, bool> predicate)
  40. {
  41. if (source == null)
  42. throw new ArgumentNullException(nameof(source));
  43. if (predicate == null)
  44. throw new ArgumentNullException(nameof(predicate));
  45. return new TakeWhileWithIndexAsyncIterator<TSource>(source, predicate);
  46. }
  47. public static IAsyncEnumerable<TSource> TakeWhile<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate)
  48. {
  49. if (source == null)
  50. throw new ArgumentNullException(nameof(source));
  51. if (predicate == null)
  52. throw new ArgumentNullException(nameof(predicate));
  53. return new TakeWhileAsyncIteratorWithTask<TSource>(source, predicate);
  54. }
  55. public static IAsyncEnumerable<TSource> TakeWhile<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, int, Task<bool>> predicate)
  56. {
  57. if (source == null)
  58. throw new ArgumentNullException(nameof(source));
  59. if (predicate == null)
  60. throw new ArgumentNullException(nameof(predicate));
  61. return new TakeWhileWithIndexAsyncIteratorWithTask<TSource>(source, predicate);
  62. }
  63. private sealed class TakeAsyncIterator<TSource> : AsyncIterator<TSource>
  64. {
  65. private readonly int count;
  66. private readonly IAsyncEnumerable<TSource> source;
  67. private int currentCount;
  68. private IAsyncEnumerator<TSource> enumerator;
  69. public TakeAsyncIterator(IAsyncEnumerable<TSource> source, int count)
  70. {
  71. Debug.Assert(source != null);
  72. this.source = source;
  73. this.count = count;
  74. currentCount = count;
  75. }
  76. public override AsyncIterator<TSource> Clone()
  77. {
  78. return new TakeAsyncIterator<TSource>(source, count);
  79. }
  80. public override async Task DisposeAsync()
  81. {
  82. if (enumerator != null)
  83. {
  84. await enumerator.DisposeAsync().ConfigureAwait(false);
  85. enumerator = null;
  86. }
  87. await base.DisposeAsync().ConfigureAwait(false);
  88. }
  89. protected override async Task<bool> MoveNextCore()
  90. {
  91. switch (state)
  92. {
  93. case AsyncIteratorState.Allocated:
  94. enumerator = source.GetAsyncEnumerator();
  95. state = AsyncIteratorState.Iterating;
  96. goto case AsyncIteratorState.Iterating;
  97. case AsyncIteratorState.Iterating:
  98. if (currentCount > 0 && await enumerator.MoveNextAsync()
  99. .ConfigureAwait(false))
  100. {
  101. current = enumerator.Current;
  102. currentCount--;
  103. return true;
  104. }
  105. break;
  106. }
  107. await DisposeAsync().ConfigureAwait(false);
  108. return false;
  109. }
  110. }
  111. private sealed class TakeLastAsyncIterator<TSource> : AsyncIterator<TSource>
  112. {
  113. private readonly int count;
  114. private readonly IAsyncEnumerable<TSource> source;
  115. private IAsyncEnumerator<TSource> enumerator;
  116. private bool isDone;
  117. private Queue<TSource> queue;
  118. public TakeLastAsyncIterator(IAsyncEnumerable<TSource> source, int count)
  119. {
  120. Debug.Assert(source != null);
  121. this.source = source;
  122. this.count = count;
  123. }
  124. public override AsyncIterator<TSource> Clone()
  125. {
  126. return new TakeLastAsyncIterator<TSource>(source, count);
  127. }
  128. public override async Task DisposeAsync()
  129. {
  130. if (enumerator != null)
  131. {
  132. await enumerator.DisposeAsync().ConfigureAwait(false);
  133. enumerator = null;
  134. }
  135. queue = null; // release the memory
  136. await base.DisposeAsync().ConfigureAwait(false);
  137. }
  138. protected override async Task<bool> MoveNextCore()
  139. {
  140. switch (state)
  141. {
  142. case AsyncIteratorState.Allocated:
  143. enumerator = source.GetAsyncEnumerator();
  144. queue = new Queue<TSource>();
  145. isDone = false;
  146. state = AsyncIteratorState.Iterating;
  147. goto case AsyncIteratorState.Iterating;
  148. case AsyncIteratorState.Iterating:
  149. while (true)
  150. {
  151. if (!isDone)
  152. {
  153. if (await enumerator.MoveNextAsync()
  154. .ConfigureAwait(false))
  155. {
  156. if (count > 0)
  157. {
  158. var item = enumerator.Current;
  159. if (queue.Count >= count)
  160. {
  161. queue.Dequeue();
  162. }
  163. queue.Enqueue(item);
  164. }
  165. }
  166. else
  167. {
  168. isDone = true;
  169. // Dispose early here as we can
  170. await enumerator.DisposeAsync().ConfigureAwait(false);
  171. enumerator = null;
  172. }
  173. continue; // loop until queue is drained
  174. }
  175. if (queue.Count > 0)
  176. {
  177. current = queue.Dequeue();
  178. return true;
  179. }
  180. break; // while
  181. }
  182. break; // case
  183. }
  184. await DisposeAsync().ConfigureAwait(false);
  185. return false;
  186. }
  187. }
  188. private sealed class TakeWhileAsyncIterator<TSource> : AsyncIterator<TSource>
  189. {
  190. private readonly Func<TSource, bool> predicate;
  191. private readonly IAsyncEnumerable<TSource> source;
  192. private IAsyncEnumerator<TSource> enumerator;
  193. public TakeWhileAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
  194. {
  195. Debug.Assert(predicate != null);
  196. Debug.Assert(source != null);
  197. this.source = source;
  198. this.predicate = predicate;
  199. }
  200. public override AsyncIterator<TSource> Clone()
  201. {
  202. return new TakeWhileAsyncIterator<TSource>(source, predicate);
  203. }
  204. public override async Task DisposeAsync()
  205. {
  206. if (enumerator != null)
  207. {
  208. await enumerator.DisposeAsync().ConfigureAwait(false);
  209. enumerator = null;
  210. }
  211. await base.DisposeAsync().ConfigureAwait(false);
  212. }
  213. protected override async Task<bool> MoveNextCore()
  214. {
  215. switch (state)
  216. {
  217. case AsyncIteratorState.Allocated:
  218. enumerator = source.GetAsyncEnumerator();
  219. state = AsyncIteratorState.Iterating;
  220. goto case AsyncIteratorState.Iterating;
  221. case AsyncIteratorState.Iterating:
  222. if (await enumerator.MoveNextAsync()
  223. .ConfigureAwait(false))
  224. {
  225. var item = enumerator.Current;
  226. if (!predicate(item))
  227. {
  228. break;
  229. }
  230. current = item;
  231. return true;
  232. }
  233. break;
  234. }
  235. await DisposeAsync().ConfigureAwait(false);
  236. return false;
  237. }
  238. }
  239. private sealed class TakeWhileWithIndexAsyncIterator<TSource> : AsyncIterator<TSource>
  240. {
  241. private readonly Func<TSource, int, bool> predicate;
  242. private readonly IAsyncEnumerable<TSource> source;
  243. private IAsyncEnumerator<TSource> enumerator;
  244. private int index;
  245. public TakeWhileWithIndexAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, int, bool> predicate)
  246. {
  247. Debug.Assert(predicate != null);
  248. Debug.Assert(source != null);
  249. this.source = source;
  250. this.predicate = predicate;
  251. }
  252. public override AsyncIterator<TSource> Clone()
  253. {
  254. return new TakeWhileWithIndexAsyncIterator<TSource>(source, predicate);
  255. }
  256. public override async Task DisposeAsync()
  257. {
  258. if (enumerator != null)
  259. {
  260. await enumerator.DisposeAsync().ConfigureAwait(false);
  261. enumerator = null;
  262. }
  263. await base.DisposeAsync().ConfigureAwait(false);
  264. }
  265. protected override async Task<bool> MoveNextCore()
  266. {
  267. switch (state)
  268. {
  269. case AsyncIteratorState.Allocated:
  270. enumerator = source.GetAsyncEnumerator();
  271. index = -1;
  272. state = AsyncIteratorState.Iterating;
  273. goto case AsyncIteratorState.Iterating;
  274. case AsyncIteratorState.Iterating:
  275. if (await enumerator.MoveNextAsync()
  276. .ConfigureAwait(false))
  277. {
  278. var item = enumerator.Current;
  279. checked
  280. {
  281. index++;
  282. }
  283. if (!predicate(item, index))
  284. {
  285. break;
  286. }
  287. current = item;
  288. return true;
  289. }
  290. break;
  291. }
  292. await DisposeAsync().ConfigureAwait(false);
  293. return false;
  294. }
  295. }
  296. private sealed class TakeWhileAsyncIteratorWithTask<TSource> : AsyncIterator<TSource>
  297. {
  298. private readonly Func<TSource, Task<bool>> predicate;
  299. private readonly IAsyncEnumerable<TSource> source;
  300. private IAsyncEnumerator<TSource> enumerator;
  301. public TakeWhileAsyncIteratorWithTask(IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate)
  302. {
  303. Debug.Assert(predicate != null);
  304. Debug.Assert(source != null);
  305. this.source = source;
  306. this.predicate = predicate;
  307. }
  308. public override AsyncIterator<TSource> Clone()
  309. {
  310. return new TakeWhileAsyncIteratorWithTask<TSource>(source, predicate);
  311. }
  312. public override async Task DisposeAsync()
  313. {
  314. if (enumerator != null)
  315. {
  316. await enumerator.DisposeAsync().ConfigureAwait(false);
  317. enumerator = null;
  318. }
  319. await base.DisposeAsync().ConfigureAwait(false);
  320. }
  321. protected override async Task<bool> MoveNextCore()
  322. {
  323. switch (state)
  324. {
  325. case AsyncIteratorState.Allocated:
  326. enumerator = source.GetAsyncEnumerator();
  327. state = AsyncIteratorState.Iterating;
  328. goto case AsyncIteratorState.Iterating;
  329. case AsyncIteratorState.Iterating:
  330. if (await enumerator.MoveNextAsync()
  331. .ConfigureAwait(false))
  332. {
  333. var item = enumerator.Current;
  334. if (!await predicate(item).ConfigureAwait(false))
  335. {
  336. break;
  337. }
  338. current = item;
  339. return true;
  340. }
  341. break;
  342. }
  343. await DisposeAsync().ConfigureAwait(false);
  344. return false;
  345. }
  346. }
  347. private sealed class TakeWhileWithIndexAsyncIteratorWithTask<TSource> : AsyncIterator<TSource>
  348. {
  349. private readonly Func<TSource, int, Task<bool>> predicate;
  350. private readonly IAsyncEnumerable<TSource> source;
  351. private IAsyncEnumerator<TSource> enumerator;
  352. private int index;
  353. public TakeWhileWithIndexAsyncIteratorWithTask(IAsyncEnumerable<TSource> source, Func<TSource, int, Task<bool>> predicate)
  354. {
  355. Debug.Assert(predicate != null);
  356. Debug.Assert(source != null);
  357. this.source = source;
  358. this.predicate = predicate;
  359. }
  360. public override AsyncIterator<TSource> Clone()
  361. {
  362. return new TakeWhileWithIndexAsyncIteratorWithTask<TSource>(source, predicate);
  363. }
  364. public override async Task DisposeAsync()
  365. {
  366. if (enumerator != null)
  367. {
  368. await enumerator.DisposeAsync().ConfigureAwait(false);
  369. enumerator = null;
  370. }
  371. await base.DisposeAsync().ConfigureAwait(false);
  372. }
  373. protected override async Task<bool> MoveNextCore()
  374. {
  375. switch (state)
  376. {
  377. case AsyncIteratorState.Allocated:
  378. enumerator = source.GetAsyncEnumerator();
  379. index = -1;
  380. state = AsyncIteratorState.Iterating;
  381. goto case AsyncIteratorState.Iterating;
  382. case AsyncIteratorState.Iterating:
  383. if (await enumerator.MoveNextAsync()
  384. .ConfigureAwait(false))
  385. {
  386. var item = enumerator.Current;
  387. checked
  388. {
  389. index++;
  390. }
  391. if (!await predicate(item, index).ConfigureAwait(false))
  392. {
  393. break;
  394. }
  395. current = item;
  396. return true;
  397. }
  398. break;
  399. }
  400. await DisposeAsync().ConfigureAwait(false);
  401. return false;
  402. }
  403. }
  404. }
  405. }