// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information. 
using System.Collections;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace System.Linq
{
    public static partial class AsyncEnumerable
    {
        /// 
        /// Returns the elements of the specified sequence or the type parameter's default value in a singleton sequence if the sequence is empty.
        /// 
        /// The type of the elements in the source sequence (if any), whose default value will be taken if the sequence is empty.
        /// The sequence to return a default value for if it is empty.
        /// An async-enumerable sequence that contains the default value for the TSource type if the source is empty; otherwise, the elements of the source itself.
        ///  is null.
        public static IAsyncEnumerable DefaultIfEmpty(this IAsyncEnumerable source) =>
            DefaultIfEmpty(source, default!);
        /// 
        /// Returns the elements of the specified sequence or the specified value in a singleton sequence if the sequence is empty.
        /// 
        /// The type of the elements in the source sequence (if any), and the specified default value which will be taken if the sequence is empty.
        /// The sequence to return the specified value for if it is empty.
        /// The value to return if the sequence is empty.
        /// An async-enumerable sequence that contains the specified default value if the source is empty; otherwise, the elements of the source itself.
        ///  is null.
        public static IAsyncEnumerable DefaultIfEmpty(this IAsyncEnumerable source, TSource defaultValue)
        {
            if (source == null)
                throw Error.ArgumentNull(nameof(source));
            return new DefaultIfEmptyAsyncIterator(source, defaultValue);
        }
        private sealed class DefaultIfEmptyAsyncIterator : AsyncIterator, IAsyncIListProvider
        {
            private readonly IAsyncEnumerable _source;
            private readonly TSource _defaultValue;
            private IAsyncEnumerator? _enumerator;
            public DefaultIfEmptyAsyncIterator(IAsyncEnumerable source, TSource defaultValue)
            {
                _source = source;
                _defaultValue = defaultValue;
            }
            public override AsyncIteratorBase Clone()
            {
                return new DefaultIfEmptyAsyncIterator(_source, _defaultValue);
            }
            public override async ValueTask DisposeAsync()
            {
                if (_enumerator != null)
                {
                    await _enumerator.DisposeAsync().ConfigureAwait(false);
                    _enumerator = null;
                }
                await base.DisposeAsync().ConfigureAwait(false);
            }
            protected override async ValueTask MoveNextCore()
            {
                switch (_state)
                {
                    case AsyncIteratorState.Allocated:
                        _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
                        if (await _enumerator.MoveNextAsync().ConfigureAwait(false))
                        {
                            _current = _enumerator.Current;
                            _state = AsyncIteratorState.Iterating;
                        }
                        else
                        {
                            _current = _defaultValue;
                            await _enumerator.DisposeAsync().ConfigureAwait(false);
                            _enumerator = null;
                            _state = AsyncIteratorState.Disposed;
                        }
                        return true;
                    case AsyncIteratorState.Iterating:
                        if (await _enumerator!.MoveNextAsync().ConfigureAwait(false))
                        {
                            _current = _enumerator.Current;
                            return true;
                        }
                        break;
                }
                await DisposeAsync().ConfigureAwait(false);
                return false;
            }
            public async ValueTask ToArrayAsync(CancellationToken cancellationToken)
            {
                var array = await _source.ToArrayAsync(cancellationToken).ConfigureAwait(false);
                return array.Length == 0 ? new[] { _defaultValue } : array;
            }
            public async ValueTask> ToListAsync(CancellationToken cancellationToken)
            {
                var list = await _source.ToListAsync(cancellationToken).ConfigureAwait(false);
                if (list.Count == 0)
                {
                    list.Add(_defaultValue);
                }
                return list;
            }
            public async ValueTask GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
            {
                int count;
                if (!onlyIfCheap || _source is ICollection || _source is ICollection)
                {
                    count = await _source.CountAsync(cancellationToken).ConfigureAwait(false);
                }
                else if (_source is IAsyncIListProvider listProv)
                {
                    count = await listProv.GetCountAsync(onlyIfCheap: true, cancellationToken).ConfigureAwait(false);
                }
                else
                {
                    count = -1;
                }
                return count == 0 ? 1 : count;
            }
        }
    }
}