// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT License. // See the LICENSE file in the project root for more information. using System.Collections.Generic; using System.Diagnostics; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerable { #if INCLUDE_SYSTEM_LINQ_ASYNCENUMERABLE_DUPLICATES // https://learn.microsoft.com/en-us/dotnet/api/system.linq.asyncenumerable.union?view=net-9.0-pp // That one overload covers the next two methods, because it supplieds a default comparer. /// /// Produces the set union of two sequences by using the default equality comparer. /// /// The type of the elements of the input sequences. /// An async-enumerable sequence whose distinct elements form the first set for the union. /// An async-enumerable sequence whose distinct elements form the second set for the union. /// An async-enumerable sequence that contains the elements from both input sequences, excluding duplicates. /// or is null. public static IAsyncEnumerable Union(this IAsyncEnumerable first, IAsyncEnumerable second) => Union(first, second, comparer: null); /// /// Produces the set union of two sequences by using a specified equality comparer. /// /// The type of the elements of the input sequences. /// An async-enumerable sequence whose distinct elements form the first set for the union. /// An async-enumerable sequence whose distinct elements form the second set for the union. /// The equality comparer to compare values. /// An async-enumerable sequence that contains the elements from both input sequences, excluding duplicates. /// or is null. public static IAsyncEnumerable Union(this IAsyncEnumerable first, IAsyncEnumerable second, IEqualityComparer? comparer) { if (first == null) throw Error.ArgumentNull(nameof(first)); if (second == null) throw Error.ArgumentNull(nameof(second)); return first is UnionAsyncIterator union && AreEqualityComparersEqual(comparer, union._comparer) ? union.Union(second) : new UnionAsyncIterator2(first, second, comparer); } private static bool AreEqualityComparersEqual(IEqualityComparer? first, IEqualityComparer? second) { return first == second || (first != null && second != null && first.Equals(second)); } /// /// An iterator that yields distinct values from two or more . /// /// The type of the source enumerables. private abstract class UnionAsyncIterator : AsyncIterator, IAsyncIListProvider { #pragma warning disable IDE1006 // Naming Styles internal readonly IEqualityComparer? _comparer; #pragma warning restore IDE1006 // Naming Styles private IAsyncEnumerator? _enumerator; private Set? _set; private int _index; protected UnionAsyncIterator(IEqualityComparer? comparer) { _comparer = comparer; } public sealed override async ValueTask DisposeAsync() { if (_enumerator != null) { await _enumerator.DisposeAsync().ConfigureAwait(false); _enumerator = null; _set = null; } await base.DisposeAsync().ConfigureAwait(false); } internal abstract IAsyncEnumerable? GetEnumerable(int index); internal abstract UnionAsyncIterator Union(IAsyncEnumerable next); private async Task SetEnumeratorAsync(IAsyncEnumerator enumerator) { if (_enumerator != null) { await _enumerator.DisposeAsync().ConfigureAwait(false); } _enumerator = enumerator; } private void StoreFirst() { var set = new Set(_comparer); var element = _enumerator!.Current; set.Add(element); _current = element; _set = set; } private async ValueTask GetNextAsync() { var set = _set; Debug.Assert(set != null); while (await _enumerator!.MoveNextAsync().ConfigureAwait(false)) { var element = _enumerator.Current; if (set!.Add(element)) { _current = element; return true; } } return false; } protected sealed override async ValueTask MoveNextCore() { switch (_state) { case AsyncIteratorState.Allocated: _index = 0; for (var enumerable = GetEnumerable(0); enumerable != null; enumerable = GetEnumerable(_index)) { ++_index; var enumerator = enumerable.GetAsyncEnumerator(_cancellationToken); await SetEnumeratorAsync(enumerator).ConfigureAwait(false); if (await enumerator.MoveNextAsync().ConfigureAwait(false)) { StoreFirst(); _state = AsyncIteratorState.Iterating; return true; } } break; case AsyncIteratorState.Iterating: while (true) { if (await GetNextAsync().ConfigureAwait(false)) { return true; } var enumerable = GetEnumerable(_index); if (enumerable == null) { break; } await SetEnumeratorAsync(enumerable.GetAsyncEnumerator(_cancellationToken)).ConfigureAwait(false); ++_index; } break; } await DisposeAsync().ConfigureAwait(false); return false; } private async Task> FillSetAsync(CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); var set = new Set(_comparer); for (var index = 0; ; ++index) { var enumerable = GetEnumerable(index); if (enumerable == null) { return set; } await foreach (var item in enumerable.WithCancellation(cancellationToken).ConfigureAwait(false)) { set.Add(item); } } } public async ValueTask ToArrayAsync(CancellationToken cancellationToken) { var set = await FillSetAsync(cancellationToken).ConfigureAwait(false); return set.ToArray(); } public async ValueTask> ToListAsync(CancellationToken cancellationToken) { var set = await FillSetAsync(cancellationToken).ConfigureAwait(false); return set.ToList(); } public ValueTask GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken) { if (onlyIfCheap) { return new ValueTask(-1); } return Core(); async ValueTask Core() { var set = await FillSetAsync(cancellationToken).ConfigureAwait(false); return set.Count; } } } /// /// An iterator that yields distinct values from two . /// /// The type of the source enumerables. private sealed class UnionAsyncIterator2 : UnionAsyncIterator { private readonly IAsyncEnumerable _first; private readonly IAsyncEnumerable _second; public UnionAsyncIterator2(IAsyncEnumerable first, IAsyncEnumerable second, IEqualityComparer? comparer) : base(comparer) { _first = first; _second = second; } public override AsyncIteratorBase Clone() => new UnionAsyncIterator2(_first, _second, _comparer); internal override IAsyncEnumerable? GetEnumerable(int index) { Debug.Assert(index >= 0 && index <= 2); return index switch { 0 => _first, 1 => _second, _ => null, }; } internal override UnionAsyncIterator Union(IAsyncEnumerable next) { var sources = new SingleLinkedNode>(_first).Add(_second).Add(next); return new UnionAsyncIteratorN(sources, 2, _comparer); } } /// /// An iterator that yields distinct values from three or more . /// /// The type of the source enumerables. private sealed class UnionAsyncIteratorN : UnionAsyncIterator { private readonly SingleLinkedNode> _sources; private readonly int _headIndex; public UnionAsyncIteratorN(SingleLinkedNode> sources, int headIndex, IEqualityComparer? comparer) : base(comparer) { Debug.Assert(headIndex >= 2); Debug.Assert(sources.GetCount() == headIndex + 1); _sources = sources; _headIndex = headIndex; } public override AsyncIteratorBase Clone() => new UnionAsyncIteratorN(_sources, _headIndex, _comparer); internal override IAsyncEnumerable? GetEnumerable(int index) => index > _headIndex ? null : _sources.GetNode(_headIndex - index).Item; internal override UnionAsyncIterator Union(IAsyncEnumerable next) { if (_headIndex == int.MaxValue - 2) { // In the unlikely case of this many unions, if we produced a UnionIteratorN // with int.MaxValue then state would overflow before it matched it's index. // So we use the naïve approach of just having a left and right sequence. return new UnionAsyncIterator2(this, next, _comparer); } return new UnionAsyncIteratorN(_sources.Add(next), _headIndex + 1, _comparer); } } #endif // INCLUDE_SYSTEM_LINQ_ASYNCENUMERABLE_DUPLICATES } }