// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerableEx { // REVIEW: Should we convert Task-based overloads to ValueTask? public static IAsyncEnumerable Do(this IAsyncEnumerable source, Action onNext) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); return DoCore(source, onNext: onNext, onError: null, onCompleted: null); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Action onNext, Action onCompleted) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); if (onCompleted == null) throw Error.ArgumentNull(nameof(onCompleted)); return DoCore(source, onNext: onNext, onError: null, onCompleted: onCompleted); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Action onNext, Action onError) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); if (onError == null) throw Error.ArgumentNull(nameof(onError)); return DoCore(source, onNext: onNext, onError: onError, onCompleted: null); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Action onNext, Action onError, Action onCompleted) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); if (onError == null) throw Error.ArgumentNull(nameof(onError)); if (onCompleted == null) throw Error.ArgumentNull(nameof(onCompleted)); return DoCore(source, onNext, onError, onCompleted); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Func onNext) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); return DoCore(source, onNext: onNext, onError: null, onCompleted: null); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Func onNext, Func onCompleted) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); if (onCompleted == null) throw Error.ArgumentNull(nameof(onCompleted)); return DoCore(source, onNext: onNext, onError: null, onCompleted: onCompleted); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Func onNext, Func onError) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); if (onError == null) throw Error.ArgumentNull(nameof(onError)); return DoCore(source, onNext: onNext, onError: onError, onCompleted: null); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Func onNext, Func onError, Func onCompleted) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); if (onError == null) throw Error.ArgumentNull(nameof(onError)); if (onCompleted == null) throw Error.ArgumentNull(nameof(onCompleted)); return DoCore(source, onNext, onError, onCompleted); } #if !NO_DEEP_CANCELLATION public static IAsyncEnumerable Do(this IAsyncEnumerable source, Func onNext) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); return DoCore(source, onNext: onNext, onError: null, onCompleted: null); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Func onNext, Func onCompleted) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); if (onCompleted == null) throw Error.ArgumentNull(nameof(onCompleted)); return DoCore(source, onNext: onNext, onError: null, onCompleted: onCompleted); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Func onNext, Func onError) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); if (onError == null) throw Error.ArgumentNull(nameof(onError)); return DoCore(source, onNext: onNext, onError: onError, onCompleted: null); } public static IAsyncEnumerable Do(this IAsyncEnumerable source, Func onNext, Func onError, Func onCompleted) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (onNext == null) throw Error.ArgumentNull(nameof(onNext)); if (onError == null) throw Error.ArgumentNull(nameof(onError)); if (onCompleted == null) throw Error.ArgumentNull(nameof(onCompleted)); return DoCore(source, onNext, onError, onCompleted); } #endif public static IAsyncEnumerable Do(this IAsyncEnumerable source, IObserver observer) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (observer == null) throw Error.ArgumentNull(nameof(observer)); return DoCore(source, new Action(observer.OnNext), new Action(observer.OnError), new Action(observer.OnCompleted)); } private static IAsyncEnumerable DoCore(IAsyncEnumerable source, Action onNext, Action onError, Action onCompleted) { return AsyncEnumerable.Create(Core); async IAsyncEnumerator Core(CancellationToken cancellationToken) { await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false)) { while (true) { TSource item; try { if (!await e.MoveNextAsync()) { break; } item = e.Current; onNext(item); } catch (OperationCanceledException) { throw; } catch (Exception ex) when (onError != null) { onError(ex); throw; } yield return item; } onCompleted?.Invoke(); } } } private static IAsyncEnumerable DoCore(IAsyncEnumerable source, Func onNext, Func onError, Func onCompleted) { return AsyncEnumerable.Create(Core); async IAsyncEnumerator Core(CancellationToken cancellationToken) { await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false)) { while (true) { TSource item; try { if (!await e.MoveNextAsync()) { break; } item = e.Current; await onNext(item).ConfigureAwait(false); } catch (OperationCanceledException) { throw; } catch (Exception ex) when (onError != null) { await onError(ex).ConfigureAwait(false); throw; } yield return item; } if (onCompleted != null) { await onCompleted().ConfigureAwait(false); } } } } #if !NO_DEEP_CANCELLATION private static IAsyncEnumerable DoCore(IAsyncEnumerable source, Func onNext, Func onError, Func onCompleted) { return AsyncEnumerable.Create(Core); async IAsyncEnumerator Core(CancellationToken cancellationToken) { await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false)) { while (true) { TSource item; try { if (!await e.MoveNextAsync()) { break; } item = e.Current; await onNext(item, cancellationToken).ConfigureAwait(false); } catch (OperationCanceledException) { throw; } catch (Exception ex) when (onError != null) { await onError(ex, cancellationToken).ConfigureAwait(false); throw; } yield return item; } if (onCompleted != null) { await onCompleted(cancellationToken).ConfigureAwait(false); } } } } #endif } }