AsyncEnumerablePartition.cs 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. // Copied from https://github.com/dotnet/corefx/blob/5f1dd8298e4355b63bb760d88d437a91b3ca808c/src/System.Linq/src/System/Linq/Partition.cs
  5. using System.Collections.Generic;
  6. using System.Diagnostics;
  7. using System.Threading;
  8. using System.Threading.Tasks;
  9. namespace System.Linq
  10. {
  11. /// <summary>
  12. /// An iterator that yields the items of part of an <see cref="IAsyncEnumerable{TSource}"/>.
  13. /// </summary>
  14. /// <typeparam name="TSource">The type of the source enumerable.</typeparam>
  15. internal sealed class AsyncEnumerablePartition<TSource> : AsyncIterator<TSource>, IAsyncPartition<TSource>
  16. {
  17. private readonly IAsyncEnumerable<TSource> _source;
  18. private readonly int _minIndexInclusive;
  19. private readonly int _maxIndexInclusive; // -1 if we want everything past _minIndexInclusive.
  20. // If this is -1, it's impossible to set a limit on the count.
  21. private IAsyncEnumerator<TSource> _enumerator;
  22. internal AsyncEnumerablePartition(IAsyncEnumerable<TSource> source, int minIndexInclusive, int maxIndexInclusive)
  23. {
  24. Debug.Assert(source != null);
  25. Debug.Assert(!(source is IList<TSource>), $"The caller needs to check for {nameof(IList<TSource>)}.");
  26. Debug.Assert(minIndexInclusive >= 0);
  27. Debug.Assert(maxIndexInclusive >= -1);
  28. // Note that although maxIndexInclusive can't grow, it can still be int.MaxValue.
  29. // We support partitioning enumerables with > 2B elements. For example, e.Skip(1).Take(int.MaxValue) should work.
  30. // But if it is int.MaxValue, then minIndexInclusive must != 0. Otherwise, our count may overflow.
  31. Debug.Assert(maxIndexInclusive == -1 || (maxIndexInclusive - minIndexInclusive < int.MaxValue), $"{nameof(Limit)} will overflow!");
  32. Debug.Assert(maxIndexInclusive == -1 || minIndexInclusive <= maxIndexInclusive);
  33. _source = source;
  34. _minIndexInclusive = minIndexInclusive;
  35. _maxIndexInclusive = maxIndexInclusive;
  36. }
  37. // If this is true (e.g. at least one Take call was made), then we have an upper bound
  38. // on how many elements we can have.
  39. private bool HasLimit => _maxIndexInclusive != -1;
  40. private int Limit => (_maxIndexInclusive + 1) - _minIndexInclusive; // This is that upper bound.
  41. public override AsyncIterator<TSource> Clone()
  42. {
  43. return new AsyncEnumerablePartition<TSource>(_source, _minIndexInclusive, _maxIndexInclusive);
  44. }
  45. public override async ValueTask DisposeAsync()
  46. {
  47. if (_enumerator != null)
  48. {
  49. await _enumerator.DisposeAsync().ConfigureAwait(false);
  50. _enumerator = null;
  51. }
  52. await base.DisposeAsync().ConfigureAwait(false);
  53. }
  54. public async Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
  55. {
  56. if (onlyIfCheap)
  57. {
  58. return -1;
  59. }
  60. if (!HasLimit)
  61. {
  62. // If HasLimit is false, we contain everything past _minIndexInclusive.
  63. // Therefore, we have to iterate the whole enumerable.
  64. return Math.Max(await _source.Count(cancellationToken).ConfigureAwait(false) - _minIndexInclusive, 0);
  65. }
  66. var en = _source.GetAsyncEnumerator(cancellationToken);
  67. try
  68. {
  69. // We only want to iterate up to _maxIndexInclusive + 1.
  70. // Past that, we know the enumerable will be able to fit this partition,
  71. // so the count will just be _maxIndexInclusive + 1 - _minIndexInclusive.
  72. // Note that it is possible for _maxIndexInclusive to be int.MaxValue here,
  73. // so + 1 may result in signed integer overflow. We need to handle this.
  74. // At the same time, however, we are guaranteed that our max count can fit
  75. // in an int because if that is true, then _minIndexInclusive must > 0.
  76. var count = await SkipAndCountAsync((uint)_maxIndexInclusive + 1, en, cancellationToken).ConfigureAwait(false);
  77. Debug.Assert(count != (uint)int.MaxValue + 1 || _minIndexInclusive > 0, "Our return value will be incorrect.");
  78. return Math.Max((int)count - _minIndexInclusive, 0);
  79. }
  80. finally
  81. {
  82. await en.DisposeAsync().ConfigureAwait(false);
  83. }
  84. }
  85. private bool _hasSkipped;
  86. private int _taken;
  87. protected override async ValueTask<bool> MoveNextCore(CancellationToken cancellationToken)
  88. {
  89. switch (state)
  90. {
  91. case AsyncIteratorState.Allocated:
  92. _enumerator = _source.GetAsyncEnumerator(cancellationToken);
  93. _hasSkipped = false;
  94. _taken = 0;
  95. state = AsyncIteratorState.Iterating;
  96. goto case AsyncIteratorState.Iterating;
  97. case AsyncIteratorState.Iterating:
  98. if (!_hasSkipped)
  99. {
  100. if (!await SkipBeforeFirstAsync(_enumerator, CancellationToken.None).ConfigureAwait(false))
  101. {
  102. // Reached the end before we finished skipping.
  103. break;
  104. }
  105. _hasSkipped = true;
  106. }
  107. if ((!HasLimit || _taken < Limit) && await _enumerator.MoveNextAsync().ConfigureAwait(false))
  108. {
  109. if (HasLimit)
  110. {
  111. // If we are taking an unknown number of elements, it's important not to increment _state.
  112. // _state - 3 may eventually end up overflowing & we'll hit the Dispose branch even though
  113. // we haven't finished enumerating.
  114. _taken++;
  115. }
  116. current = _enumerator.Current;
  117. return true;
  118. }
  119. break;
  120. }
  121. await DisposeAsync().ConfigureAwait(false);
  122. return false;
  123. }
  124. #if NOTYET
  125. public override IAsyncEnumerable<TResult> Select<TResult>(Func<TSource, TResult> selector)
  126. {
  127. return new SelectIPartitionIterator<TSource, TResult>(this, selector);
  128. }
  129. public override IAsyncEnumerable<TResult> Select<TResult>(Func<TSource, Task<TResult>> selector)
  130. {
  131. return new SelectIPartitionIterator<TSource, TResult>(this, selector);
  132. }
  133. #endif
  134. public IAsyncPartition<TSource> Skip(int count)
  135. {
  136. var minIndex = _minIndexInclusive + count;
  137. if (!HasLimit)
  138. {
  139. if (minIndex < 0)
  140. {
  141. // If we don't know our max count and minIndex can no longer fit in a positive int,
  142. // then we will need to wrap ourselves in another iterator.
  143. // This can happen, for example, during e.Skip(int.MaxValue).Skip(int.MaxValue).
  144. return new AsyncEnumerablePartition<TSource>(this, count, -1);
  145. }
  146. }
  147. else if ((uint)minIndex > (uint)_maxIndexInclusive)
  148. {
  149. // If minIndex overflows and we have an upper bound, we will go down this branch.
  150. // We know our upper bound must be smaller than minIndex, since our upper bound fits in an int.
  151. // This branch should not be taken if we don't have a bound.
  152. return AsyncEnumerable.EmptyAsyncIterator<TSource>.Instance;
  153. }
  154. Debug.Assert(minIndex >= 0, $"We should have taken care of all cases when {nameof(minIndex)} overflows.");
  155. return new AsyncEnumerablePartition<TSource>(_source, minIndex, _maxIndexInclusive);
  156. }
  157. public IAsyncPartition<TSource> Take(int count)
  158. {
  159. var maxIndex = _minIndexInclusive + count - 1;
  160. if (!HasLimit)
  161. {
  162. if (maxIndex < 0)
  163. {
  164. // If we don't know our max count and maxIndex can no longer fit in a positive int,
  165. // then we will need to wrap ourselves in another iterator.
  166. // Note that although maxIndex may be too large, the difference between it and
  167. // _minIndexInclusive (which is count - 1) must fit in an int.
  168. // Example: e.Skip(50).Take(int.MaxValue).
  169. return new AsyncEnumerablePartition<TSource>(this, 0, count - 1);
  170. }
  171. }
  172. else if ((uint)maxIndex >= (uint)_maxIndexInclusive)
  173. {
  174. // If we don't know our max count, we can't go down this branch.
  175. // It's always possible for us to contain more than count items, as the rest
  176. // of the enumerable past _minIndexInclusive can be arbitrarily long.
  177. return this;
  178. }
  179. Debug.Assert(maxIndex >= 0, $"We should have taken care of all cases when {nameof(maxIndex)} overflows.");
  180. return new AsyncEnumerablePartition<TSource>(_source, _minIndexInclusive, maxIndex);
  181. }
  182. public async Task<Maybe<TSource>> TryGetElementAsync(int index, CancellationToken cancellationToken)
  183. {
  184. // If the index is negative or >= our max count, return early.
  185. if (index >= 0 && (!HasLimit || index < Limit))
  186. {
  187. var en = _source.GetAsyncEnumerator(cancellationToken);
  188. try
  189. {
  190. Debug.Assert(_minIndexInclusive + index >= 0, $"Adding {nameof(index)} caused {nameof(_minIndexInclusive)} to overflow.");
  191. if (await SkipBeforeAsync(_minIndexInclusive + index, en, cancellationToken).ConfigureAwait(false) && await en.MoveNextAsync().ConfigureAwait(false))
  192. {
  193. return new Maybe<TSource>(en.Current);
  194. }
  195. }
  196. finally
  197. {
  198. await en.DisposeAsync().ConfigureAwait(false);
  199. }
  200. }
  201. return new Maybe<TSource>();
  202. }
  203. public async Task<Maybe<TSource>> TryGetFirstAsync(CancellationToken cancellationToken)
  204. {
  205. var en = _source.GetAsyncEnumerator(cancellationToken);
  206. try
  207. {
  208. if (await SkipBeforeFirstAsync(en, cancellationToken).ConfigureAwait(false) && await en.MoveNextAsync().ConfigureAwait(false))
  209. {
  210. return new Maybe<TSource>(en.Current);
  211. }
  212. }
  213. finally
  214. {
  215. await en.DisposeAsync().ConfigureAwait(false);
  216. }
  217. return new Maybe<TSource>();
  218. }
  219. public async Task<Maybe<TSource>> TryGetLastAsync(CancellationToken cancellationToken)
  220. {
  221. var en = _source.GetAsyncEnumerator(cancellationToken);
  222. try
  223. {
  224. if (await SkipBeforeFirstAsync(en, cancellationToken).ConfigureAwait(false) && await en.MoveNextAsync().ConfigureAwait(false))
  225. {
  226. var remaining = Limit - 1; // Max number of items left, not counting the current element.
  227. var comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true.
  228. TSource result;
  229. do
  230. {
  231. remaining--;
  232. result = en.Current;
  233. }
  234. while (remaining >= comparand && await en.MoveNextAsync().ConfigureAwait(false));
  235. return new Maybe<TSource>(result);
  236. }
  237. }
  238. finally
  239. {
  240. await en.DisposeAsync().ConfigureAwait(false);
  241. }
  242. return new Maybe<TSource>();
  243. }
  244. public async Task<TSource[]> ToArrayAsync(CancellationToken cancellationToken)
  245. {
  246. var en = _source.GetAsyncEnumerator(cancellationToken);
  247. try
  248. {
  249. if (await SkipBeforeFirstAsync(en, cancellationToken).ConfigureAwait(false) && await en.MoveNextAsync().ConfigureAwait(false))
  250. {
  251. var remaining = Limit - 1; // Max number of items left, not counting the current element.
  252. var comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true.
  253. var maxCapacity = HasLimit ? Limit : int.MaxValue;
  254. var builder = new List<TSource>(maxCapacity);
  255. do
  256. {
  257. remaining--;
  258. builder.Add(en.Current);
  259. }
  260. while (remaining >= comparand && await en.MoveNextAsync().ConfigureAwait(false));
  261. return builder.ToArray();
  262. }
  263. }
  264. finally
  265. {
  266. await en.DisposeAsync().ConfigureAwait(false);
  267. }
  268. #if NO_ARRAY_EMPTY
  269. return EmptyArray<TSource>.Value;
  270. #else
  271. return Array.Empty<TSource>();
  272. #endif
  273. }
  274. public async Task<List<TSource>> ToListAsync(CancellationToken cancellationToken)
  275. {
  276. var list = new List<TSource>();
  277. var en = _source.GetAsyncEnumerator(cancellationToken);
  278. try
  279. {
  280. if (await SkipBeforeFirstAsync(en, cancellationToken).ConfigureAwait(false) && await en.MoveNextAsync().ConfigureAwait(false))
  281. {
  282. var remaining = Limit - 1; // Max number of items left, not counting the current element.
  283. var comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true.
  284. do
  285. {
  286. remaining--;
  287. list.Add(en.Current);
  288. }
  289. while (remaining >= comparand && await en.MoveNextAsync().ConfigureAwait(false));
  290. }
  291. }
  292. finally
  293. {
  294. await en.DisposeAsync().ConfigureAwait(false);
  295. }
  296. return list;
  297. }
  298. private Task<bool> SkipBeforeFirstAsync(IAsyncEnumerator<TSource> en, CancellationToken cancellationToken) => SkipBeforeAsync(_minIndexInclusive, en, cancellationToken);
  299. private static async Task<bool> SkipBeforeAsync(int index, IAsyncEnumerator<TSource> en, CancellationToken cancellationToken)
  300. {
  301. var n = await SkipAndCountAsync(index, en, cancellationToken).ConfigureAwait(false);
  302. return n == index;
  303. }
  304. private static async Task<int> SkipAndCountAsync(int index, IAsyncEnumerator<TSource> en, CancellationToken cancellationToken)
  305. {
  306. Debug.Assert(index >= 0);
  307. return (int)await SkipAndCountAsync((uint)index, en, cancellationToken).ConfigureAwait(false);
  308. }
  309. private static async Task<uint> SkipAndCountAsync(uint index, IAsyncEnumerator<TSource> en, CancellationToken cancellationToken)
  310. {
  311. Debug.Assert(en != null);
  312. for (uint i = 0; i < index; i++)
  313. {
  314. if (!await en.MoveNextAsync().ConfigureAwait(false))
  315. {
  316. return i;
  317. }
  318. }
  319. return index;
  320. }
  321. }
  322. }