AsyncEnumerablePartition.cs 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. // Copied from https://github.com/dotnet/corefx/blob/5f1dd8298e4355b63bb760d88d437a91b3ca808c/src/System.Linq/src/System/Linq/Partition.cs
  5. using System.Collections.Generic;
  6. using System.Diagnostics;
  7. using System.Threading;
  8. using System.Threading.Tasks;
  9. namespace System.Linq
  10. {
  11. /// <summary>
  12. /// An iterator that yields the items of part of an <see cref="IAsyncEnumerable{TSource}"/>.
  13. /// </summary>
  14. /// <typeparam name="TSource">The type of the source enumerable.</typeparam>
  15. internal sealed class AsyncEnumerablePartition<TSource> : AsyncIterator<TSource>, IAsyncPartition<TSource>
  16. {
  17. private readonly IAsyncEnumerable<TSource> _source;
  18. private readonly int _minIndexInclusive;
  19. private readonly int _maxIndexInclusive; // -1 if we want everything past _minIndexInclusive.
  20. // If this is -1, it's impossible to set a limit on the count.
  21. private IAsyncEnumerator<TSource>? _enumerator;
  22. internal AsyncEnumerablePartition(IAsyncEnumerable<TSource> source, int minIndexInclusive, int maxIndexInclusive)
  23. {
  24. Debug.Assert(!(source is IList<TSource>), $"The caller needs to check for {nameof(IList<TSource>)}.");
  25. Debug.Assert(minIndexInclusive >= 0);
  26. Debug.Assert(maxIndexInclusive >= -1);
  27. // Note that although maxIndexInclusive can't grow, it can still be int.MaxValue.
  28. // We support partitioning enumerables with > 2B elements. For example, e.Skip(1).Take(int.MaxValue) should work.
  29. // But if it is int.MaxValue, then minIndexInclusive must != 0. Otherwise, our count may overflow.
  30. Debug.Assert(maxIndexInclusive == -1 || (maxIndexInclusive - minIndexInclusive < int.MaxValue), $"{nameof(Limit)} will overflow!");
  31. Debug.Assert(maxIndexInclusive == -1 || minIndexInclusive <= maxIndexInclusive);
  32. _source = source;
  33. _minIndexInclusive = minIndexInclusive;
  34. _maxIndexInclusive = maxIndexInclusive;
  35. }
  36. // If this is true (e.g. at least one Take call was made), then we have an upper bound
  37. // on how many elements we can have.
  38. private bool HasLimit => _maxIndexInclusive != -1;
  39. private int Limit => (_maxIndexInclusive + 1) - _minIndexInclusive; // This is that upper bound.
  40. public override AsyncIteratorBase<TSource> Clone()
  41. {
  42. return new AsyncEnumerablePartition<TSource>(_source, _minIndexInclusive, _maxIndexInclusive);
  43. }
  44. public override async ValueTask DisposeAsync()
  45. {
  46. if (_enumerator != null)
  47. {
  48. await _enumerator.DisposeAsync().ConfigureAwait(false);
  49. _enumerator = null;
  50. }
  51. await base.DisposeAsync().ConfigureAwait(false);
  52. }
  53. public ValueTask<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
  54. {
  55. if (onlyIfCheap)
  56. {
  57. return new ValueTask<int>(-1);
  58. }
  59. return Core();
  60. async ValueTask<int> Core()
  61. {
  62. if (!HasLimit)
  63. {
  64. // If HasLimit is false, we contain everything past _minIndexInclusive.
  65. // Therefore, we have to iterate the whole enumerable.
  66. return Math.Max(await _source.CountAsync(cancellationToken).ConfigureAwait(false) - _minIndexInclusive, 0);
  67. }
  68. var en = _source.GetAsyncEnumerator(cancellationToken);
  69. try
  70. {
  71. // We only want to iterate up to _maxIndexInclusive + 1.
  72. // Past that, we know the enumerable will be able to fit this partition,
  73. // so the count will just be _maxIndexInclusive + 1 - _minIndexInclusive.
  74. // Note that it is possible for _maxIndexInclusive to be int.MaxValue here,
  75. // so + 1 may result in signed integer overflow. We need to handle this.
  76. // At the same time, however, we are guaranteed that our max count can fit
  77. // in an int because if that is true, then _minIndexInclusive must > 0.
  78. var count = await SkipAndCountAsync((uint)_maxIndexInclusive + 1, en).ConfigureAwait(false);
  79. Debug.Assert(count != (uint)int.MaxValue + 1 || _minIndexInclusive > 0, "Our return value will be incorrect.");
  80. return Math.Max((int)count - _minIndexInclusive, 0);
  81. }
  82. finally
  83. {
  84. await en.DisposeAsync().ConfigureAwait(false);
  85. }
  86. }
  87. }
  88. private bool _hasSkipped;
  89. private int _taken;
  90. protected override async ValueTask<bool> MoveNextCore()
  91. {
  92. switch (_state)
  93. {
  94. case AsyncIteratorState.Allocated:
  95. _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
  96. _hasSkipped = false;
  97. _taken = 0;
  98. _state = AsyncIteratorState.Iterating;
  99. goto case AsyncIteratorState.Iterating;
  100. case AsyncIteratorState.Iterating:
  101. if (!_hasSkipped)
  102. {
  103. if (!await SkipBeforeFirstAsync(_enumerator!).ConfigureAwait(false))
  104. {
  105. // Reached the end before we finished skipping.
  106. break;
  107. }
  108. _hasSkipped = true;
  109. }
  110. if ((!HasLimit || _taken < Limit) && await _enumerator!.MoveNextAsync().ConfigureAwait(false))
  111. {
  112. if (HasLimit)
  113. {
  114. // If we are taking an unknown number of elements, it's important not to increment _state.
  115. // _state - 3 may eventually end up overflowing & we'll hit the Dispose branch even though
  116. // we haven't finished enumerating.
  117. _taken++;
  118. }
  119. _current = _enumerator.Current;
  120. return true;
  121. }
  122. break;
  123. }
  124. await DisposeAsync().ConfigureAwait(false);
  125. return false;
  126. }
  127. #if NOTYET
  128. public override IAsyncEnumerable<TResult> Select<TResult>(Func<TSource, TResult> selector)
  129. {
  130. return new SelectIPartitionIterator<TSource, TResult>(this, selector);
  131. }
  132. public override IAsyncEnumerable<TResult> Select<TResult>(Func<TSource, ValueTask<TResult>> selector)
  133. {
  134. return new SelectIPartitionIterator<TSource, TResult>(this, selector);
  135. }
  136. #endif
  137. public IAsyncPartition<TSource> Skip(int count)
  138. {
  139. var minIndex = _minIndexInclusive + count;
  140. if (!HasLimit)
  141. {
  142. if (minIndex < 0)
  143. {
  144. // If we don't know our max count and minIndex can no longer fit in a positive int,
  145. // then we will need to wrap ourselves in another iterator.
  146. // This can happen, for example, during e.Skip(int.MaxValue).Skip(int.MaxValue).
  147. return new AsyncEnumerablePartition<TSource>(this, count, -1);
  148. }
  149. }
  150. else if ((uint)minIndex > (uint)_maxIndexInclusive)
  151. {
  152. // If minIndex overflows and we have an upper bound, we will go down this branch.
  153. // We know our upper bound must be smaller than minIndex, since our upper bound fits in an int.
  154. // This branch should not be taken if we don't have a bound.
  155. return AsyncEnumerable.EmptyAsyncIterator<TSource>.Instance;
  156. }
  157. Debug.Assert(minIndex >= 0, $"We should have taken care of all cases when {nameof(minIndex)} overflows.");
  158. return new AsyncEnumerablePartition<TSource>(_source, minIndex, _maxIndexInclusive);
  159. }
  160. public IAsyncPartition<TSource> Take(int count)
  161. {
  162. var maxIndex = _minIndexInclusive + count - 1;
  163. if (!HasLimit)
  164. {
  165. if (maxIndex < 0)
  166. {
  167. // If we don't know our max count and maxIndex can no longer fit in a positive int,
  168. // then we will need to wrap ourselves in another iterator.
  169. // Note that although maxIndex may be too large, the difference between it and
  170. // _minIndexInclusive (which is count - 1) must fit in an int.
  171. // Example: e.Skip(50).Take(int.MaxValue).
  172. return new AsyncEnumerablePartition<TSource>(this, 0, count - 1);
  173. }
  174. }
  175. else if ((uint)maxIndex >= (uint)_maxIndexInclusive)
  176. {
  177. // If we don't know our max count, we can't go down this branch.
  178. // It's always possible for us to contain more than count items, as the rest
  179. // of the enumerable past _minIndexInclusive can be arbitrarily long.
  180. return this;
  181. }
  182. Debug.Assert(maxIndex >= 0, $"We should have taken care of all cases when {nameof(maxIndex)} overflows.");
  183. return new AsyncEnumerablePartition<TSource>(_source, _minIndexInclusive, maxIndex);
  184. }
  185. public async ValueTask<Maybe<TSource>> TryGetElementAtAsync(int index, CancellationToken cancellationToken)
  186. {
  187. // If the index is negative or >= our max count, return early.
  188. if (index >= 0 && (!HasLimit || index < Limit))
  189. {
  190. var en = _source.GetAsyncEnumerator(cancellationToken);
  191. try
  192. {
  193. Debug.Assert(_minIndexInclusive + index >= 0, $"Adding {nameof(index)} caused {nameof(_minIndexInclusive)} to overflow.");
  194. if (await SkipBeforeAsync(_minIndexInclusive + index, en).ConfigureAwait(false) && await en.MoveNextAsync().ConfigureAwait(false))
  195. {
  196. return new Maybe<TSource>(en.Current);
  197. }
  198. }
  199. finally
  200. {
  201. await en.DisposeAsync().ConfigureAwait(false);
  202. }
  203. }
  204. return new Maybe<TSource>();
  205. }
  206. public async ValueTask<Maybe<TSource>> TryGetFirstAsync(CancellationToken cancellationToken)
  207. {
  208. var en = _source.GetAsyncEnumerator(cancellationToken);
  209. try
  210. {
  211. if (await SkipBeforeFirstAsync(en).ConfigureAwait(false) && await en.MoveNextAsync().ConfigureAwait(false))
  212. {
  213. return new Maybe<TSource>(en.Current);
  214. }
  215. }
  216. finally
  217. {
  218. await en.DisposeAsync().ConfigureAwait(false);
  219. }
  220. return new Maybe<TSource>();
  221. }
  222. public async ValueTask<Maybe<TSource>> TryGetLastAsync(CancellationToken cancellationToken)
  223. {
  224. var en = _source.GetAsyncEnumerator(cancellationToken);
  225. try
  226. {
  227. if (await SkipBeforeFirstAsync(en).ConfigureAwait(false) && await en.MoveNextAsync().ConfigureAwait(false))
  228. {
  229. var remaining = Limit - 1; // Max number of items left, not counting the current element.
  230. var comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true.
  231. TSource result;
  232. do
  233. {
  234. remaining--;
  235. result = en.Current;
  236. }
  237. while (remaining >= comparand && await en.MoveNextAsync().ConfigureAwait(false));
  238. return new Maybe<TSource>(result);
  239. }
  240. }
  241. finally
  242. {
  243. await en.DisposeAsync().ConfigureAwait(false);
  244. }
  245. return new Maybe<TSource>();
  246. }
  247. public async ValueTask<TSource[]> ToArrayAsync(CancellationToken cancellationToken)
  248. {
  249. var en = _source.GetAsyncEnumerator(cancellationToken);
  250. try
  251. {
  252. if (await SkipBeforeFirstAsync(en).ConfigureAwait(false) && await en.MoveNextAsync().ConfigureAwait(false))
  253. {
  254. var remaining = Limit - 1; // Max number of items left, not counting the current element.
  255. var comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true.
  256. // REVIEW: If this ends up in corefx, the code below can use LargeArrayBuilder<T>.
  257. var builder = HasLimit ? new List<TSource>(Limit) : new List<TSource>();
  258. do
  259. {
  260. remaining--;
  261. builder.Add(en.Current);
  262. }
  263. while (remaining >= comparand && await en.MoveNextAsync().ConfigureAwait(false));
  264. return builder.ToArray();
  265. }
  266. }
  267. finally
  268. {
  269. await en.DisposeAsync().ConfigureAwait(false);
  270. }
  271. #if NO_ARRAY_EMPTY
  272. return EmptyArray<TSource>.Value;
  273. #else
  274. return Array.Empty<TSource>();
  275. #endif
  276. }
  277. public async ValueTask<List<TSource>> ToListAsync(CancellationToken cancellationToken)
  278. {
  279. var list = new List<TSource>();
  280. var en = _source.GetAsyncEnumerator(cancellationToken);
  281. try
  282. {
  283. if (await SkipBeforeFirstAsync(en).ConfigureAwait(false) && await en.MoveNextAsync().ConfigureAwait(false))
  284. {
  285. var remaining = Limit - 1; // Max number of items left, not counting the current element.
  286. var comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true.
  287. do
  288. {
  289. remaining--;
  290. list.Add(en.Current);
  291. }
  292. while (remaining >= comparand && await en.MoveNextAsync().ConfigureAwait(false));
  293. }
  294. }
  295. finally
  296. {
  297. await en.DisposeAsync().ConfigureAwait(false);
  298. }
  299. return list;
  300. }
  301. private ValueTask<bool> SkipBeforeFirstAsync(IAsyncEnumerator<TSource> en) => SkipBeforeAsync(_minIndexInclusive, en);
  302. private static async ValueTask<bool> SkipBeforeAsync(int index, IAsyncEnumerator<TSource> en)
  303. {
  304. var n = await SkipAndCountAsync(index, en).ConfigureAwait(false);
  305. return n == index;
  306. }
  307. private static async ValueTask<int> SkipAndCountAsync(int index, IAsyncEnumerator<TSource> en)
  308. {
  309. Debug.Assert(index >= 0);
  310. return (int)await SkipAndCountAsync((uint)index, en).ConfigureAwait(false);
  311. }
  312. private static async ValueTask<uint> SkipAndCountAsync(uint index, IAsyncEnumerator<TSource> en)
  313. {
  314. for (uint i = 0; i < index; i++)
  315. {
  316. if (!await en.MoveNextAsync().ConfigureAwait(false))
  317. {
  318. return i;
  319. }
  320. }
  321. return index;
  322. }
  323. }
  324. }