// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System.Collections.Generic; using System.Diagnostics; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerable { public static IAsyncEnumerable Do(this IAsyncEnumerable source, Action onNext) { if (source == null) throw new ArgumentNullException(nameof(source)); if (onNext == null) throw new ArgumentNullException(nameof(onNext)); return DoHelper(source, onNext, null, null); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Action onNext, Action onCompleted) { if (source == null) throw new ArgumentNullException(nameof(source)); if (onNext == null) throw new ArgumentNullException(nameof(onNext)); if (onCompleted == null) throw new ArgumentNullException(nameof(onCompleted)); return DoHelper(source, onNext, null, onCompleted); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Action onNext, Action onError) { if (source == null) throw new ArgumentNullException(nameof(source)); if (onNext == null) throw new ArgumentNullException(nameof(onNext)); if (onError == null) throw new ArgumentNullException(nameof(onError)); return DoHelper(source, onNext, onError, null); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Action onNext, Action onError, Action onCompleted) { if (source == null) throw new ArgumentNullException(nameof(source)); if (onNext == null) throw new ArgumentNullException(nameof(onNext)); if (onError == null) throw new ArgumentNullException(nameof(onError)); if (onCompleted == null) throw new ArgumentNullException(nameof(onCompleted)); return DoHelper(source, onNext, onError, onCompleted); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, IObserver observer) { if (source == null) throw new ArgumentNullException(nameof(source)); if (observer == null) throw new ArgumentNullException(nameof(observer)); return DoHelper(source, observer.OnNext, observer.OnError, observer.OnCompleted); } private static IAsyncEnumerable DoHelper(this IAsyncEnumerable source, Action onNext, Action onError, Action onCompleted) { return new DoAsyncIterator(source, onNext, onError, onCompleted); } private sealed class DoAsyncIterator : AsyncIterator { private readonly Action onCompleted; private readonly Action onError; private readonly Action onNext; private readonly IAsyncEnumerable source; private IAsyncEnumerator enumerator; public DoAsyncIterator(IAsyncEnumerable source, Action onNext, Action onError, Action onCompleted) { Debug.Assert(source != null); Debug.Assert(onNext != null); this.source = source; this.onNext = onNext; this.onError = onError; this.onCompleted = onCompleted; } public override AsyncIterator Clone() { return new DoAsyncIterator(source, onNext, onError, onCompleted); } public override void Dispose() { if (enumerator != null) { enumerator.Dispose(); enumerator = null; } base.Dispose(); } protected override async Task MoveNextCore(CancellationToken cancellationToken) { switch (state) { case AsyncIteratorState.Allocated: enumerator = source.GetEnumerator(); state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: try { if (await enumerator.MoveNext(cancellationToken) .ConfigureAwait(false)) { current = enumerator.Current; onNext(current); return true; } } catch (OperationCanceledException) { throw; } catch (Exception ex) { onError?.Invoke(ex); throw; } onCompleted?.Invoke(); Dispose(); break; } return false; } } } }