// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System.Collections.Generic; using System.Diagnostics; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerable { public static IAsyncEnumerable CreateEnumerable(Func> getEnumerator) { if (getEnumerator == null) throw new ArgumentNullException(nameof(getEnumerator)); return new AnonymousAsyncEnumerable(getEnumerator); } public static IAsyncEnumerable CreateEnumerable(Func>> getEnumerator) { if (getEnumerator == null) throw new ArgumentNullException(nameof(getEnumerator)); return new AnonymousAsyncEnumerableWithTask(getEnumerator); } public static IAsyncEnumerator CreateEnumerator(Func> moveNext, Func current, Func dispose) { if (moveNext == null) throw new ArgumentNullException(nameof(moveNext)); // Note: Many methods pass null in for the second two params. We're assuming // That the caller is responsible and knows what they're doing return new AnonymousAsyncIterator(moveNext, current, dispose); } private static IAsyncEnumerator CreateEnumerator(Func, Task> moveNext, Func current, Func dispose) { var self = new AnonymousAsyncIterator( async () => { var tcs = new TaskCompletionSource(); var stop = new Action(() => tcs.TrySetCanceled()); return await moveNext(tcs).ConfigureAwait(false); }, current, dispose ); return self; } private sealed class AnonymousAsyncEnumerable : IAsyncEnumerable { private readonly Func> getEnumerator; public AnonymousAsyncEnumerable(Func> getEnumerator) { Debug.Assert(getEnumerator != null); this.getEnumerator = getEnumerator; } public IAsyncEnumerator GetAsyncEnumerator() => getEnumerator(); } private sealed class AnonymousAsyncEnumerableWithTask : IAsyncEnumerable { private readonly Func>> getEnumerator; public AnonymousAsyncEnumerableWithTask(Func>> getEnumerator) { Debug.Assert(getEnumerator != null); this.getEnumerator = getEnumerator; } public IAsyncEnumerator GetAsyncEnumerator() => new Enumerator(getEnumerator); private sealed class Enumerator : IAsyncEnumerator { private Func>> getEnumerator; private IAsyncEnumerator enumerator; public Enumerator(Func>> getEnumerator) { Debug.Assert(getEnumerator != null); this.getEnumerator = getEnumerator; } public T Current { get { if (enumerator == null) throw new InvalidOperationException(); return enumerator.Current; } } public async Task DisposeAsync() { var old = Interlocked.Exchange(ref enumerator, DisposedEnumerator.Instance); if (enumerator != null) { await enumerator.DisposeAsync().ConfigureAwait(false); } } public Task MoveNextAsync() { if (enumerator == null) { return InitAndMoveNextAsync(); } return enumerator.MoveNextAsync(); } private async Task InitAndMoveNextAsync() { try { enumerator = await getEnumerator().ConfigureAwait(false); } catch (Exception ex) { enumerator = Throw(ex).GetAsyncEnumerator(); throw; } finally { getEnumerator = null; } return await enumerator.MoveNextAsync().ConfigureAwait(false); } private sealed class DisposedEnumerator : IAsyncEnumerator { public static readonly DisposedEnumerator Instance = new DisposedEnumerator(); public T Current => throw new ObjectDisposedException("this"); public Task DisposeAsync() => TaskExt.CompletedTask; public Task MoveNextAsync() => throw new ObjectDisposedException("this"); } } } private sealed class AnonymousAsyncIterator : AsyncIterator { private readonly Func currentFunc; private readonly Func dispose; private readonly Func> moveNext; public AnonymousAsyncIterator(Func> moveNext, Func currentFunc, Func dispose) { Debug.Assert(moveNext != null); this.moveNext = moveNext; this.currentFunc = currentFunc; this.dispose = dispose; // Explicit call to initialize enumerator mode GetAsyncEnumerator(); } public override AsyncIterator Clone() { throw new NotSupportedException("AnonymousAsyncIterator cannot be cloned. It is only intended for use as an iterator."); } public override async Task DisposeAsync() { if (dispose != null) { await dispose().ConfigureAwait(false); } await base.DisposeAsync().ConfigureAwait(false); } protected override async Task MoveNextCore() { switch (state) { case AsyncIteratorState.Allocated: state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: if (await moveNext().ConfigureAwait(false)) { current = currentFunc(); return true; } await DisposeAsync().ConfigureAwait(false); break; } return false; } } } }