// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT License.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
namespace System.Linq
{
public static partial class AsyncEnumerable
{
#if INCLUDE_SYSTEM_LINQ_ASYNCENUMERABLE_DUPLICATES
// https://learn.microsoft.com/en-us/dotnet/api/system.linq.asyncenumerable.union?view=net-9.0-pp
// That one overload covers the next two methods, because it supplieds a default comparer.
///
/// Produces the set union of two sequences by using the default equality comparer.
///
/// The type of the elements of the input sequences.
/// An async-enumerable sequence whose distinct elements form the first set for the union.
/// An async-enumerable sequence whose distinct elements form the second set for the union.
/// An async-enumerable sequence that contains the elements from both input sequences, excluding duplicates.
/// or is null.
public static IAsyncEnumerable Union(this IAsyncEnumerable first, IAsyncEnumerable second) =>
Union(first, second, comparer: null);
///
/// Produces the set union of two sequences by using a specified equality comparer.
///
/// The type of the elements of the input sequences.
/// An async-enumerable sequence whose distinct elements form the first set for the union.
/// An async-enumerable sequence whose distinct elements form the second set for the union.
/// The equality comparer to compare values.
/// An async-enumerable sequence that contains the elements from both input sequences, excluding duplicates.
/// or is null.
public static IAsyncEnumerable Union(this IAsyncEnumerable first, IAsyncEnumerable second, IEqualityComparer? comparer)
{
if (first == null)
throw Error.ArgumentNull(nameof(first));
if (second == null)
throw Error.ArgumentNull(nameof(second));
return first is UnionAsyncIterator union && AreEqualityComparersEqual(comparer, union._comparer) ? union.Union(second) : new UnionAsyncIterator2(first, second, comparer);
}
private static bool AreEqualityComparersEqual(IEqualityComparer? first, IEqualityComparer? second)
{
return first == second || (first != null && second != null && first.Equals(second));
}
///
/// An iterator that yields distinct values from two or more .
///
/// The type of the source enumerables.
private abstract class UnionAsyncIterator : AsyncIterator, IAsyncIListProvider
{
#pragma warning disable IDE1006 // Naming Styles
internal readonly IEqualityComparer? _comparer;
#pragma warning restore IDE1006 // Naming Styles
private IAsyncEnumerator? _enumerator;
private Set? _set;
private int _index;
protected UnionAsyncIterator(IEqualityComparer? comparer)
{
_comparer = comparer;
}
public sealed override async ValueTask DisposeAsync()
{
if (_enumerator != null)
{
await _enumerator.DisposeAsync().ConfigureAwait(false);
_enumerator = null;
_set = null;
}
await base.DisposeAsync().ConfigureAwait(false);
}
internal abstract IAsyncEnumerable? GetEnumerable(int index);
internal abstract UnionAsyncIterator Union(IAsyncEnumerable next);
private async Task SetEnumeratorAsync(IAsyncEnumerator enumerator)
{
if (_enumerator != null)
{
await _enumerator.DisposeAsync().ConfigureAwait(false);
}
_enumerator = enumerator;
}
private void StoreFirst()
{
var set = new Set(_comparer);
var element = _enumerator!.Current;
set.Add(element);
_current = element;
_set = set;
}
private async ValueTask GetNextAsync()
{
var set = _set;
Debug.Assert(set != null);
while (await _enumerator!.MoveNextAsync().ConfigureAwait(false))
{
var element = _enumerator.Current;
if (set!.Add(element))
{
_current = element;
return true;
}
}
return false;
}
protected sealed override async ValueTask MoveNextCore()
{
switch (_state)
{
case AsyncIteratorState.Allocated:
_index = 0;
for (var enumerable = GetEnumerable(0); enumerable != null; enumerable = GetEnumerable(_index))
{
++_index;
var enumerator = enumerable.GetAsyncEnumerator(_cancellationToken);
await SetEnumeratorAsync(enumerator).ConfigureAwait(false);
if (await enumerator.MoveNextAsync().ConfigureAwait(false))
{
StoreFirst();
_state = AsyncIteratorState.Iterating;
return true;
}
}
break;
case AsyncIteratorState.Iterating:
while (true)
{
if (await GetNextAsync().ConfigureAwait(false))
{
return true;
}
var enumerable = GetEnumerable(_index);
if (enumerable == null)
{
break;
}
await SetEnumeratorAsync(enumerable.GetAsyncEnumerator(_cancellationToken)).ConfigureAwait(false);
++_index;
}
break;
}
await DisposeAsync().ConfigureAwait(false);
return false;
}
private async Task> FillSetAsync(CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
var set = new Set(_comparer);
for (var index = 0; ; ++index)
{
var enumerable = GetEnumerable(index);
if (enumerable == null)
{
return set;
}
await foreach (var item in enumerable.WithCancellation(cancellationToken).ConfigureAwait(false))
{
set.Add(item);
}
}
}
public async ValueTask ToArrayAsync(CancellationToken cancellationToken)
{
var set = await FillSetAsync(cancellationToken).ConfigureAwait(false);
return set.ToArray();
}
public async ValueTask> ToListAsync(CancellationToken cancellationToken)
{
var set = await FillSetAsync(cancellationToken).ConfigureAwait(false);
return set.ToList();
}
public ValueTask GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
{
if (onlyIfCheap)
{
return new ValueTask(-1);
}
return Core();
async ValueTask Core()
{
var set = await FillSetAsync(cancellationToken).ConfigureAwait(false);
return set.Count;
}
}
}
///
/// An iterator that yields distinct values from two .
///
/// The type of the source enumerables.
private sealed class UnionAsyncIterator2 : UnionAsyncIterator
{
private readonly IAsyncEnumerable _first;
private readonly IAsyncEnumerable _second;
public UnionAsyncIterator2(IAsyncEnumerable first, IAsyncEnumerable second, IEqualityComparer? comparer)
: base(comparer)
{
_first = first;
_second = second;
}
public override AsyncIteratorBase Clone() => new UnionAsyncIterator2(_first, _second, _comparer);
internal override IAsyncEnumerable? GetEnumerable(int index)
{
Debug.Assert(index >= 0 && index <= 2);
return index switch
{
0 => _first,
1 => _second,
_ => null,
};
}
internal override UnionAsyncIterator Union(IAsyncEnumerable next)
{
var sources = new SingleLinkedNode>(_first).Add(_second).Add(next);
return new UnionAsyncIteratorN(sources, 2, _comparer);
}
}
///
/// An iterator that yields distinct values from three or more .
///
/// The type of the source enumerables.
private sealed class UnionAsyncIteratorN : UnionAsyncIterator
{
private readonly SingleLinkedNode> _sources;
private readonly int _headIndex;
public UnionAsyncIteratorN(SingleLinkedNode> sources, int headIndex, IEqualityComparer? comparer)
: base(comparer)
{
Debug.Assert(headIndex >= 2);
Debug.Assert(sources.GetCount() == headIndex + 1);
_sources = sources;
_headIndex = headIndex;
}
public override AsyncIteratorBase Clone() => new UnionAsyncIteratorN(_sources, _headIndex, _comparer);
internal override IAsyncEnumerable? GetEnumerable(int index) => index > _headIndex ? null : _sources.GetNode(_headIndex - index).Item;
internal override UnionAsyncIterator Union(IAsyncEnumerable next)
{
if (_headIndex == int.MaxValue - 2)
{
// In the unlikely case of this many unions, if we produced a UnionIteratorN
// with int.MaxValue then state would overflow before it matched it's index.
// So we use the naïve approach of just having a left and right sequence.
return new UnionAsyncIterator2(this, next, _comparer);
}
return new UnionAsyncIteratorN(_sources.Add(next), _headIndex + 1, _comparer);
}
}
#endif // INCLUDE_SYSTEM_LINQ_ASYNCENUMERABLE_DUPLICATES
}
}