// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System; using System.Collections.Generic; using System.Linq; using System.Runtime.ExceptionServices; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerable { public static IAsyncEnumerable Catch(this IAsyncEnumerable source, Func> handler) where TException : Exception { if (source == null) throw new ArgumentNullException(nameof(source)); if (handler == null) throw new ArgumentNullException(nameof(handler)); return Create(() => { var e = source.GetEnumerator(); var cts = new CancellationTokenDisposable(); var a = new AssignableDisposable { Disposable = e }; var d = Disposable.Create(cts, a); var done = false; var f = default(Func>); f = async ct => { if (!done) { try { return await e.MoveNext(ct) .ConfigureAwait(false); } catch (TException ex) { var err = handler(ex) .GetEnumerator(); e = err; a.Disposable = e; done = true; return await f(ct) .ConfigureAwait(false); } } return await e.MoveNext(ct) .ConfigureAwait(false); }; return Create( f, () => e.Current, d.Dispose, a ); }); } public static IAsyncEnumerable Catch(this IEnumerable> sources) { if (sources == null) throw new ArgumentNullException(nameof(sources)); return sources.Catch_(); } public static IAsyncEnumerable Catch(params IAsyncEnumerable[] sources) { if (sources == null) throw new ArgumentNullException(nameof(sources)); return sources.Catch_(); } public static IAsyncEnumerable Catch(this IAsyncEnumerable first, IAsyncEnumerable second) { if (first == null) throw new ArgumentNullException(nameof(first)); if (second == null) throw new ArgumentNullException(nameof(second)); return new[] { first, second }.Catch_(); } private static IAsyncEnumerable Catch_(this IEnumerable> sources) { return Create(() => { var se = sources.GetEnumerator(); var e = default(IAsyncEnumerator); var cts = new CancellationTokenDisposable(); var a = new AssignableDisposable(); var d = Disposable.Create(cts, se, a); var error = default(ExceptionDispatchInfo); var f = default(Func>); f = async ct => { if (e == null) { if (se.MoveNext()) { e = se.Current.GetEnumerator(); } else { error?.Throw(); return false; } error = null; a.Disposable = e; } try { return await e.MoveNext(ct) .ConfigureAwait(false); } catch (Exception exception) { e.Dispose(); e = null; error = ExceptionDispatchInfo.Capture(exception); return await f(ct) .ConfigureAwait(false); } }; return Create( f, () => e.Current, d.Dispose, a ); }); } } }