// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT License.
// See the LICENSE file in the project root for more information. 
using System.Threading;
using System.Threading.Tasks;
namespace System.Collections.Generic
{
    public static partial class AsyncEnumerator
    {
        /// 
        /// Wraps the specified enumerator with an enumerator that checks for cancellation upon every invocation
        /// of the  method.
        /// 
        /// The type of the elements returned by the enumerator.
        /// The enumerator to augment with cancellation support.
        /// The cancellation token to observe.
        /// An enumerator that honors cancellation requests.
        public static IAsyncEnumerator WithCancellation(this IAsyncEnumerator source, CancellationToken cancellationToken)
        {
            if (source == null)
                throw Error.ArgumentNull(nameof(source));
            if (cancellationToken == default)
                return source;
            return new WithCancellationAsyncEnumerator(source, cancellationToken);
        }
        private sealed class WithCancellationAsyncEnumerator : IAsyncEnumerator
        {
            private readonly IAsyncEnumerator _source;
            private readonly CancellationToken _cancellationToken;
            public WithCancellationAsyncEnumerator(IAsyncEnumerator source, CancellationToken cancellationToken)
            {
                _source = source;
                _cancellationToken = cancellationToken;
            }
            public T Current => _source.Current;
            public ValueTask DisposeAsync() => _source.DisposeAsync();
            public ValueTask MoveNextAsync()
            {
                _cancellationToken.ThrowIfCancellationRequested();
                return _source.MoveNextAsync(); // REVIEW: Signal cancellation through task or synchronously?
            }
        }
    }
}