// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System.Collections.Generic; using System.Diagnostics; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerable { public static IAsyncEnumerable Concat(this IAsyncEnumerable first, IAsyncEnumerable second) { if (first == null) throw Error.ArgumentNull(nameof(first)); if (second == null) throw Error.ArgumentNull(nameof(second)); return first is ConcatAsyncIterator concatFirst ? concatFirst.Concat(second) : new Concat2AsyncIterator(first, second); } private sealed class Concat2AsyncIterator : ConcatAsyncIterator { private readonly IAsyncEnumerable _first; private readonly IAsyncEnumerable _second; internal Concat2AsyncIterator(IAsyncEnumerable first, IAsyncEnumerable second) { Debug.Assert(first != null); Debug.Assert(second != null); _first = first; _second = second; } public override AsyncIteratorBase Clone() { return new Concat2AsyncIterator(_first, _second); } internal override ConcatAsyncIterator Concat(IAsyncEnumerable next) { return new ConcatNAsyncIterator(this, next, 2); } internal override IAsyncEnumerable GetAsyncEnumerable(int index) { switch (index) { case 0: return _first; case 1: return _second; default: return null; } } } private abstract class ConcatAsyncIterator : AsyncIterator, IAsyncIListProvider { private int _counter; private IAsyncEnumerator _enumerator; public ValueTask ToArrayAsync(CancellationToken cancellationToken) { return AsyncEnumerableHelpers.ToArray(this, cancellationToken); } public async ValueTask> ToListAsync(CancellationToken cancellationToken) { var list = new List(); for (var i = 0; ; i++) { var source = GetAsyncEnumerable(i); if (source == null) { break; } var e = source.GetAsyncEnumerator(cancellationToken); try { while (await e.MoveNextAsync().ConfigureAwait(false)) { list.Add(e.Current); } } finally { await e.DisposeAsync().ConfigureAwait(false); } } return list; } public ValueTask GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken) { if (onlyIfCheap) { return new ValueTask(-1); } return Core(); async ValueTask Core() { var count = 0; for (var i = 0; ; i++) { var source = GetAsyncEnumerable(i); if (source == null) { break; } checked { count += await source.CountAsync(cancellationToken).ConfigureAwait(false); } } return count; } } public override async ValueTask DisposeAsync() { if (_enumerator != null) { await _enumerator.DisposeAsync().ConfigureAwait(false); _enumerator = null; } await base.DisposeAsync().ConfigureAwait(false); } protected override async ValueTask MoveNextCore() { if (_state == AsyncIteratorState.Allocated) { _enumerator = GetAsyncEnumerable(0).GetAsyncEnumerator(_cancellationToken); _state = AsyncIteratorState.Iterating; _counter = 2; } if (_state == AsyncIteratorState.Iterating) { while (true) { if (await _enumerator.MoveNextAsync().ConfigureAwait(false)) { _current = _enumerator.Current; return true; } // note, this is simply to match the logic of // https://github.com/dotnet/corefx/blob/ec2685715b01d12f16b08d0dfa326649b12db8ec/src/system.linq/src/system/linq/concatenate.cs#L173-L173 var next = GetAsyncEnumerable(_counter++ - 1); if (next != null) { await _enumerator.DisposeAsync().ConfigureAwait(false); _enumerator = next.GetAsyncEnumerator(_cancellationToken); continue; } await DisposeAsync().ConfigureAwait(false); break; } } return false; } internal abstract ConcatAsyncIterator Concat(IAsyncEnumerable next); internal abstract IAsyncEnumerable GetAsyncEnumerable(int index); } // To handle chains of >= 3 sources, we chain the concat iterators together and allow // GetEnumerable to fetch enumerables from the previous sources. This means that rather // than each MoveNext/Current calls having to traverse all of the previous sources, we // only have to traverse all of the previous sources once per chained enumerable. An // alternative would be to use an array to store all of the enumerables, but this has // a much better memory profile and without much additional run-time cost. private sealed class ConcatNAsyncIterator : ConcatAsyncIterator { private readonly IAsyncEnumerable _next; private readonly int _nextIndex; private readonly ConcatAsyncIterator _previousConcat; internal ConcatNAsyncIterator(ConcatAsyncIterator previousConcat, IAsyncEnumerable next, int nextIndex) { Debug.Assert(previousConcat != null); Debug.Assert(next != null); Debug.Assert(nextIndex >= 2); _previousConcat = previousConcat; _next = next; _nextIndex = nextIndex; } public override AsyncIteratorBase Clone() { return new ConcatNAsyncIterator(_previousConcat, _next, _nextIndex); } internal override ConcatAsyncIterator Concat(IAsyncEnumerable next) { if (_nextIndex == int.MaxValue - 2) { // In the unlikely case of this many concatenations, if we produced a ConcatNIterator // with int.MaxValue then state would overflow before it matched it's index. // So we use the naïve approach of just having a left and right sequence. return new Concat2AsyncIterator(this, next); } return new ConcatNAsyncIterator(this, next, _nextIndex + 1); } internal override IAsyncEnumerable GetAsyncEnumerable(int index) { if (index > _nextIndex) { return null; } // Walk back through the chain of ConcatNIterators looking for the one // that has its _nextIndex equal to index. If we don't find one, then it // must be prior to any of them, so call GetEnumerable on the previous // Concat2Iterator. This avoids a deep recursive call chain. var current = this; while (true) { if (index == current._nextIndex) { return current._next; } if (current._previousConcat is ConcatNAsyncIterator prevN) { current = prevN; continue; } Debug.Assert(current._previousConcat is Concat2AsyncIterator); Debug.Assert(index == 0 || index == 1); return current._previousConcat.GetAsyncEnumerable(index); } } } } }