// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System.Collections.Generic; using System.Diagnostics; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerable { /// /// Concatenates the second async-enumerable sequence to the first async-enumerable sequence upon successful termination of the first. /// /// The type of the elements in the source sequences. /// First async-enumerable sequence. /// Second async-enumerable sequence. /// An async-enumerable sequence that contains the elements of the first sequence, followed by those of the second the sequence. /// or is null. public static IAsyncEnumerable Concat(this IAsyncEnumerable first, IAsyncEnumerable second) { if (first == null) throw Error.ArgumentNull(nameof(first)); if (second == null) throw Error.ArgumentNull(nameof(second)); return first is ConcatAsyncIterator concatFirst ? concatFirst.Concat(second) : new Concat2AsyncIterator(first, second); } private sealed class Concat2AsyncIterator : ConcatAsyncIterator { private readonly IAsyncEnumerable _first; private readonly IAsyncEnumerable _second; internal Concat2AsyncIterator(IAsyncEnumerable first, IAsyncEnumerable second) { _first = first; _second = second; } public override AsyncIteratorBase Clone() { return new Concat2AsyncIterator(_first, _second); } internal override ConcatAsyncIterator Concat(IAsyncEnumerable next) { return new ConcatNAsyncIterator(this, next, 2); } internal override IAsyncEnumerable? GetAsyncEnumerable(int index) { return index switch { 0 => _first, 1 => _second, _ => null, }; } } private abstract class ConcatAsyncIterator : AsyncIterator, IAsyncIListProvider { private int _counter; private IAsyncEnumerator? _enumerator; public ValueTask ToArrayAsync(CancellationToken cancellationToken) { return AsyncEnumerableHelpers.ToArray(this, cancellationToken); } public async ValueTask> ToListAsync(CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); var list = new List(); for (var i = 0; ; i++) { var source = GetAsyncEnumerable(i); if (source == null) { break; } await foreach (var item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) { list.Add(item); } } return list; } public ValueTask GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken) { if (onlyIfCheap) { return new ValueTask(-1); } return Core(); async ValueTask Core() { cancellationToken.ThrowIfCancellationRequested(); var count = 0; for (var i = 0; ; i++) { var source = GetAsyncEnumerable(i); if (source == null) { break; } checked { count += await source.CountAsync(cancellationToken).ConfigureAwait(false); } } return count; } } public override async ValueTask DisposeAsync() { if (_enumerator != null) { await _enumerator.DisposeAsync().ConfigureAwait(false); _enumerator = null; } await base.DisposeAsync().ConfigureAwait(false); } protected override async ValueTask MoveNextCore() { if (_state == AsyncIteratorState.Allocated) { _enumerator = GetAsyncEnumerable(0)!.GetAsyncEnumerator(_cancellationToken); _state = AsyncIteratorState.Iterating; _counter = 2; } if (_state == AsyncIteratorState.Iterating) { while (true) { if (await _enumerator!.MoveNextAsync().ConfigureAwait(false)) { _current = _enumerator.Current; return true; } // // NB: This is simply to match the logic of // https://github.com/dotnet/corefx/blob/f7539b726c4bc2385b7f49e5751c1cff2f2c7368/src/System.Linq/src/System/Linq/Concat.cs#L240 // var next = GetAsyncEnumerable(_counter++ - 1); if (next != null) { await _enumerator.DisposeAsync().ConfigureAwait(false); _enumerator = next.GetAsyncEnumerator(_cancellationToken); continue; } await DisposeAsync().ConfigureAwait(false); break; } } return false; } internal abstract ConcatAsyncIterator Concat(IAsyncEnumerable next); internal abstract IAsyncEnumerable? GetAsyncEnumerable(int index); } // To handle chains of >= 3 sources, we chain the concat iterators together and allow // GetEnumerable to fetch enumerables from the previous sources. This means that rather // than each MoveNext/Current calls having to traverse all of the previous sources, we // only have to traverse all of the previous sources once per chained enumerable. An // alternative would be to use an array to store all of the enumerables, but this has // a much better memory profile and without much additional run-time cost. private sealed class ConcatNAsyncIterator : ConcatAsyncIterator { private readonly IAsyncEnumerable _next; private readonly int _nextIndex; private readonly ConcatAsyncIterator _previousConcat; internal ConcatNAsyncIterator(ConcatAsyncIterator previousConcat, IAsyncEnumerable next, int nextIndex) { Debug.Assert(nextIndex >= 2); _previousConcat = previousConcat; _next = next; _nextIndex = nextIndex; } public override AsyncIteratorBase Clone() { return new ConcatNAsyncIterator(_previousConcat, _next, _nextIndex); } internal override ConcatAsyncIterator Concat(IAsyncEnumerable next) { if (_nextIndex == int.MaxValue - 2) { // In the unlikely case of this many concatenations, if we produced a ConcatNIterator // with int.MaxValue then state would overflow before it matched it's index. // So we use the naïve approach of just having a left and right sequence. return new Concat2AsyncIterator(this, next); } return new ConcatNAsyncIterator(this, next, _nextIndex + 1); } internal override IAsyncEnumerable? GetAsyncEnumerable(int index) { if (index > _nextIndex) { return null; } // Walk back through the chain of ConcatNIterators looking for the one // that has its _nextIndex equal to index. If we don't find one, then it // must be prior to any of them, so call GetEnumerable on the previous // Concat2Iterator. This avoids a deep recursive call chain. var current = this; while (true) { if (index == current._nextIndex) { return current._next; } if (current._previousConcat is ConcatNAsyncIterator prevN) { current = prevN; continue; } Debug.Assert(current._previousConcat is Concat2AsyncIterator); Debug.Assert(index == 0 || index == 1); return current._previousConcat.GetAsyncEnumerable(index); } } } } }