// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System.Collections.Generic; using System.Diagnostics; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerableEx { public static IAsyncEnumerable Do(this IAsyncEnumerable source, Action onNext) { if (source == null) throw new ArgumentNullException(nameof(source)); if (onNext == null) throw new ArgumentNullException(nameof(onNext)); return DoCore(source, onNext: onNext, onError: null, onCompleted: null); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Action onNext, Action onCompleted) { if (source == null) throw new ArgumentNullException(nameof(source)); if (onNext == null) throw new ArgumentNullException(nameof(onNext)); if (onCompleted == null) throw new ArgumentNullException(nameof(onCompleted)); return DoCore(source, onNext: onNext, onError: null, onCompleted: onCompleted); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Action onNext, Action onError) { if (source == null) throw new ArgumentNullException(nameof(source)); if (onNext == null) throw new ArgumentNullException(nameof(onNext)); if (onError == null) throw new ArgumentNullException(nameof(onError)); return DoCore(source, onNext: onNext, onError: onError, onCompleted: null); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Action onNext, Action onError, Action onCompleted) { if (source == null) throw new ArgumentNullException(nameof(source)); if (onNext == null) throw new ArgumentNullException(nameof(onNext)); if (onError == null) throw new ArgumentNullException(nameof(onError)); if (onCompleted == null) throw new ArgumentNullException(nameof(onCompleted)); return DoCore(source, onNext, onError, onCompleted); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Func onNext) { if (source == null) throw new ArgumentNullException(nameof(source)); if (onNext == null) throw new ArgumentNullException(nameof(onNext)); return DoCore(source, onNext: onNext, onError: null, onCompleted: null); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Func onNext, Func onCompleted) { if (source == null) throw new ArgumentNullException(nameof(source)); if (onNext == null) throw new ArgumentNullException(nameof(onNext)); if (onCompleted == null) throw new ArgumentNullException(nameof(onCompleted)); return DoCore(source, onNext: onNext, onError: null, onCompleted: onCompleted); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Func onNext, Func onError) { if (source == null) throw new ArgumentNullException(nameof(source)); if (onNext == null) throw new ArgumentNullException(nameof(onNext)); if (onError == null) throw new ArgumentNullException(nameof(onError)); return DoCore(source, onNext: onNext, onError: onError, onCompleted: null); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Func onNext, Func onError, Func onCompleted) { if (source == null) throw new ArgumentNullException(nameof(source)); if (onNext == null) throw new ArgumentNullException(nameof(onNext)); if (onError == null) throw new ArgumentNullException(nameof(onError)); if (onCompleted == null) throw new ArgumentNullException(nameof(onCompleted)); return DoCore(source, onNext, onError, onCompleted); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, IObserver observer) { if (source == null) throw new ArgumentNullException(nameof(source)); if (observer == null) throw new ArgumentNullException(nameof(observer)); return DoCore(source, new Action(observer.OnNext), new Action(observer.OnError), new Action(observer.OnCompleted)); } private static IAsyncEnumerable DoCore(IAsyncEnumerable source, Action onNext, Action onError, Action onCompleted) { return new DoAsyncIterator(source, onNext, onError, onCompleted); } private static IAsyncEnumerable DoCore(IAsyncEnumerable source, Func onNext, Func onError, Func onCompleted) { return new DoAsyncIteratorWithTask(source, onNext, onError, onCompleted); } private sealed class DoAsyncIterator : AsyncIterator { private readonly Action onCompleted; private readonly Action onError; private readonly Action onNext; private readonly IAsyncEnumerable source; private IAsyncEnumerator enumerator; public DoAsyncIterator(IAsyncEnumerable source, Action onNext, Action onError, Action onCompleted) { Debug.Assert(source != null); Debug.Assert(onNext != null); this.source = source; this.onNext = onNext; this.onError = onError; this.onCompleted = onCompleted; } public override AsyncIterator Clone() { return new DoAsyncIterator(source, onNext, onError, onCompleted); } public override async ValueTask DisposeAsync() { if (enumerator != null) { await enumerator.DisposeAsync().ConfigureAwait(false); enumerator = null; } await base.DisposeAsync().ConfigureAwait(false); } protected override async ValueTask MoveNextCore(CancellationToken cancellationToken) { switch (state) { case AsyncIteratorState.Allocated: enumerator = source.GetAsyncEnumerator(cancellationToken); state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: try { if (await enumerator.MoveNextAsync().ConfigureAwait(false)) { current = enumerator.Current; onNext(current); return true; } } catch (OperationCanceledException) { throw; } catch (Exception ex) when (onError != null) { onError(ex); throw; } onCompleted?.Invoke(); await DisposeAsync().ConfigureAwait(false); break; } return false; } } private sealed class DoAsyncIteratorWithTask : AsyncIterator { private readonly Func onCompleted; private readonly Func onError; private readonly Func onNext; private readonly IAsyncEnumerable source; private IAsyncEnumerator enumerator; public DoAsyncIteratorWithTask(IAsyncEnumerable source, Func onNext, Func onError, Func onCompleted) { Debug.Assert(source != null); Debug.Assert(onNext != null); this.source = source; this.onNext = onNext; this.onError = onError; this.onCompleted = onCompleted; } public override AsyncIterator Clone() { return new DoAsyncIteratorWithTask(source, onNext, onError, onCompleted); } public override async ValueTask DisposeAsync() { if (enumerator != null) { await enumerator.DisposeAsync().ConfigureAwait(false); enumerator = null; } await base.DisposeAsync().ConfigureAwait(false); } protected override async ValueTask MoveNextCore(CancellationToken cancellationToken) { switch (state) { case AsyncIteratorState.Allocated: enumerator = source.GetAsyncEnumerator(cancellationToken); state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: try { if (await enumerator.MoveNextAsync().ConfigureAwait(false)) { current = enumerator.Current; await onNext(current).ConfigureAwait(false); return true; } } catch (OperationCanceledException) { throw; } catch (Exception ex) when (onError != null) { await onError(ex).ConfigureAwait(false); throw; } if (onCompleted != null) { await onCompleted().ConfigureAwait(false); } await DisposeAsync().ConfigureAwait(false); break; } return false; } } } }