Explorar o código

Use async iterators in Catch.

Bart De Smet %!s(int64=6) %!d(string=hai) anos
pai
achega
c80b4f1987

+ 190 - 0
Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Catch.cs

@@ -12,6 +12,11 @@ namespace System.Linq
 {
     public static partial class AsyncEnumerableEx
     {
+        // REVIEW: All Catch operators may catch OperationCanceledException due to cancellation of the enumeration
+        //         of the source. Should we explicitly avoid handling this? E.g. as follows:
+        //
+        //         catch (TException ex) when(!(ex is OperationCanceledException oce && oce.CancellationToken == cancellationToken))
+
         public static IAsyncEnumerable<TSource> Catch<TSource, TException>(this IAsyncEnumerable<TSource> source, Func<TException, IAsyncEnumerable<TSource>> handler)
             where TException : Exception
         {
@@ -20,7 +25,54 @@ namespace System.Linq
             if (handler == null)
                 throw Error.ArgumentNull(nameof(handler));
 
+#if USE_ASYNC_ITERATOR
+            return Create(Core);
+
+            async IAsyncEnumerator<TSource> Core(CancellationToken cancellationToken)
+            {
+                // REVIEW: This implementation mirrors the Ix implementation, which does not protect GetEnumerator
+                //         using the try statement either. A more trivial implementation would use await foreach
+                //         and protect the entire loop using a try statement, with two breaking changes:
+                //
+                //         - Also protecting the call to GetAsyncEnumerator by the try statement.
+                //         - Invocation of the handler after disposal of the failed first sequence.
+
+                var err = default(IAsyncEnumerable<TSource>);
+
+                await using (var e = source.GetAsyncEnumerator(cancellationToken).ConfigureAwait(false))
+                {
+                    while (true)
+                    {
+                        var c = default(TSource);
+
+                        try
+                        {
+                            if (!await e.MoveNextAsync())
+                                break;
+
+                            c = e.Current;
+                        }
+                        catch (TException ex)
+                        {
+                            err = handler(ex);
+                            break;
+                        }
+
+                        yield return c;
+                    }
+                }
+
+                if (err != null)
+                {
+                    await foreach (var item in err.WithCancellation(cancellationToken).ConfigureAwait(false))
+                    {
+                        yield return item;
+                    }
+                }
+            }
+#else
             return new CatchAsyncIterator<TSource, TException>(source, handler);
+#endif
         }
 
         public static IAsyncEnumerable<TSource> Catch<TSource, TException>(this IAsyncEnumerable<TSource> source, Func<TException, ValueTask<IAsyncEnumerable<TSource>>> handler)
@@ -31,7 +83,54 @@ namespace System.Linq
             if (handler == null)
                 throw Error.ArgumentNull(nameof(handler));
 
+#if USE_ASYNC_ITERATOR
+            return Create(Core);
+
+            async IAsyncEnumerator<TSource> Core(CancellationToken cancellationToken)
+            {
+                // REVIEW: This implementation mirrors the Ix implementation, which does not protect GetEnumerator
+                //         using the try statement either. A more trivial implementation would use await foreach
+                //         and protect the entire loop using a try statement, with two breaking changes:
+                //
+                //         - Also protecting the call to GetAsyncEnumerator by the try statement.
+                //         - Invocation of the handler after disposal of the failed first sequence.
+
+                var err = default(IAsyncEnumerable<TSource>);
+
+                await using (var e = source.GetAsyncEnumerator(cancellationToken).ConfigureAwait(false))
+                {
+                    while (true)
+                    {
+                        var c = default(TSource);
+
+                        try
+                        {
+                            if (!await e.MoveNextAsync())
+                                break;
+
+                            c = e.Current;
+                        }
+                        catch (TException ex)
+                        {
+                            err = await handler(ex).ConfigureAwait(false);
+                            break;
+                        }
+
+                        yield return c;
+                    }
+                }
+
+                if (err != null)
+                {
+                    await foreach (var item in err.WithCancellation(cancellationToken).ConfigureAwait(false))
+                    {
+                        yield return item;
+                    }
+                }
+            }
+#else
             return new CatchAsyncIteratorWithTask<TSource, TException>(source, handler);
+#endif
         }
 
 #if !NO_DEEP_CANCELLATION
@@ -43,7 +142,54 @@ namespace System.Linq
             if (handler == null)
                 throw Error.ArgumentNull(nameof(handler));
 
+#if USE_ASYNC_ITERATOR
+            return Create(Core);
+
+            async IAsyncEnumerator<TSource> Core(CancellationToken cancellationToken)
+            {
+                // REVIEW: This implementation mirrors the Ix implementation, which does not protect GetEnumerator
+                //         using the try statement either. A more trivial implementation would use await foreach
+                //         and protect the entire loop using a try statement, with two breaking changes:
+                //
+                //         - Also protecting the call to GetAsyncEnumerator by the try statement.
+                //         - Invocation of the handler after disposal of the failed first sequence.
+
+                var err = default(IAsyncEnumerable<TSource>);
+
+                await using (var e = source.GetAsyncEnumerator(cancellationToken).ConfigureAwait(false))
+                {
+                    while (true)
+                    {
+                        var c = default(TSource);
+
+                        try
+                        {
+                            if (!await e.MoveNextAsync())
+                                break;
+
+                            c = e.Current;
+                        }
+                        catch (TException ex)
+                        {
+                            err = await handler(ex, cancellationToken).ConfigureAwait(false);
+                            break;
+                        }
+
+                        yield return c;
+                    }
+                }
+
+                if (err != null)
+                {
+                    await foreach (var item in err.WithCancellation(cancellationToken).ConfigureAwait(false))
+                    {
+                        yield return item;
+                    }
+                }
+            }
+#else
             return new CatchAsyncIteratorWithTaskAndCancellation<TSource, TException>(source, handler);
+#endif
         }
 #endif
 
@@ -75,9 +221,52 @@ namespace System.Linq
 
         private static IAsyncEnumerable<TSource> CatchCore<TSource>(IEnumerable<IAsyncEnumerable<TSource>> sources)
         {
+#if USE_ASYNC_ITERATOR
+            return Create(Core);
+
+            async IAsyncEnumerator<TSource> Core(CancellationToken cancellationToken)
+            {
+                var error = default(ExceptionDispatchInfo);
+
+                foreach (var source in sources)
+                {
+                    await using (var e = source.GetAsyncEnumerator(cancellationToken).ConfigureAwait(false))
+                    {
+                        error = null;
+
+                        while (true)
+                        {
+                            var c = default(TSource);
+
+                            try
+                            {
+                                if (!await e.MoveNextAsync())
+                                    break;
+
+                                c = e.Current;
+                            }
+                            catch (Exception ex)
+                            {
+                                error = ExceptionDispatchInfo.Capture(ex);
+                                break;
+                            }
+
+                            yield return c;
+                        }
+
+                        if (error == null)
+                            break;
+                    }
+                }
+
+                error?.Throw();
+            }
+#else
             return new CatchAsyncIterator<TSource>(sources);
+#endif
         }
 
+#if !USE_ASYNC_ITERATOR
         private sealed class CatchAsyncIterator<TSource, TException> : AsyncIterator<TSource> where TException : Exception
         {
             private readonly Func<TException, IAsyncEnumerable<TSource>> _handler;
@@ -453,4 +642,5 @@ namespace System.Linq
             }
         }
     }
+#endif
 }