// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information. 
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
namespace System.Linq
{
    public static partial class AsyncEnumerable
    {
        /// 
        /// Produces the set union of two sequences by using the default equality comparer.
        /// 
        /// The type of the elements of the input sequences.
        /// An async-enumerable sequence whose distinct elements form the first set for the union.
        /// An async-enumerable sequence whose distinct elements form the second set for the union.
        /// An async-enumerable sequence that contains the elements from both input sequences, excluding duplicates.
        ///  or  is null.
        public static IAsyncEnumerable Union(this IAsyncEnumerable first, IAsyncEnumerable second) =>
            Union(first, second, comparer: null);
        /// 
        /// Produces the set union of two sequences by using a specified equality comparer.
        /// 
        /// The type of the elements of the input sequences.
        /// An async-enumerable sequence whose distinct elements form the first set for the union.
        /// An async-enumerable sequence whose distinct elements form the second set for the union.
        /// The equality comparer to compare values.
        /// An async-enumerable sequence that contains the elements from both input sequences, excluding duplicates.
        ///  or  is null.
        public static IAsyncEnumerable Union(this IAsyncEnumerable first, IAsyncEnumerable second, IEqualityComparer? comparer)
        {
            if (first == null)
                throw Error.ArgumentNull(nameof(first));
            if (second == null)
                throw Error.ArgumentNull(nameof(second));
            return first is UnionAsyncIterator union && AreEqualityComparersEqual(comparer, union._comparer) ? union.Union(second) : new UnionAsyncIterator2(first, second, comparer);
        }
        private static bool AreEqualityComparersEqual(IEqualityComparer? first, IEqualityComparer? second)
        {
            return first == second || (first != null && second != null && first.Equals(second));
        }
        /// 
        /// An iterator that yields distinct values from two or more .
        /// 
        /// The type of the source enumerables.
        private abstract class UnionAsyncIterator : AsyncIterator, IAsyncIListProvider
        {
            internal readonly IEqualityComparer? _comparer;
            private IAsyncEnumerator? _enumerator;
            private Set? _set;
            private int _index;
            protected UnionAsyncIterator(IEqualityComparer? comparer)
            {
                _comparer = comparer;
            }
            public sealed override async ValueTask DisposeAsync()
            {
                if (_enumerator != null)
                {
                    await _enumerator.DisposeAsync().ConfigureAwait(false);
                    _enumerator = null;
                    _set = null;
                }
                await base.DisposeAsync().ConfigureAwait(false);
            }
            internal abstract IAsyncEnumerable? GetEnumerable(int index);
            internal abstract UnionAsyncIterator Union(IAsyncEnumerable next);
            private async Task SetEnumeratorAsync(IAsyncEnumerator enumerator)
            {
                if (_enumerator != null)
                {
                    await _enumerator.DisposeAsync().ConfigureAwait(false);
                }
                _enumerator = enumerator;
            }
            private void StoreFirst()
            {
                var set = new Set(_comparer);
                var element = _enumerator!.Current;
                set.Add(element);
                _current = element;
                _set = set;
            }
            private async ValueTask GetNextAsync()
            {
                var set = _set;
                Debug.Assert(set != null);
                while (await _enumerator!.MoveNextAsync().ConfigureAwait(false))
                {
                    var element = _enumerator.Current;
                    if (set!.Add(element))
                    {
                        _current = element;
                        return true;
                    }
                }
                return false;
            }
            protected sealed override async ValueTask MoveNextCore()
            {
                switch (_state)
                {
                    case AsyncIteratorState.Allocated:
                        _index = 0;
                        for (var enumerable = GetEnumerable(0); enumerable != null; enumerable = GetEnumerable(_index))
                        {
                            ++_index;
                            var enumerator = enumerable.GetAsyncEnumerator(_cancellationToken);
                            if (await enumerator.MoveNextAsync().ConfigureAwait(false))
                            {
                                await SetEnumeratorAsync(enumerator).ConfigureAwait(false);
                                StoreFirst();
                                _state = AsyncIteratorState.Iterating;
                                return true;
                            }
                        }
                        break;
                    case AsyncIteratorState.Iterating:
                        while (true)
                        {
                            if (await GetNextAsync().ConfigureAwait(false))
                            {
                                return true;
                            }
                            var enumerable = GetEnumerable(_index);
                            if (enumerable == null)
                            {
                                break;
                            }
                            await SetEnumeratorAsync(enumerable.GetAsyncEnumerator(_cancellationToken)).ConfigureAwait(false);
                            ++_index;
                        }
                        break;
                }
                await DisposeAsync().ConfigureAwait(false);
                return false;
            }
            private async Task> FillSetAsync(CancellationToken cancellationToken)
            {
                cancellationToken.ThrowIfCancellationRequested();
                var set = new Set(_comparer);
                for (var index = 0; ; ++index)
                {
                    var enumerable = GetEnumerable(index);
                    if (enumerable == null)
                    {
                        return set;
                    }
                    await foreach (var item in enumerable.WithCancellation(cancellationToken).ConfigureAwait(false))
                    {
                        set.Add(item);
                    }
                }
            }
            public async ValueTask ToArrayAsync(CancellationToken cancellationToken)
            {
                var set = await FillSetAsync(cancellationToken).ConfigureAwait(false);
                return set.ToArray();
            }
            public async ValueTask> ToListAsync(CancellationToken cancellationToken)
            {
                var set = await FillSetAsync(cancellationToken).ConfigureAwait(false);
                return set.ToList();
            }
            public ValueTask GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
            {
                if (onlyIfCheap)
                {
                    return new ValueTask(-1);
                }
                return Core();
                async ValueTask Core()
                {
                    var set = await FillSetAsync(cancellationToken).ConfigureAwait(false);
                    return set.Count;
                }
            }
        }
        /// 
        /// An iterator that yields distinct values from two .
        /// 
        /// The type of the source enumerables.
        private sealed class UnionAsyncIterator2 : UnionAsyncIterator
        {
            private readonly IAsyncEnumerable _first;
            private readonly IAsyncEnumerable _second;
            public UnionAsyncIterator2(IAsyncEnumerable first, IAsyncEnumerable second, IEqualityComparer? comparer)
                : base(comparer)
            {
                _first = first;
                _second = second;
            }
            public override AsyncIteratorBase Clone() => new UnionAsyncIterator2(_first, _second, _comparer);
            internal override IAsyncEnumerable? GetEnumerable(int index)
            {
                Debug.Assert(index >= 0 && index <= 2);
                return index switch
                {
                    0 => _first,
                    1 => _second,
                    _ => null,
                };
            }
            internal override UnionAsyncIterator Union(IAsyncEnumerable next)
            {
                var sources = new SingleLinkedNode>(_first).Add(_second).Add(next);
                return new UnionAsyncIteratorN(sources, 2, _comparer);
            }
        }
        /// 
        /// An iterator that yields distinct values from three or more .
        /// 
        /// The type of the source enumerables.
        private sealed class UnionAsyncIteratorN : UnionAsyncIterator
        {
            private readonly SingleLinkedNode> _sources;
            private readonly int _headIndex;
            public UnionAsyncIteratorN(SingleLinkedNode> sources, int headIndex, IEqualityComparer? comparer)
                : base(comparer)
            {
                Debug.Assert(headIndex >= 2);
                Debug.Assert(sources.GetCount() == headIndex + 1);
                _sources = sources;
                _headIndex = headIndex;
            }
            public override AsyncIteratorBase Clone() => new UnionAsyncIteratorN(_sources, _headIndex, _comparer);
            internal override IAsyncEnumerable? GetEnumerable(int index) => index > _headIndex ? null : _sources.GetNode(_headIndex - index).Item;
            internal override UnionAsyncIterator Union(IAsyncEnumerable next)
            {
                if (_headIndex == int.MaxValue - 2)
                {
                    // In the unlikely case of this many unions, if we produced a UnionIteratorN
                    // with int.MaxValue then state would overflow before it matched it's index.
                    // So we use the naïve approach of just having a left and right sequence.
                    return new UnionAsyncIterator2(this, next, _comparer);
                }
                return new UnionAsyncIteratorN(_sources.Add(next), _headIndex + 1, _comparer);
            }
        }
    }
}