// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT License.
// See the LICENSE file in the project root for more information.
using System.Threading;
using System.Threading.Tasks;
namespace System.Collections.Generic
{
public static partial class AsyncEnumerator
{
///
/// Wraps the specified enumerator with an enumerator that checks for cancellation upon every invocation
/// of the method.
///
/// The type of the elements returned by the enumerator.
/// The enumerator to augment with cancellation support.
/// The cancellation token to observe.
/// An enumerator that honors cancellation requests.
public static IAsyncEnumerator WithCancellation(this IAsyncEnumerator source, CancellationToken cancellationToken)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (cancellationToken == default)
return source;
return new WithCancellationAsyncEnumerator(source, cancellationToken);
}
private sealed class WithCancellationAsyncEnumerator : IAsyncEnumerator
{
private readonly IAsyncEnumerator _source;
private readonly CancellationToken _cancellationToken;
public WithCancellationAsyncEnumerator(IAsyncEnumerator source, CancellationToken cancellationToken)
{
_source = source;
_cancellationToken = cancellationToken;
}
public T Current => _source.Current;
public ValueTask DisposeAsync() => _source.DisposeAsync();
public ValueTask MoveNextAsync()
{
_cancellationToken.ThrowIfCancellationRequested();
return _source.MoveNextAsync(); // REVIEW: Signal cancellation through task or synchronously?
}
}
}
}