// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System.Collections.Generic; using System.Diagnostics; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerable { public static IAsyncEnumerable Union(this IAsyncEnumerable first, IAsyncEnumerable second) { if (first == null) throw Error.ArgumentNull(nameof(first)); if (second == null) throw Error.ArgumentNull(nameof(second)); return UnionCore(first, second, comparer: null); } public static IAsyncEnumerable Union(this IAsyncEnumerable first, IAsyncEnumerable second, IEqualityComparer comparer) { if (first == null) throw Error.ArgumentNull(nameof(first)); if (second == null) throw Error.ArgumentNull(nameof(second)); return UnionCore(first, second, comparer); } private static IAsyncEnumerable UnionCore(IAsyncEnumerable first, IAsyncEnumerable second, IEqualityComparer comparer) { return first is UnionAsyncIterator union && AreEqualityComparersEqual(comparer, union._comparer) ? union.Union(second) : new UnionAsyncIterator2(first, second, comparer); } private static bool AreEqualityComparersEqual(IEqualityComparer first, IEqualityComparer second) { return first == second || (first != null && second != null && first.Equals(second)); } /// /// An iterator that yields distinct values from two or more . /// /// The type of the source enumerables. private abstract class UnionAsyncIterator : AsyncIterator, IAsyncIListProvider { internal readonly IEqualityComparer _comparer; private IAsyncEnumerator _enumerator; private Set _set; private int _index; protected UnionAsyncIterator(IEqualityComparer comparer) { _comparer = comparer; } public sealed override async ValueTask DisposeAsync() { if (_enumerator != null) { await _enumerator.DisposeAsync().ConfigureAwait(false); _enumerator = null; _set = null; } await base.DisposeAsync().ConfigureAwait(false); } internal abstract IAsyncEnumerable GetEnumerable(int index); internal abstract UnionAsyncIterator Union(IAsyncEnumerable next); private async Task SetEnumeratorAsync(IAsyncEnumerator enumerator) { if (_enumerator != null) { await _enumerator.DisposeAsync().ConfigureAwait(false); } _enumerator = enumerator; } private void StoreFirst() { var set = new Set(_comparer); var element = _enumerator.Current; set.Add(element); current = element; _set = set; } private async Task GetNextAsync() { var set = _set; Debug.Assert(set != null); while (await _enumerator.MoveNextAsync().ConfigureAwait(false)) { var element = _enumerator.Current; if (set.Add(element)) { current = element; return true; } } return false; } protected sealed override async ValueTask MoveNextCore() { switch (state) { case AsyncIteratorState.Allocated: _index = 0; for (var enumerable = GetEnumerable(0); enumerable != null; enumerable = GetEnumerable(_index)) { ++_index; var enumerator = enumerable.GetAsyncEnumerator(cancellationToken); if (await enumerator.MoveNextAsync().ConfigureAwait(false)) { await SetEnumeratorAsync(enumerator).ConfigureAwait(false); StoreFirst(); state = AsyncIteratorState.Iterating; return true; } } break; case AsyncIteratorState.Iterating: while (true) { if (await GetNextAsync().ConfigureAwait(false)) { return true; } var enumerable = GetEnumerable(_index); if (enumerable == null) { break; } await SetEnumeratorAsync(enumerable.GetAsyncEnumerator(cancellationToken)).ConfigureAwait(false); ++_index; } break; } await DisposeAsync().ConfigureAwait(false); return false; } private async Task> FillSetAsync(CancellationToken cancellationToken) { var set = new Set(_comparer); for (var index = 0; ; ++index) { var enumerable = GetEnumerable(index); if (enumerable == null) { return set; } var e = enumerable.GetAsyncEnumerator(cancellationToken); try { while (await e.MoveNextAsync().ConfigureAwait(false)) { set.Add(e.Current); } } finally { await e.DisposeAsync().ConfigureAwait(false); } } } public Task ToArrayAsync(CancellationToken cancellationToken) => FillSetAsync(cancellationToken).ContinueWith(set => set.Result.ToArray()); public Task> ToListAsync(CancellationToken cancellationToken) => FillSetAsync(cancellationToken).ContinueWith(set => set.Result.ToList()); public Task GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken) => onlyIfCheap ? TaskExt.MinusOne : FillSetAsync(cancellationToken).ContinueWith(set => set.Result.Count); } /// /// An iterator that yields distinct values from two . /// /// The type of the source enumerables. private sealed class UnionAsyncIterator2 : UnionAsyncIterator { private readonly IAsyncEnumerable _first; private readonly IAsyncEnumerable _second; public UnionAsyncIterator2(IAsyncEnumerable first, IAsyncEnumerable second, IEqualityComparer comparer) : base(comparer) { Debug.Assert(first != null); Debug.Assert(second != null); _first = first; _second = second; } public override AsyncIterator Clone() => new UnionAsyncIterator2(_first, _second, _comparer); internal override IAsyncEnumerable GetEnumerable(int index) { Debug.Assert(index >= 0 && index <= 2); switch (index) { case 0: return _first; case 1: return _second; default: return null; } } internal override UnionAsyncIterator Union(IAsyncEnumerable next) { var sources = new SingleLinkedNode>(_first).Add(_second).Add(next); return new UnionAsyncIteratorN(sources, 2, _comparer); } } /// /// An iterator that yields distinct values from three or more . /// /// The type of the source enumerables. private sealed class UnionAsyncIteratorN : UnionAsyncIterator { private readonly SingleLinkedNode> _sources; private readonly int _headIndex; public UnionAsyncIteratorN(SingleLinkedNode> sources, int headIndex, IEqualityComparer comparer) : base(comparer) { Debug.Assert(headIndex >= 2); Debug.Assert(sources?.GetCount() == headIndex + 1); _sources = sources; _headIndex = headIndex; } public override AsyncIterator Clone() => new UnionAsyncIteratorN(_sources, _headIndex, _comparer); internal override IAsyncEnumerable GetEnumerable(int index) => index > _headIndex ? null : _sources.GetNode(_headIndex - index).Item; internal override UnionAsyncIterator Union(IAsyncEnumerable next) { if (_headIndex == int.MaxValue - 2) { // In the unlikely case of this many unions, if we produced a UnionIteratorN // with int.MaxValue then state would overflow before it matched it's index. // So we use the naïve approach of just having a left and right sequence. return new UnionAsyncIterator2(this, next, _comparer); } return new UnionAsyncIteratorN(_sources.Add(next), _headIndex + 1, _comparer); } } } }