// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information. 
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace System.Linq
{
    public static partial class AsyncEnumerable
    {
        /// 
        /// Filters the elements of an async-enumerable sequence based on a predicate.
        /// 
        /// The type of the elements in the source sequence.
        /// An async-enumerable sequence whose elements to filter.
        /// A function to test each source element for a condition.
        /// An async-enumerable sequence that contains elements from the input sequence that satisfy the condition.
        ///  or  is null.
        public static IAsyncEnumerable Where(this IAsyncEnumerable source, Func predicate)
        {
            if (source == null)
                throw Error.ArgumentNull(nameof(source));
            if (predicate == null)
                throw Error.ArgumentNull(nameof(predicate));
            if (source is AsyncIteratorBase iterator)
            {
                return iterator.Where(predicate);
            }
            // TODO: Can we add array/list optimizations here, does it make sense?
            return new WhereEnumerableAsyncIterator(source, predicate);
        }
        /// 
        /// Filters the elements of an async-enumerable sequence based on a predicate by incorporating the element's index.
        /// 
        /// The type of the elements in the source sequence.
        /// An async-enumerable sequence whose elements to filter.
        /// A function to test each source element for a condition; the second parameter of the function represents the index of the source element.
        /// An async-enumerable sequence that contains elements from the input sequence that satisfy the condition.
        ///  or  is null.
        public static IAsyncEnumerable Where(this IAsyncEnumerable source, Func predicate)
        {
            if (source == null)
                throw Error.ArgumentNull(nameof(source));
            if (predicate == null)
                throw Error.ArgumentNull(nameof(predicate));
            return Create(Core);
            async IAsyncEnumerator Core(CancellationToken cancellationToken)
            {
                var index = -1;
                await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
                {
                    checked
                    {
                        index++;
                    }
                    if (predicate(element, index))
                    {
                        yield return element;
                    }
                }
            }
        }
        internal static IAsyncEnumerable WhereAwaitCore(this IAsyncEnumerable source, Func> predicate)
        {
            if (source == null)
                throw Error.ArgumentNull(nameof(source));
            if (predicate == null)
                throw Error.ArgumentNull(nameof(predicate));
            if (source is AsyncIteratorBase iterator)
            {
                return iterator.Where(predicate);
            }
            // TODO: Can we add array/list optimizations here, does it make sense?
            return new WhereEnumerableAsyncIteratorWithTask(source, predicate);
        }
#if !NO_DEEP_CANCELLATION
        internal static IAsyncEnumerable WhereAwaitWithCancellationCore(this IAsyncEnumerable source, Func> predicate)
        {
            if (source == null)
                throw Error.ArgumentNull(nameof(source));
            if (predicate == null)
                throw Error.ArgumentNull(nameof(predicate));
            if (source is AsyncIteratorBase iterator)
            {
                return iterator.Where(predicate);
            }
            // TODO: Can we add array/list optimizations here, does it make sense?
            return new WhereEnumerableAsyncIteratorWithTaskAndCancellation(source, predicate);
        }
#endif
        internal static IAsyncEnumerable WhereAwaitCore(this IAsyncEnumerable source, Func> predicate)
        {
            if (source == null)
                throw Error.ArgumentNull(nameof(source));
            if (predicate == null)
                throw Error.ArgumentNull(nameof(predicate));
            return Create(Core);
            async IAsyncEnumerator Core(CancellationToken cancellationToken)
            {
                var index = -1;
                await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
                {
                    checked
                    {
                        index++;
                    }
                    if (await predicate(element, index).ConfigureAwait(false))
                    {
                        yield return element;
                    }
                }
            }
        }
#if !NO_DEEP_CANCELLATION
        internal static IAsyncEnumerable WhereAwaitWithCancellationCore(this IAsyncEnumerable source, Func> predicate)
        {
            if (source == null)
                throw Error.ArgumentNull(nameof(source));
            if (predicate == null)
                throw Error.ArgumentNull(nameof(predicate));
            return Create(Core);
            async IAsyncEnumerator Core(CancellationToken cancellationToken)
            {
                var index = -1;
                await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
                {
                    checked
                    {
                        index++;
                    }
                    if (await predicate(element, index, cancellationToken).ConfigureAwait(false))
                    {
                        yield return element;
                    }
                }
            }
        }
#endif
        internal sealed class WhereEnumerableAsyncIterator : AsyncIterator
        {
            private readonly Func _predicate;
            private readonly IAsyncEnumerable _source;
            private IAsyncEnumerator? _enumerator;
            public WhereEnumerableAsyncIterator(IAsyncEnumerable source, Func predicate)
            {
                _source = source;
                _predicate = predicate;
            }
            public override AsyncIteratorBase Clone()
            {
                return new WhereEnumerableAsyncIterator(_source, _predicate);
            }
            public override async ValueTask DisposeAsync()
            {
                if (_enumerator != null)
                {
                    await _enumerator.DisposeAsync().ConfigureAwait(false);
                    _enumerator = null;
                }
                await base.DisposeAsync().ConfigureAwait(false);
            }
            public override IAsyncEnumerable Select(Func selector)
            {
                return new WhereSelectEnumerableAsyncIterator(_source, _predicate, selector);
            }
            public override IAsyncEnumerable Where(Func predicate)
            {
                return new WhereEnumerableAsyncIterator(_source, CombinePredicates(_predicate, predicate));
            }
            protected override async ValueTask MoveNextCore()
            {
                switch (_state)
                {
                    case AsyncIteratorState.Allocated:
                        _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
                        _state = AsyncIteratorState.Iterating;
                        goto case AsyncIteratorState.Iterating;
                    case AsyncIteratorState.Iterating:
                        while (await _enumerator!.MoveNextAsync().ConfigureAwait(false))
                        {
                            var item = _enumerator.Current;
                            if (_predicate(item))
                            {
                                _current = item;
                                return true;
                            }
                        }
                        await DisposeAsync().ConfigureAwait(false);
                        break;
                }
                return false;
            }
        }
        internal sealed class WhereEnumerableAsyncIteratorWithTask : AsyncIterator
        {
            private readonly Func> _predicate;
            private readonly IAsyncEnumerable _source;
            private IAsyncEnumerator? _enumerator;
            public WhereEnumerableAsyncIteratorWithTask(IAsyncEnumerable source, Func> predicate)
            {
                _source = source;
                _predicate = predicate;
            }
            public override AsyncIteratorBase Clone()
            {
                return new WhereEnumerableAsyncIteratorWithTask(_source, _predicate);
            }
            public override async ValueTask DisposeAsync()
            {
                if (_enumerator != null)
                {
                    await _enumerator.DisposeAsync().ConfigureAwait(false);
                    _enumerator = null;
                }
                await base.DisposeAsync().ConfigureAwait(false);
            }
            public override IAsyncEnumerable Where(Func> predicate)
            {
                return new WhereEnumerableAsyncIteratorWithTask(_source, CombinePredicates(_predicate, predicate));
            }
            protected override async ValueTask MoveNextCore()
            {
                switch (_state)
                {
                    case AsyncIteratorState.Allocated:
                        _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
                        _state = AsyncIteratorState.Iterating;
                        goto case AsyncIteratorState.Iterating;
                    case AsyncIteratorState.Iterating:
                        while (await _enumerator!.MoveNextAsync().ConfigureAwait(false))
                        {
                            var item = _enumerator.Current;
                            if (await _predicate(item).ConfigureAwait(false))
                            {
                                _current = item;
                                return true;
                            }
                        }
                        await DisposeAsync().ConfigureAwait(false);
                        break;
                }
                return false;
            }
        }
#if !NO_DEEP_CANCELLATION
        internal sealed class WhereEnumerableAsyncIteratorWithTaskAndCancellation : AsyncIterator
        {
            private readonly Func> _predicate;
            private readonly IAsyncEnumerable _source;
            private IAsyncEnumerator? _enumerator;
            public WhereEnumerableAsyncIteratorWithTaskAndCancellation(IAsyncEnumerable source, Func> predicate)
            {
                _source = source;
                _predicate = predicate;
            }
            public override AsyncIteratorBase Clone()
            {
                return new WhereEnumerableAsyncIteratorWithTaskAndCancellation(_source, _predicate);
            }
            public override async ValueTask DisposeAsync()
            {
                if (_enumerator != null)
                {
                    await _enumerator.DisposeAsync().ConfigureAwait(false);
                    _enumerator = null;
                }
                await base.DisposeAsync().ConfigureAwait(false);
            }
            public override IAsyncEnumerable Where(Func> predicate)
            {
                return new WhereEnumerableAsyncIteratorWithTaskAndCancellation(_source, CombinePredicates(_predicate, predicate));
            }
            protected override async ValueTask MoveNextCore()
            {
                switch (_state)
                {
                    case AsyncIteratorState.Allocated:
                        _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
                        _state = AsyncIteratorState.Iterating;
                        goto case AsyncIteratorState.Iterating;
                    case AsyncIteratorState.Iterating:
                        while (await _enumerator!.MoveNextAsync().ConfigureAwait(false))
                        {
                            var item = _enumerator.Current;
                            if (await _predicate(item, _cancellationToken).ConfigureAwait(false))
                            {
                                _current = item;
                                return true;
                            }
                        }
                        await DisposeAsync().ConfigureAwait(false);
                        break;
                }
                return false;
            }
        }
#endif
        private sealed class WhereSelectEnumerableAsyncIterator : AsyncIterator
        {
            private readonly Func _predicate;
            private readonly Func _selector;
            private readonly IAsyncEnumerable _source;
            private IAsyncEnumerator? _enumerator;
            public WhereSelectEnumerableAsyncIterator(IAsyncEnumerable source, Func predicate, Func selector)
            {
                _source = source;
                _predicate = predicate;
                _selector = selector;
            }
            public override AsyncIteratorBase Clone()
            {
                return new WhereSelectEnumerableAsyncIterator(_source, _predicate, _selector);
            }
            public override async ValueTask DisposeAsync()
            {
                if (_enumerator != null)
                {
                    await _enumerator.DisposeAsync().ConfigureAwait(false);
                    _enumerator = null;
                }
                await base.DisposeAsync().ConfigureAwait(false);
            }
            public override IAsyncEnumerable Select(Func selector)
            {
                return new WhereSelectEnumerableAsyncIterator(_source, _predicate, CombineSelectors(_selector, selector));
            }
            protected override async ValueTask MoveNextCore()
            {
                switch (_state)
                {
                    case AsyncIteratorState.Allocated:
                        _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
                        _state = AsyncIteratorState.Iterating;
                        goto case AsyncIteratorState.Iterating;
                    case AsyncIteratorState.Iterating:
                        while (await _enumerator!.MoveNextAsync().ConfigureAwait(false))
                        {
                            var item = _enumerator.Current;
                            if (_predicate(item))
                            {
                                _current = _selector(item);
                                return true;
                            }
                        }
                        await DisposeAsync().ConfigureAwait(false);
                        break;
                }
                return false;
            }
        }
    }
}