// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information.
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace System.Collections.Generic
{
///
/// Provides a set of extension methods for .
///
public static class AsyncEnumerator
{
///
/// Creates a new enumerator using the specified delegates implementing the members of .
///
/// The type of the elements returned by the enumerator.
/// The delegate implementing the method.
/// The delegate implementing the property getter.
/// The delegate implementing the method.
/// A new enumerator instance.
public static IAsyncEnumerator Create(Func> moveNext, Func current, Func dispose)
{
if (moveNext == null)
throw new ArgumentNullException(nameof(moveNext));
// Note: Many methods pass null in for the second two params. We're assuming
// That the caller is responsible and knows what they're doing
return new AnonymousAsyncIterator(moveNext, current, dispose);
}
///
/// Advances the enumerator to the next element in the sequence, returning the result asynchronously.
///
/// The type of the elements returned by the enumerator.
/// The enumerator to advance.
/// Cancellation token that can be used to cancel the operation.
///
/// Task containing the result of the operation: true if the enumerator was successfully advanced
/// to the next element; false if the enumerator has passed the end of the sequence.
///
public static Task MoveNextAsync(this IAsyncEnumerator source, CancellationToken cancellationToken)
{
if (source == null)
throw new ArgumentNullException(nameof(source));
cancellationToken.ThrowIfCancellationRequested();
return source.MoveNextAsync();
}
///
/// Wraps the specified enumerator with an enumerator that checks for cancellation upon every invocation
/// of the method.
///
/// The type of the elements returned by the enumerator.
/// The enumerator to augment with cancellation support.
/// The cancellation token to observe.
/// An enumerator that honors cancellation requests.
public static IAsyncEnumerator WithCancellation(this IAsyncEnumerator source, CancellationToken cancellationToken)
{
if (source == null)
throw new ArgumentNullException(nameof(source));
return new AnonymousAsyncIterator(
moveNext: () =>
{
cancellationToken.ThrowIfCancellationRequested();
return source.MoveNextAsync();
},
currentFunc: () => source.Current,
dispose: source.DisposeAsync
);
}
///
/// Wraps the specified enumerator in an enumerable.
///
/// The type of the elements returned by the enumerator.
/// The enumerator to wrap.
/// An enumerable wrapping the specified enumerator.
public static IAsyncEnumerable AsEnumerable(this IAsyncEnumerator source)
{
if (source == null)
throw new ArgumentNullException(nameof(source));
return AsyncEnumerable.CreateEnumerable(() => source);
}
internal static IAsyncEnumerator Create(Func, Task> moveNext, Func current, Func dispose)
{
return new AnonymousAsyncIterator(
async () =>
{
var tcs = new TaskCompletionSource();
return await moveNext(tcs).ConfigureAwait(false);
},
current,
dispose
);
}
private sealed class AnonymousAsyncIterator : AsyncIterator
{
private readonly Func currentFunc;
private readonly Func> moveNext;
private Func dispose;
public AnonymousAsyncIterator(Func> moveNext, Func currentFunc, Func dispose)
{
Debug.Assert(moveNext != null);
this.moveNext = moveNext;
this.currentFunc = currentFunc;
this.dispose = dispose;
// Explicit call to initialize enumerator mode
GetAsyncEnumerator();
}
public override AsyncIterator Clone()
{
throw new NotSupportedException("AnonymousAsyncIterator cannot be cloned. It is only intended for use as an iterator.");
}
public override async Task DisposeAsync()
{
var dispose = Interlocked.Exchange(ref this.dispose, null);
if (dispose != null)
{
await dispose().ConfigureAwait(false);
}
await base.DisposeAsync().ConfigureAwait(false);
}
protected override async Task MoveNextCore()
{
switch (state)
{
case AsyncIteratorState.Allocated:
state = AsyncIteratorState.Iterating;
goto case AsyncIteratorState.Iterating;
case AsyncIteratorState.Iterating:
if (await moveNext().ConfigureAwait(false))
{
current = currentFunc();
return true;
}
await DisposeAsync().ConfigureAwait(false);
break;
}
return false;
}
}
}
}