// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT License.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
namespace System.Linq
{
public static partial class AsyncEnumerable
{
#if INCLUDE_SYSTEM_LINQ_ASYNCENUMERABLE_DUPLICATES
// https://learn.microsoft.com/en-us/dotnet/api/system.linq.asyncenumerable.concat?view=net-9.0-pp
///
/// Concatenates the second async-enumerable sequence to the first async-enumerable sequence upon successful termination of the first.
///
/// The type of the elements in the source sequences.
/// First async-enumerable sequence.
/// Second async-enumerable sequence.
/// An async-enumerable sequence that contains the elements of the first sequence, followed by those of the second the sequence.
/// or is null.
public static IAsyncEnumerable Concat(this IAsyncEnumerable first, IAsyncEnumerable second)
{
if (first == null)
throw Error.ArgumentNull(nameof(first));
if (second == null)
throw Error.ArgumentNull(nameof(second));
return first is ConcatAsyncIterator concatFirst ?
concatFirst.Concat(second) :
new Concat2AsyncIterator(first, second);
}
private sealed class Concat2AsyncIterator : ConcatAsyncIterator
{
private readonly IAsyncEnumerable _first;
private readonly IAsyncEnumerable _second;
internal Concat2AsyncIterator(IAsyncEnumerable first, IAsyncEnumerable second)
{
_first = first;
_second = second;
}
public override AsyncIteratorBase Clone()
{
return new Concat2AsyncIterator(_first, _second);
}
internal override ConcatAsyncIterator Concat(IAsyncEnumerable next)
{
return new ConcatNAsyncIterator(this, next, 2);
}
internal override IAsyncEnumerable? GetAsyncEnumerable(int index)
{
return index switch
{
0 => _first,
1 => _second,
_ => null,
};
}
}
private abstract class ConcatAsyncIterator : AsyncIterator, IAsyncIListProvider
{
private int _counter;
private IAsyncEnumerator? _enumerator;
public ValueTask ToArrayAsync(CancellationToken cancellationToken)
{
return AsyncEnumerableHelpers.ToArray(this, cancellationToken);
}
public async ValueTask> ToListAsync(CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
var list = new List();
for (var i = 0; ; i++)
{
var source = GetAsyncEnumerable(i);
if (source == null)
{
break;
}
await foreach (var item in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
list.Add(item);
}
}
return list;
}
public ValueTask GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
{
if (onlyIfCheap)
{
return new ValueTask(-1);
}
return Core();
async ValueTask Core()
{
cancellationToken.ThrowIfCancellationRequested();
var count = 0;
for (var i = 0; ; i++)
{
var source = GetAsyncEnumerable(i);
if (source == null)
{
break;
}
checked
{
count += await source.CountAsync(cancellationToken).ConfigureAwait(false);
}
}
return count;
}
}
public override async ValueTask DisposeAsync()
{
if (_enumerator != null)
{
await _enumerator.DisposeAsync().ConfigureAwait(false);
_enumerator = null;
}
await base.DisposeAsync().ConfigureAwait(false);
}
protected override async ValueTask MoveNextCore()
{
if (_state == AsyncIteratorState.Allocated)
{
_enumerator = GetAsyncEnumerable(0)!.GetAsyncEnumerator(_cancellationToken);
_state = AsyncIteratorState.Iterating;
_counter = 2;
}
if (_state == AsyncIteratorState.Iterating)
{
while (true)
{
if (await _enumerator!.MoveNextAsync().ConfigureAwait(false))
{
_current = _enumerator.Current;
return true;
}
//
// NB: This is simply to match the logic of
// https://github.com/dotnet/corefx/blob/f7539b726c4bc2385b7f49e5751c1cff2f2c7368/src/System.Linq/src/System/Linq/Concat.cs#L240
//
var next = GetAsyncEnumerable(_counter++ - 1);
if (next != null)
{
await _enumerator.DisposeAsync().ConfigureAwait(false);
_enumerator = next.GetAsyncEnumerator(_cancellationToken);
continue;
}
await DisposeAsync().ConfigureAwait(false);
break;
}
}
return false;
}
internal abstract ConcatAsyncIterator Concat(IAsyncEnumerable next);
internal abstract IAsyncEnumerable? GetAsyncEnumerable(int index);
}
// To handle chains of >= 3 sources, we chain the concat iterators together and allow
// GetEnumerable to fetch enumerables from the previous sources. This means that rather
// than each MoveNext/Current calls having to traverse all of the previous sources, we
// only have to traverse all of the previous sources once per chained enumerable. An
// alternative would be to use an array to store all of the enumerables, but this has
// a much better memory profile and without much additional run-time cost.
private sealed class ConcatNAsyncIterator : ConcatAsyncIterator
{
private readonly IAsyncEnumerable _next;
private readonly int _nextIndex;
private readonly ConcatAsyncIterator _previousConcat;
internal ConcatNAsyncIterator(ConcatAsyncIterator previousConcat, IAsyncEnumerable next, int nextIndex)
{
Debug.Assert(nextIndex >= 2);
_previousConcat = previousConcat;
_next = next;
_nextIndex = nextIndex;
}
public override AsyncIteratorBase Clone()
{
return new ConcatNAsyncIterator(_previousConcat, _next, _nextIndex);
}
internal override ConcatAsyncIterator Concat(IAsyncEnumerable next)
{
if (_nextIndex == int.MaxValue - 2)
{
// In the unlikely case of this many concatenations, if we produced a ConcatNIterator
// with int.MaxValue then state would overflow before it matched it's index.
// So we use the naïve approach of just having a left and right sequence.
return new Concat2AsyncIterator(this, next);
}
return new ConcatNAsyncIterator(this, next, _nextIndex + 1);
}
internal override IAsyncEnumerable? GetAsyncEnumerable(int index)
{
if (index > _nextIndex)
{
return null;
}
// Walk back through the chain of ConcatNIterators looking for the one
// that has its _nextIndex equal to index. If we don't find one, then it
// must be prior to any of them, so call GetEnumerable on the previous
// Concat2Iterator. This avoids a deep recursive call chain.
var current = this;
while (true)
{
if (index == current._nextIndex)
{
return current._next;
}
if (current._previousConcat is ConcatNAsyncIterator prevN)
{
current = prevN;
continue;
}
Debug.Assert(current._previousConcat is Concat2AsyncIterator);
Debug.Assert(index == 0 || index == 1);
return current._previousConcat.GetAsyncEnumerable(index);
}
}
}
#endif // INCLUDE_SYSTEM_LINQ_ASYNCENUMERABLE_DUPLICATES
}
}