// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System.Collections.Generic; using System.Diagnostics; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerableEx { public static IAsyncEnumerable Defer(Func> factory) { if (factory == null) throw Error.ArgumentNull(nameof(factory)); #if USE_ASYNC_ITERATOR return AsyncEnumerable.Create(Core); async IAsyncEnumerator Core(CancellationToken cancellationToken) { await foreach (var item in factory().WithCancellation(cancellationToken).ConfigureAwait(false)) { yield return item; } } #else return new DeferIterator(factory); #endif } public static IAsyncEnumerable Defer(Func>> factory) { if (factory == null) throw Error.ArgumentNull(nameof(factory)); #if USE_ASYNC_ITERATOR return AsyncEnumerable.Create(Core); async IAsyncEnumerator Core(CancellationToken cancellationToken) { await foreach (var item in (await factory().ConfigureAwait(false)).WithCancellation(cancellationToken).ConfigureAwait(false)) { yield return item; } } #else return new AsyncDeferIterator(factory); #endif } #if !NO_DEEP_CANCELLATION public static IAsyncEnumerable Defer(Func>> factory) { if (factory == null) throw Error.ArgumentNull(nameof(factory)); #if USE_ASYNC_ITERATOR return AsyncEnumerable.Create(Core); async IAsyncEnumerator Core(CancellationToken cancellationToken) { await foreach (var item in (await factory(cancellationToken).ConfigureAwait(false)).WithCancellation(cancellationToken).ConfigureAwait(false)) { yield return item; } } #else return new AsyncDeferIteratorWithCancellation(factory); #endif } #endif #if !USE_ASYNC_ITERATOR private sealed class DeferIterator : AsyncIteratorBase { private readonly Func> _factory; private IAsyncEnumerator _enumerator; public DeferIterator(Func> factory) { Debug.Assert(factory != null); _factory = factory; } public override T Current => _enumerator == null ? default : _enumerator.Current; public override AsyncIteratorBase Clone() { return new DeferIterator(_factory); } public override async ValueTask DisposeAsync() { if (_enumerator != null) { await _enumerator.DisposeAsync().ConfigureAwait(false); _enumerator = null; } await base.DisposeAsync().ConfigureAwait(false); } protected override ValueTask MoveNextCore() { if (_enumerator == null) { return InitializeAndMoveNextAsync(); } return _enumerator.MoveNextAsync(); } private async ValueTask InitializeAndMoveNextAsync() { // NB: Using an async method to ensure any exception is reported via the task. try { _enumerator = _factory().GetAsyncEnumerator(_cancellationToken); } catch (Exception ex) { _enumerator = Throw(ex).GetAsyncEnumerator(_cancellationToken); throw; } return await _enumerator.MoveNextAsync().ConfigureAwait(false); } } private sealed class AsyncDeferIterator : AsyncIteratorBase { private readonly Func>> _factory; private IAsyncEnumerator _enumerator; public AsyncDeferIterator(Func< Task>> factory) { Debug.Assert(factory != null); _factory = factory; } public override T Current => _enumerator == null ? default : _enumerator.Current; public override AsyncIteratorBase Clone() { return new AsyncDeferIterator(_factory); } public override async ValueTask DisposeAsync() { if (_enumerator != null) { await _enumerator.DisposeAsync().ConfigureAwait(false); _enumerator = null; } await base.DisposeAsync().ConfigureAwait(false); } protected override ValueTask MoveNextCore() { if (_enumerator == null) { return InitializeAndMoveNextAsync(); } return _enumerator.MoveNextAsync(); } private async ValueTask InitializeAndMoveNextAsync() { try { _enumerator = (await _factory().ConfigureAwait(false)).GetAsyncEnumerator(_cancellationToken); } catch (Exception ex) { _enumerator = Throw(ex).GetAsyncEnumerator(_cancellationToken); throw; } return await _enumerator.MoveNextAsync().ConfigureAwait(false); } } #if !NO_DEEP_CANCELLATION private sealed class AsyncDeferIteratorWithCancellation : AsyncIteratorBase { private readonly Func>> _factory; private IAsyncEnumerator _enumerator; public AsyncDeferIteratorWithCancellation(Func>> factory) { Debug.Assert(factory != null); _factory = factory; } public override T Current => _enumerator == null ? default : _enumerator.Current; public override AsyncIteratorBase Clone() { return new AsyncDeferIteratorWithCancellation(_factory); } public override async ValueTask DisposeAsync() { if (_enumerator != null) { await _enumerator.DisposeAsync().ConfigureAwait(false); _enumerator = null; } await base.DisposeAsync().ConfigureAwait(false); } protected override ValueTask MoveNextCore() { if (_enumerator == null) { return InitializeAndMoveNextAsync(); } return _enumerator.MoveNextAsync(); } private async ValueTask InitializeAndMoveNextAsync() { try { _enumerator = (await _factory(_cancellationToken).ConfigureAwait(false)).GetAsyncEnumerator(_cancellationToken); } catch (Exception ex) { _enumerator = Throw(ex).GetAsyncEnumerator(_cancellationToken); throw; } return await _enumerator.MoveNextAsync().ConfigureAwait(false); } } #endif #endif } }