Take.cs 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System;
  5. using System.Collections.Generic;
  6. using System.Linq;
  7. using System.Threading;
  8. using System.Threading.Tasks;
  9. namespace System.Linq
  10. {
  11. public static partial class AsyncEnumerable
  12. {
  13. public static IAsyncEnumerable<TSource> Take<TSource>(this IAsyncEnumerable<TSource> source, int count)
  14. {
  15. if (source == null)
  16. throw new ArgumentNullException(nameof(source));
  17. if (count < 0)
  18. throw new ArgumentOutOfRangeException(nameof(count));
  19. return CreateEnumerable(
  20. () =>
  21. {
  22. var e = source.GetEnumerator();
  23. var n = count;
  24. var cts = new CancellationTokenDisposable();
  25. var d = Disposable.Create(cts, e);
  26. return CreateEnumerator(
  27. async ct =>
  28. {
  29. if (n == 0)
  30. return false;
  31. var result = await e.MoveNext(cts.Token)
  32. .ConfigureAwait(false);
  33. --n;
  34. if (n == 0)
  35. e.Dispose();
  36. return result;
  37. },
  38. () => e.Current,
  39. d.Dispose,
  40. e
  41. );
  42. });
  43. }
  44. public static IAsyncEnumerable<TSource> TakeLast<TSource>(this IAsyncEnumerable<TSource> source, int count)
  45. {
  46. if (source == null)
  47. throw new ArgumentNullException(nameof(source));
  48. if (count < 0)
  49. throw new ArgumentOutOfRangeException(nameof(count));
  50. return CreateEnumerable(
  51. () =>
  52. {
  53. var e = source.GetEnumerator();
  54. var cts = new CancellationTokenDisposable();
  55. var d = Disposable.Create(cts, e);
  56. var q = new Queue<TSource>(count);
  57. var done = false;
  58. var current = default(TSource);
  59. var f = default(Func<CancellationToken, Task<bool>>);
  60. f = async ct =>
  61. {
  62. if (!done)
  63. {
  64. if (await e.MoveNext(ct)
  65. .ConfigureAwait(false))
  66. {
  67. if (count > 0)
  68. {
  69. var item = e.Current;
  70. if (q.Count >= count)
  71. q.Dequeue();
  72. q.Enqueue(item);
  73. }
  74. }
  75. else
  76. {
  77. done = true;
  78. e.Dispose();
  79. }
  80. return await f(ct)
  81. .ConfigureAwait(false);
  82. }
  83. if (q.Count > 0)
  84. {
  85. current = q.Dequeue();
  86. return true;
  87. }
  88. return false;
  89. };
  90. return CreateEnumerator(
  91. f,
  92. () => current,
  93. d.Dispose,
  94. e
  95. );
  96. });
  97. }
  98. public static IAsyncEnumerable<TSource> TakeWhile<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
  99. {
  100. if (source == null)
  101. throw new ArgumentNullException(nameof(source));
  102. if (predicate == null)
  103. throw new ArgumentNullException(nameof(predicate));
  104. return CreateEnumerable(
  105. () =>
  106. {
  107. var e = source.GetEnumerator();
  108. var cts = new CancellationTokenDisposable();
  109. var d = Disposable.Create(cts, e);
  110. return CreateEnumerator(
  111. async ct =>
  112. {
  113. if (await e.MoveNext(cts.Token)
  114. .ConfigureAwait(false))
  115. {
  116. return predicate(e.Current);
  117. }
  118. return false;
  119. },
  120. () => e.Current,
  121. d.Dispose,
  122. e
  123. );
  124. });
  125. }
  126. public static IAsyncEnumerable<TSource> TakeWhile<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, int, bool> predicate)
  127. {
  128. if (source == null)
  129. throw new ArgumentNullException(nameof(source));
  130. if (predicate == null)
  131. throw new ArgumentNullException(nameof(predicate));
  132. return CreateEnumerable(
  133. () =>
  134. {
  135. var e = source.GetEnumerator();
  136. var index = 0;
  137. var cts = new CancellationTokenDisposable();
  138. var d = Disposable.Create(cts, e);
  139. return CreateEnumerator(
  140. async ct =>
  141. {
  142. if (await e.MoveNext(cts.Token)
  143. .ConfigureAwait(false))
  144. {
  145. return predicate(e.Current, checked(index++));
  146. }
  147. return false;
  148. },
  149. () => e.Current,
  150. d.Dispose,
  151. e
  152. );
  153. });
  154. }
  155. }
  156. }