// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. // Copied from https://github.com/dotnet/corefx/blob/5f1dd8298e4355b63bb760d88d437a91b3ca808c/src/System.Linq/src/System/Linq/Partition.cs using System.Collections.Generic; using System.Diagnostics; using System.Threading; using System.Threading.Tasks; namespace System.Linq { /// /// An iterator that yields the items of part of an . /// /// The type of the source enumerable. internal sealed class AsyncEnumerablePartition : AsyncIterator, IAsyncPartition { private readonly IAsyncEnumerable _source; private readonly int _minIndexInclusive; private readonly int _maxIndexInclusive; // -1 if we want everything past _minIndexInclusive. // If this is -1, it's impossible to set a limit on the count. private IAsyncEnumerator _enumerator; internal AsyncEnumerablePartition(IAsyncEnumerable source, int minIndexInclusive, int maxIndexInclusive) { Debug.Assert(source != null); Debug.Assert(!(source is IList), $"The caller needs to check for {nameof(IList)}."); Debug.Assert(minIndexInclusive >= 0); Debug.Assert(maxIndexInclusive >= -1); // Note that although maxIndexInclusive can't grow, it can still be int.MaxValue. // We support partitioning enumerables with > 2B elements. For example, e.Skip(1).Take(int.MaxValue) should work. // But if it is int.MaxValue, then minIndexInclusive must != 0. Otherwise, our count may overflow. Debug.Assert(maxIndexInclusive == -1 || (maxIndexInclusive - minIndexInclusive < int.MaxValue), $"{nameof(Limit)} will overflow!"); Debug.Assert(maxIndexInclusive == -1 || minIndexInclusive <= maxIndexInclusive); _source = source; _minIndexInclusive = minIndexInclusive; _maxIndexInclusive = maxIndexInclusive; } // If this is true (e.g. at least one Take call was made), then we have an upper bound // on how many elements we can have. private bool HasLimit => _maxIndexInclusive != -1; private int Limit => (_maxIndexInclusive + 1) - _minIndexInclusive; // This is that upper bound. public override AsyncIterator Clone() { return new AsyncEnumerablePartition(_source, _minIndexInclusive, _maxIndexInclusive); } public override async ValueTask DisposeAsync() { if (_enumerator != null) { await _enumerator.DisposeAsync().ConfigureAwait(false); _enumerator = null; } await base.DisposeAsync().ConfigureAwait(false); } public async Task GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken) { if (onlyIfCheap) { return -1; } if (!HasLimit) { // If HasLimit is false, we contain everything past _minIndexInclusive. // Therefore, we have to iterate the whole enumerable. return Math.Max(await _source.Count(cancellationToken).ConfigureAwait(false) - _minIndexInclusive, 0); } var en = _source.GetAsyncEnumerator(cancellationToken); try { // We only want to iterate up to _maxIndexInclusive + 1. // Past that, we know the enumerable will be able to fit this partition, // so the count will just be _maxIndexInclusive + 1 - _minIndexInclusive. // Note that it is possible for _maxIndexInclusive to be int.MaxValue here, // so + 1 may result in signed integer overflow. We need to handle this. // At the same time, however, we are guaranteed that our max count can fit // in an int because if that is true, then _minIndexInclusive must > 0. var count = await SkipAndCountAsync((uint)_maxIndexInclusive + 1, en, cancellationToken).ConfigureAwait(false); Debug.Assert(count != (uint)int.MaxValue + 1 || _minIndexInclusive > 0, "Our return value will be incorrect."); return Math.Max((int)count - _minIndexInclusive, 0); } finally { await en.DisposeAsync().ConfigureAwait(false); } } private bool _hasSkipped; private int _taken; protected override async ValueTask MoveNextCore(CancellationToken cancellationToken) { switch (state) { case AsyncIteratorState.Allocated: _enumerator = _source.GetAsyncEnumerator(cancellationToken); _hasSkipped = false; _taken = 0; state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: if (!_hasSkipped) { if (!await SkipBeforeFirstAsync(_enumerator, CancellationToken.None).ConfigureAwait(false)) { // Reached the end before we finished skipping. break; } _hasSkipped = true; } if ((!HasLimit || _taken < Limit) && await _enumerator.MoveNextAsync().ConfigureAwait(false)) { if (HasLimit) { // If we are taking an unknown number of elements, it's important not to increment _state. // _state - 3 may eventually end up overflowing & we'll hit the Dispose branch even though // we haven't finished enumerating. _taken++; } current = _enumerator.Current; return true; } break; } await DisposeAsync().ConfigureAwait(false); return false; } #if NOTYET public override IAsyncEnumerable Select(Func selector) { return new SelectIPartitionIterator(this, selector); } public override IAsyncEnumerable Select(Func> selector) { return new SelectIPartitionIterator(this, selector); } #endif public IAsyncPartition Skip(int count) { var minIndex = _minIndexInclusive + count; if (!HasLimit) { if (minIndex < 0) { // If we don't know our max count and minIndex can no longer fit in a positive int, // then we will need to wrap ourselves in another iterator. // This can happen, for example, during e.Skip(int.MaxValue).Skip(int.MaxValue). return new AsyncEnumerablePartition(this, count, -1); } } else if ((uint)minIndex > (uint)_maxIndexInclusive) { // If minIndex overflows and we have an upper bound, we will go down this branch. // We know our upper bound must be smaller than minIndex, since our upper bound fits in an int. // This branch should not be taken if we don't have a bound. return AsyncEnumerable.EmptyAsyncIterator.Instance; } Debug.Assert(minIndex >= 0, $"We should have taken care of all cases when {nameof(minIndex)} overflows."); return new AsyncEnumerablePartition(_source, minIndex, _maxIndexInclusive); } public IAsyncPartition Take(int count) { var maxIndex = _minIndexInclusive + count - 1; if (!HasLimit) { if (maxIndex < 0) { // If we don't know our max count and maxIndex can no longer fit in a positive int, // then we will need to wrap ourselves in another iterator. // Note that although maxIndex may be too large, the difference between it and // _minIndexInclusive (which is count - 1) must fit in an int. // Example: e.Skip(50).Take(int.MaxValue). return new AsyncEnumerablePartition(this, 0, count - 1); } } else if ((uint)maxIndex >= (uint)_maxIndexInclusive) { // If we don't know our max count, we can't go down this branch. // It's always possible for us to contain more than count items, as the rest // of the enumerable past _minIndexInclusive can be arbitrarily long. return this; } Debug.Assert(maxIndex >= 0, $"We should have taken care of all cases when {nameof(maxIndex)} overflows."); return new AsyncEnumerablePartition(_source, _minIndexInclusive, maxIndex); } public async Task> TryGetElementAsync(int index, CancellationToken cancellationToken) { // If the index is negative or >= our max count, return early. if (index >= 0 && (!HasLimit || index < Limit)) { var en = _source.GetAsyncEnumerator(cancellationToken); try { Debug.Assert(_minIndexInclusive + index >= 0, $"Adding {nameof(index)} caused {nameof(_minIndexInclusive)} to overflow."); if (await SkipBeforeAsync(_minIndexInclusive + index, en, cancellationToken).ConfigureAwait(false) && await en.MoveNextAsync().ConfigureAwait(false)) { return new Maybe(en.Current); } } finally { await en.DisposeAsync().ConfigureAwait(false); } } return new Maybe(); } public async Task> TryGetFirstAsync(CancellationToken cancellationToken) { var en = _source.GetAsyncEnumerator(cancellationToken); try { if (await SkipBeforeFirstAsync(en, cancellationToken).ConfigureAwait(false) && await en.MoveNextAsync().ConfigureAwait(false)) { return new Maybe(en.Current); } } finally { await en.DisposeAsync().ConfigureAwait(false); } return new Maybe(); } public async Task> TryGetLastAsync(CancellationToken cancellationToken) { var en = _source.GetAsyncEnumerator(cancellationToken); try { if (await SkipBeforeFirstAsync(en, cancellationToken).ConfigureAwait(false) && await en.MoveNextAsync().ConfigureAwait(false)) { var remaining = Limit - 1; // Max number of items left, not counting the current element. var comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true. TSource result; do { remaining--; result = en.Current; } while (remaining >= comparand && await en.MoveNextAsync().ConfigureAwait(false)); return new Maybe(result); } } finally { await en.DisposeAsync().ConfigureAwait(false); } return new Maybe(); } public async Task ToArrayAsync(CancellationToken cancellationToken) { var en = _source.GetAsyncEnumerator(cancellationToken); try { if (await SkipBeforeFirstAsync(en, cancellationToken).ConfigureAwait(false) && await en.MoveNextAsync().ConfigureAwait(false)) { var remaining = Limit - 1; // Max number of items left, not counting the current element. var comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true. var maxCapacity = HasLimit ? Limit : int.MaxValue; var builder = new List(maxCapacity); do { remaining--; builder.Add(en.Current); } while (remaining >= comparand && await en.MoveNextAsync().ConfigureAwait(false)); return builder.ToArray(); } } finally { await en.DisposeAsync().ConfigureAwait(false); } #if NO_ARRAY_EMPTY return EmptyArray.Value; #else return Array.Empty(); #endif } public async Task> ToListAsync(CancellationToken cancellationToken) { var list = new List(); var en = _source.GetAsyncEnumerator(cancellationToken); try { if (await SkipBeforeFirstAsync(en, cancellationToken).ConfigureAwait(false) && await en.MoveNextAsync().ConfigureAwait(false)) { var remaining = Limit - 1; // Max number of items left, not counting the current element. var comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true. do { remaining--; list.Add(en.Current); } while (remaining >= comparand && await en.MoveNextAsync().ConfigureAwait(false)); } } finally { await en.DisposeAsync().ConfigureAwait(false); } return list; } private Task SkipBeforeFirstAsync(IAsyncEnumerator en, CancellationToken cancellationToken) => SkipBeforeAsync(_minIndexInclusive, en, cancellationToken); private static async Task SkipBeforeAsync(int index, IAsyncEnumerator en, CancellationToken cancellationToken) { var n = await SkipAndCountAsync(index, en, cancellationToken).ConfigureAwait(false); return n == index; } private static async Task SkipAndCountAsync(int index, IAsyncEnumerator en, CancellationToken cancellationToken) { Debug.Assert(index >= 0); return (int)await SkipAndCountAsync((uint)index, en, cancellationToken).ConfigureAwait(false); } private static async Task SkipAndCountAsync(uint index, IAsyncEnumerator en, CancellationToken cancellationToken) { Debug.Assert(en != null); for (uint i = 0; i < index; i++) { if (!await en.MoveNextAsync().ConfigureAwait(false)) { return i; } } return index; } } }