Take.cs 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System;
  5. using System.Collections.Generic;
  6. using System.Linq;
  7. using System.Threading;
  8. using System.Threading.Tasks;
  9. namespace System.Linq
  10. {
  11. public static partial class AsyncEnumerable
  12. {
  13. public static IAsyncEnumerable<TSource> Take<TSource>(this IAsyncEnumerable<TSource> source, int count)
  14. {
  15. if (source == null)
  16. throw new ArgumentNullException(nameof(source));
  17. if (count < 0)
  18. throw new ArgumentOutOfRangeException(nameof(count));
  19. return CreateEnumerable(
  20. () =>
  21. {
  22. var e = source.GetEnumerator();
  23. var n = count;
  24. var cts = new CancellationTokenDisposable();
  25. var d = Disposable.Create(cts, e);
  26. var current = default(TSource);
  27. return CreateEnumerator(
  28. async ct =>
  29. {
  30. if (n == 0)
  31. return false;
  32. var result = await e.MoveNext(cts.Token)
  33. .ConfigureAwait(false);
  34. --n;
  35. if (result)
  36. current = e.Current;
  37. if (n == 0)
  38. e.Dispose();
  39. return result;
  40. },
  41. () => current,
  42. d.Dispose,
  43. e
  44. );
  45. });
  46. }
  47. public static IAsyncEnumerable<TSource> TakeLast<TSource>(this IAsyncEnumerable<TSource> source, int count)
  48. {
  49. if (source == null)
  50. throw new ArgumentNullException(nameof(source));
  51. if (count < 0)
  52. throw new ArgumentOutOfRangeException(nameof(count));
  53. return CreateEnumerable(
  54. () =>
  55. {
  56. var e = source.GetEnumerator();
  57. var cts = new CancellationTokenDisposable();
  58. var d = Disposable.Create(cts, e);
  59. var q = new Queue<TSource>(count);
  60. var done = false;
  61. var current = default(TSource);
  62. var f = default(Func<CancellationToken, Task<bool>>);
  63. f = async ct =>
  64. {
  65. if (!done)
  66. {
  67. if (await e.MoveNext(ct)
  68. .ConfigureAwait(false))
  69. {
  70. if (count > 0)
  71. {
  72. var item = e.Current;
  73. if (q.Count >= count)
  74. q.Dequeue();
  75. q.Enqueue(item);
  76. }
  77. }
  78. else
  79. {
  80. done = true;
  81. e.Dispose();
  82. }
  83. return await f(ct)
  84. .ConfigureAwait(false);
  85. }
  86. if (q.Count > 0)
  87. {
  88. current = q.Dequeue();
  89. return true;
  90. }
  91. return false;
  92. };
  93. return CreateEnumerator(
  94. f,
  95. () => current,
  96. d.Dispose,
  97. e
  98. );
  99. });
  100. }
  101. public static IAsyncEnumerable<TSource> TakeWhile<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
  102. {
  103. if (source == null)
  104. throw new ArgumentNullException(nameof(source));
  105. if (predicate == null)
  106. throw new ArgumentNullException(nameof(predicate));
  107. return CreateEnumerable(
  108. () =>
  109. {
  110. var e = source.GetEnumerator();
  111. var cts = new CancellationTokenDisposable();
  112. var d = Disposable.Create(cts, e);
  113. return CreateEnumerator(
  114. async ct =>
  115. {
  116. if (await e.MoveNext(cts.Token)
  117. .ConfigureAwait(false))
  118. {
  119. return predicate(e.Current);
  120. }
  121. return false;
  122. },
  123. () => e.Current,
  124. d.Dispose,
  125. e
  126. );
  127. });
  128. }
  129. public static IAsyncEnumerable<TSource> TakeWhile<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, int, bool> predicate)
  130. {
  131. if (source == null)
  132. throw new ArgumentNullException(nameof(source));
  133. if (predicate == null)
  134. throw new ArgumentNullException(nameof(predicate));
  135. return CreateEnumerable(
  136. () =>
  137. {
  138. var e = source.GetEnumerator();
  139. var index = 0;
  140. var cts = new CancellationTokenDisposable();
  141. var d = Disposable.Create(cts, e);
  142. return CreateEnumerator(
  143. async ct =>
  144. {
  145. if (await e.MoveNext(cts.Token)
  146. .ConfigureAwait(false))
  147. {
  148. return predicate(e.Current, checked(index++));
  149. }
  150. return false;
  151. },
  152. () => e.Current,
  153. d.Dispose,
  154. e
  155. );
  156. });
  157. }
  158. }
  159. }