// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System.Collections.Generic; using System.Diagnostics; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerableEx { // REVIEW: Should we convert Task-based overloads to ValueTask? public static IAsyncEnumerable Do(this IAsyncEnumerable source, Action onNext) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); return DoCore(source, onNext: onNext, onError: null, onCompleted: null); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Action onNext, Action onCompleted) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); if (onCompleted == null) throw Error.ArgumentNull(nameof(onCompleted)); return DoCore(source, onNext: onNext, onError: null, onCompleted: onCompleted); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Action onNext, Action onError) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); if (onError == null) throw Error.ArgumentNull(nameof(onError)); return DoCore(source, onNext: onNext, onError: onError, onCompleted: null); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Action onNext, Action onError, Action onCompleted) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); if (onError == null) throw Error.ArgumentNull(nameof(onError)); if (onCompleted == null) throw Error.ArgumentNull(nameof(onCompleted)); return DoCore(source, onNext, onError, onCompleted); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Func onNext) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); return DoCore(source, onNext: onNext, onError: null, onCompleted: null); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Func onNext, Func onCompleted) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); if (onCompleted == null) throw Error.ArgumentNull(nameof(onCompleted)); return DoCore(source, onNext: onNext, onError: null, onCompleted: onCompleted); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Func onNext, Func onError) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); if (onError == null) throw Error.ArgumentNull(nameof(onError)); return DoCore(source, onNext: onNext, onError: onError, onCompleted: null); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Func onNext, Func onError, Func onCompleted) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); if (onError == null) throw Error.ArgumentNull(nameof(onError)); if (onCompleted == null) throw Error.ArgumentNull(nameof(onCompleted)); return DoCore(source, onNext, onError, onCompleted); } #if !NO_DEEP_CANCELLATION public static IAsyncEnumerable Do(this IAsyncEnumerable source, Func onNext) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); return DoCore(source, onNext: onNext, onError: null, onCompleted: null); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Func onNext, Func onCompleted) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); if (onCompleted == null) throw Error.ArgumentNull(nameof(onCompleted)); return DoCore(source, onNext: onNext, onError: null, onCompleted: onCompleted); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Func onNext, Func onError) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); if (onError == null) throw Error.ArgumentNull(nameof(onError)); return DoCore(source, onNext: onNext, onError: onError, onCompleted: null); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Func onNext, Func onError, Func onCompleted) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); if (onError == null) throw Error.ArgumentNull(nameof(onError)); if (onCompleted == null) throw Error.ArgumentNull(nameof(onCompleted)); return DoCore(source, onNext, onError, onCompleted); } #endif public static IAsyncEnumerable Do(this IAsyncEnumerable source, IObserver observer) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (observer == null) throw Error.ArgumentNull(nameof(observer)); return DoCore(source, new Action(observer.OnNext), new Action(observer.OnError), new Action(observer.OnCompleted)); } private static IAsyncEnumerable DoCore(IAsyncEnumerable source, Action onNext, Action onError, Action onCompleted) { #if USE_ASYNC_ITERATOR return AsyncEnumerable.Create(Core); async IAsyncEnumerator Core(CancellationToken cancellationToken) { var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false); try // TODO: Switch to `await using` in preview 3 (https://github.com/dotnet/roslyn/pull/32731) { while (true) { TSource item; try { if (!await e.MoveNextAsync()) { break; } item = e.Current; onNext(item); } catch (OperationCanceledException) { throw; } catch (Exception ex) when (onError != null) { onError(ex); throw; } yield return item; } onCompleted?.Invoke(); } finally { await e.DisposeAsync(); } } #else return new DoAsyncIterator(source, onNext, onError, onCompleted); #endif } private static IAsyncEnumerable DoCore(IAsyncEnumerable source, Func onNext, Func onError, Func onCompleted) { #if USE_ASYNC_ITERATOR return AsyncEnumerable.Create(Core); async IAsyncEnumerator Core(CancellationToken cancellationToken) { var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false); try // TODO: Switch to `await using` in preview 3 (https://github.com/dotnet/roslyn/pull/32731) { while (true) { TSource item; try { if (!await e.MoveNextAsync()) { break; } item = e.Current; await onNext(item).ConfigureAwait(false); } catch (OperationCanceledException) { throw; } catch (Exception ex) when (onError != null) { await onError(ex).ConfigureAwait(false); throw; } yield return item; } if (onCompleted != null) { await onCompleted().ConfigureAwait(false); } } finally { await e.DisposeAsync(); } } #else return new DoAsyncIteratorWithTask(source, onNext, onError, onCompleted); #endif } #if !NO_DEEP_CANCELLATION private static IAsyncEnumerable DoCore(IAsyncEnumerable source, Func onNext, Func onError, Func onCompleted) { #if USE_ASYNC_ITERATOR return AsyncEnumerable.Create(Core); async IAsyncEnumerator Core(CancellationToken cancellationToken) { var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false); try // TODO: Switch to `await using` in preview 3 (https://github.com/dotnet/roslyn/pull/32731) { while (true) { TSource item; try { if (!await e.MoveNextAsync()) { break; } item = e.Current; await onNext(item, cancellationToken).ConfigureAwait(false); } catch (OperationCanceledException) { throw; } catch (Exception ex) when (onError != null) { await onError(ex, cancellationToken).ConfigureAwait(false); throw; } yield return item; } if (onCompleted != null) { await onCompleted(cancellationToken).ConfigureAwait(false); } } finally { await e.DisposeAsync(); } } #else return new DoAsyncIteratorWithTaskAndCancellation(source, onNext, onError, onCompleted); #endif } #endif #if !USE_ASYNC_ITERATOR private sealed class DoAsyncIterator : AsyncIterator { private readonly Action _onCompleted; private readonly Action _onError; private readonly Action _onNext; private readonly IAsyncEnumerable _source; private IAsyncEnumerator _enumerator; public DoAsyncIterator(IAsyncEnumerable source, Action onNext, Action onError, Action onCompleted) { Debug.Assert(source != null); Debug.Assert(onNext != null); _source = source; _onNext = onNext; _onError = onError; _onCompleted = onCompleted; } public override AsyncIteratorBase Clone() { return new DoAsyncIterator(_source, _onNext, _onError, _onCompleted); } public override async ValueTask DisposeAsync() { if (_enumerator != null) { await _enumerator.DisposeAsync().ConfigureAwait(false); _enumerator = null; } await base.DisposeAsync().ConfigureAwait(false); } protected override async ValueTask MoveNextCore() { switch (_state) { case AsyncIteratorState.Allocated: _enumerator = _source.GetAsyncEnumerator(_cancellationToken); _state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: try { if (await _enumerator.MoveNextAsync().ConfigureAwait(false)) { _current = _enumerator.Current; _onNext(_current); return true; } } catch (OperationCanceledException) { throw; } catch (Exception ex) when (_onError != null) { _onError(ex); throw; } _onCompleted?.Invoke(); await DisposeAsync().ConfigureAwait(false); break; } return false; } } private sealed class DoAsyncIteratorWithTask : AsyncIterator { private readonly Func _onCompleted; private readonly Func _onError; private readonly Func _onNext; private readonly IAsyncEnumerable _source; private IAsyncEnumerator _enumerator; public DoAsyncIteratorWithTask(IAsyncEnumerable source, Func onNext, Func onError, Func onCompleted) { Debug.Assert(source != null); Debug.Assert(onNext != null); _source = source; _onNext = onNext; _onError = onError; _onCompleted = onCompleted; } public override AsyncIteratorBase Clone() { return new DoAsyncIteratorWithTask(_source, _onNext, _onError, _onCompleted); } public override async ValueTask DisposeAsync() { if (_enumerator != null) { await _enumerator.DisposeAsync().ConfigureAwait(false); _enumerator = null; } await base.DisposeAsync().ConfigureAwait(false); } protected override async ValueTask MoveNextCore() { switch (_state) { case AsyncIteratorState.Allocated: _enumerator = _source.GetAsyncEnumerator(_cancellationToken); _state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: try { if (await _enumerator.MoveNextAsync().ConfigureAwait(false)) { _current = _enumerator.Current; await _onNext(_current).ConfigureAwait(false); return true; } } catch (OperationCanceledException) { throw; } catch (Exception ex) when (_onError != null) { await _onError(ex).ConfigureAwait(false); throw; } if (_onCompleted != null) { await _onCompleted().ConfigureAwait(false); } await DisposeAsync().ConfigureAwait(false); break; } return false; } } #if !NO_DEEP_CANCELLATION private sealed class DoAsyncIteratorWithTaskAndCancellation : AsyncIterator { private readonly Func _onCompleted; private readonly Func _onError; private readonly Func _onNext; private readonly IAsyncEnumerable _source; private IAsyncEnumerator _enumerator; public DoAsyncIteratorWithTaskAndCancellation(IAsyncEnumerable source, Func onNext, Func onError, Func onCompleted) { Debug.Assert(source != null); Debug.Assert(onNext != null); _source = source; _onNext = onNext; _onError = onError; _onCompleted = onCompleted; } public override AsyncIteratorBase Clone() { return new DoAsyncIteratorWithTaskAndCancellation(_source, _onNext, _onError, _onCompleted); } public override async ValueTask DisposeAsync() { if (_enumerator != null) { await _enumerator.DisposeAsync().ConfigureAwait(false); _enumerator = null; } await base.DisposeAsync().ConfigureAwait(false); } protected override async ValueTask MoveNextCore() { switch (_state) { case AsyncIteratorState.Allocated: _enumerator = _source.GetAsyncEnumerator(_cancellationToken); _state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: try { if (await _enumerator.MoveNextAsync().ConfigureAwait(false)) { _current = _enumerator.Current; await _onNext(_current, _cancellationToken).ConfigureAwait(false); return true; } } catch (OperationCanceledException) { throw; } catch (Exception ex) when (_onError != null) { await _onError(ex, _cancellationToken).ConfigureAwait(false); throw; } if (_onCompleted != null) { await _onCompleted(_cancellationToken).ConfigureAwait(false); } await DisposeAsync().ConfigureAwait(false); break; } return false; } } #endif #endif } }