Browse Source

Async variant of Catch.

Bart De Smet 8 năm trước cách đây
mục cha
commit
91cdd73ba8

+ 106 - 3
Ix.NET/Source/System.Interactive.Async/Catch.cs

@@ -5,7 +5,6 @@
 using System.Collections.Generic;
 using System.Diagnostics;
 using System.Runtime.ExceptionServices;
-using System.Threading;
 using System.Threading.Tasks;
 
 namespace System.Linq
@@ -23,6 +22,17 @@ namespace System.Linq
             return new CatchAsyncIterator<TSource, TException>(source, handler);
         }
 
+        public static IAsyncEnumerable<TSource> Catch<TSource, TException>(this IAsyncEnumerable<TSource> source, Func<TException, Task<IAsyncEnumerable<TSource>>> handler)
+            where TException : Exception
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (handler == null)
+                throw new ArgumentNullException(nameof(handler));
+
+            return new CatchAsyncIteratorWithTask<TSource, TException>(source, handler);
+        }
+
         public static IAsyncEnumerable<TSource> Catch<TSource>(this IEnumerable<IAsyncEnumerable<TSource>> sources)
         {
             if (sources == null)
@@ -116,8 +126,101 @@ namespace System.Linq
                                     // Note: Ideally we'd dipose of the previous enumerator before
                                     // invoking the handler, but we use this order to preserve
                                     // current behavior
-                                    var err = handler(ex)
-                                        .GetAsyncEnumerator();
+                                    var inner = handler(ex);
+                                    var err = inner.GetAsyncEnumerator();
+
+                                    if (enumerator != null)
+                                    {
+                                        await enumerator.DisposeAsync().ConfigureAwait(false);
+                                    }
+
+                                    enumerator = err;
+                                    isDone = true;
+                                    continue; // loop so we hit the catch state
+                                }
+                            }
+
+                            if (await enumerator.MoveNextAsync().ConfigureAwait(false))
+                            {
+                                current = enumerator.Current;
+                                return true;
+                            }
+
+                            break; // while
+                        }
+
+                        break; // case
+                }
+
+                await DisposeAsync().ConfigureAwait(false);
+                return false;
+            }
+        }
+
+        private sealed class CatchAsyncIteratorWithTask<TSource, TException> : AsyncIterator<TSource> where TException : Exception
+        {
+            private readonly Func<TException, Task<IAsyncEnumerable<TSource>>> handler;
+            private readonly IAsyncEnumerable<TSource> source;
+
+            private IAsyncEnumerator<TSource> enumerator;
+            private bool isDone;
+
+            public CatchAsyncIteratorWithTask(IAsyncEnumerable<TSource> source, Func<TException, Task<IAsyncEnumerable<TSource>>> handler)
+            {
+                Debug.Assert(source != null);
+                Debug.Assert(handler != null);
+
+                this.source = source;
+                this.handler = handler;
+            }
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new CatchAsyncIteratorWithTask<TSource, TException>(source, handler);
+            }
+
+            public override async Task DisposeAsync()
+            {
+                if (enumerator != null)
+                {
+                    await enumerator.DisposeAsync().ConfigureAwait(false);
+                    enumerator = null;
+                }
+
+                await base.DisposeAsync().ConfigureAwait(false);
+            }
+
+            protected override async Task<bool> MoveNextCore()
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetAsyncEnumerator();
+                        isDone = false;
+
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        while (true)
+                        {
+                            if (!isDone)
+                            {
+                                try
+                                {
+                                    if (await enumerator.MoveNextAsync().ConfigureAwait(false))
+                                    {
+                                        current = enumerator.Current;
+                                        return true;
+                                    }
+                                }
+                                catch (TException ex)
+                                {
+                                    // Note: Ideally we'd dipose of the previous enumerator before
+                                    // invoking the handler, but we use this order to preserve
+                                    // current behavior
+                                    var inner = await handler(ex).ConfigureAwait(false);
+                                    var err = inner.GetAsyncEnumerator();
 
                                     if (enumerator != null)
                                     {

+ 1 - 3
Ix.NET/Source/Tests/AsyncTests.Exceptions.cs

@@ -5,9 +5,7 @@
 using System;
 using System.Collections.Generic;
 using System.Linq;
-using System.Text;
 using Xunit;
-using System.Threading;
 using System.Threading.Tasks;
 
 namespace Tests
@@ -17,7 +15,7 @@ namespace Tests
         [Fact]
         public void Catch_Null()
         {
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Catch<int, Exception>(default(IAsyncEnumerable<int>), x => null));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Catch<int, Exception>(default(IAsyncEnumerable<int>), x => default(IAsyncEnumerable<int>)));
             AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Catch<int, Exception>(AsyncEnumerable.Return(42), default(Func<Exception, IAsyncEnumerable<int>>)));
 
             AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Catch<int>(default(IAsyncEnumerable<int>), AsyncEnumerable.Return(42)));