// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System.Collections.Generic; using System.Diagnostics; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerable { public static IAsyncEnumerable Concat(this IAsyncEnumerable first, IAsyncEnumerable second) { if (first == null) throw new ArgumentNullException(nameof(first)); if (second == null) throw new ArgumentNullException(nameof(second)); var concatFirst = first as ConcatAsyncIterator; return concatFirst != null ? concatFirst.Concat(second) : new Concat2AsyncIterator(first, second); } private sealed class Concat2AsyncIterator : ConcatAsyncIterator { private readonly IAsyncEnumerable first; private readonly IAsyncEnumerable second; internal Concat2AsyncIterator(IAsyncEnumerable first, IAsyncEnumerable second) { Debug.Assert(first != null); Debug.Assert(second != null); this.first = first; this.second = second; } public override AsyncIterator Clone() { return new Concat2AsyncIterator(first, second); } internal override ConcatAsyncIterator Concat(IAsyncEnumerable next) { return new ConcatNAsyncIterator(this, next, 2); } internal override IAsyncEnumerable GetAsyncEnumerable(int index) { switch (index) { case 0: return first; case 1: return second; default: return null; } } } private abstract class ConcatAsyncIterator : AsyncIterator, IAsyncIListProvider { private int counter; private IAsyncEnumerator enumerator; public Task ToArrayAsync(CancellationToken cancellationToken) { return AsyncEnumerableHelpers.ToArray(this, cancellationToken); } public async Task> ToListAsync(CancellationToken cancellationToken) { var list = new List(); for (var i = 0; ; i++) { var source = GetAsyncEnumerable(i); if (source == null) { break; } var e = source.GetAsyncEnumerator(cancellationToken); try { while (await e.MoveNextAsync().ConfigureAwait(false)) { list.Add(e.Current); } } finally { await e.DisposeAsync().ConfigureAwait(false); } } return list; } public async Task GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken) { if (onlyIfCheap) { return -1; } var count = 0; for (var i = 0; ; i++) { var source = GetAsyncEnumerable(i); if (source == null) { break; } checked { count += await source.Count(cancellationToken).ConfigureAwait(false); } } return count; } public override async ValueTask DisposeAsync() { if (enumerator != null) { await enumerator.DisposeAsync().ConfigureAwait(false); enumerator = null; } await base.DisposeAsync().ConfigureAwait(false); } protected override async ValueTask MoveNextCore(CancellationToken cancellationToken) { if (state == AsyncIteratorState.Allocated) { enumerator = GetAsyncEnumerable(0).GetAsyncEnumerator(cancellationToken); state = AsyncIteratorState.Iterating; counter = 2; } if (state == AsyncIteratorState.Iterating) { while (true) { if (await enumerator.MoveNextAsync().ConfigureAwait(false)) { current = enumerator.Current; return true; } // note, this is simply to match the logic of // https://github.com/dotnet/corefx/blob/ec2685715b01d12f16b08d0dfa326649b12db8ec/src/system.linq/src/system/linq/concatenate.cs#L173-L173 var next = GetAsyncEnumerable(counter++ - 1); if (next != null) { await enumerator.DisposeAsync().ConfigureAwait(false); enumerator = next.GetAsyncEnumerator(cancellationToken); continue; } await DisposeAsync().ConfigureAwait(false); break; } } return false; } internal abstract ConcatAsyncIterator Concat(IAsyncEnumerable next); internal abstract IAsyncEnumerable GetAsyncEnumerable(int index); } // To handle chains of >= 3 sources, we chain the concat iterators together and allow // GetEnumerable to fetch enumerables from the previous sources. This means that rather // than each MoveNext/Current calls having to traverse all of the previous sources, we // only have to traverse all of the previous sources once per chained enumerable. An // alternative would be to use an array to store all of the enumerables, but this has // a much better memory profile and without much additional run-time cost. private sealed class ConcatNAsyncIterator : ConcatAsyncIterator { private readonly IAsyncEnumerable next; private readonly int nextIndex; private readonly ConcatAsyncIterator previousConcat; internal ConcatNAsyncIterator(ConcatAsyncIterator previousConcat, IAsyncEnumerable next, int nextIndex) { Debug.Assert(previousConcat != null); Debug.Assert(next != null); Debug.Assert(nextIndex >= 2); this.previousConcat = previousConcat; this.next = next; this.nextIndex = nextIndex; } public override AsyncIterator Clone() { return new ConcatNAsyncIterator(previousConcat, next, nextIndex); } internal override ConcatAsyncIterator Concat(IAsyncEnumerable next) { if (nextIndex == int.MaxValue - 2) { // In the unlikely case of this many concatenations, if we produced a ConcatNIterator // with int.MaxValue then state would overflow before it matched it's index. // So we use the naïve approach of just having a left and right sequence. return new Concat2AsyncIterator(this, next); } return new ConcatNAsyncIterator(this, next, nextIndex + 1); } internal override IAsyncEnumerable GetAsyncEnumerable(int index) { if (index > nextIndex) { return null; } // Walk back through the chain of ConcatNIterators looking for the one // that has its _nextIndex equal to index. If we don't find one, then it // must be prior to any of them, so call GetEnumerable on the previous // Concat2Iterator. This avoids a deep recursive call chain. var current = this; while (true) { if (index == current.nextIndex) { return current.next; } if (current.previousConcat is ConcatNAsyncIterator prevN) { current = prevN; continue; } Debug.Assert(current.previousConcat is Concat2AsyncIterator); Debug.Assert(index == 0 || index == 1); return current.previousConcat.GetAsyncEnumerable(index); } } } } }