Oren Novotny 9 years ago
parent
commit
f05ef936ba
1 changed files with 154 additions and 84 deletions
  1. 154 84
      Ix.NET/Source/System.Interactive.Async/Catch.cs

+ 154 - 84
Ix.NET/Source/System.Interactive.Async/Catch.cs

@@ -21,51 +21,7 @@ namespace System.Linq
             if (handler == null)
                 throw new ArgumentNullException(nameof(handler));
 
-            return CreateEnumerable(
-                () =>
-                {
-                    var e = source.GetEnumerator();
-
-                    var cts = new CancellationTokenDisposable();
-                    var a = new AssignableDisposable
-                    {
-                        Disposable = e
-                    };
-                    var d = Disposable.Create(cts, a);
-                    var done = false;
-
-                    var f = default(Func<CancellationToken, Task<bool>>);
-                    f = async ct =>
-                        {
-                            if (!done)
-                            {
-                                try
-                                {
-                                    return await e.MoveNext(ct)
-                                                  .ConfigureAwait(false);
-                                }
-                                catch (TException ex)
-                                {
-                                    var err = handler(ex)
-                                        .GetEnumerator();
-                                    e = err;
-                                    a.Disposable = e;
-                                    done = true;
-                                    return await f(ct)
-                                               .ConfigureAwait(false);
-                                }
-                            }
-                            return await e.MoveNext(ct)
-                                          .ConfigureAwait(false);
-                        };
-
-                    return CreateEnumerator(
-                        f,
-                        () => e.Current,
-                        d.Dispose,
-                        a
-                    );
-                });
+            return new CatchAsyncIterator<TSource, TException>(source, handler);
         }
 
         public static IAsyncEnumerable<TSource> Catch<TSource>(this IEnumerable<IAsyncEnumerable<TSource>> sources)
@@ -96,60 +52,174 @@ namespace System.Linq
 
         private static IAsyncEnumerable<TSource> Catch_<TSource>(this IEnumerable<IAsyncEnumerable<TSource>> sources)
         {
-            return CreateEnumerable(
-                () =>
+            return new CatchAsyncIterator<TSource>(sources);
+        }
+
+        private sealed class CatchAsyncIterator<TSource, TException> : AsyncIterator<TSource> where TException : Exception
+        {
+            private readonly Func<TException, IAsyncEnumerable<TSource>> handler;
+            private readonly IAsyncEnumerable<TSource> source;
+
+            private IAsyncEnumerator<TSource> enumerator;
+            private bool isDone;
+
+            public CatchAsyncIterator(IAsyncEnumerable<TSource> source, Func<TException, IAsyncEnumerable<TSource>> handler)
+            {
+                this.source = source;
+                this.handler = handler;
+            }
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new CatchAsyncIterator<TSource, TException>(source, handler);
+            }
+
+            public override void Dispose()
+            {
+                if (enumerator != null)
                 {
-                    var se = sources.GetEnumerator();
-                    var e = default(IAsyncEnumerator<TSource>);
+                    enumerator.Dispose();
+                    enumerator = null;
+                }
 
-                    var cts = new CancellationTokenDisposable();
-                    var a = new AssignableDisposable();
-                    var d = Disposable.Create(cts, se, a);
+                base.Dispose();
+            }
 
-                    var error = default(ExceptionDispatchInfo);
+            protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetEnumerator();
+                        isDone = false;
 
-                    var f = default(Func<CancellationToken, Task<bool>>);
-                    f = async ct =>
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        if (!isDone)
                         {
-                            if (e == null)
+                            try
                             {
-                                if (se.MoveNext())
-                                {
-                                    e = se.Current.GetEnumerator();
-                                }
-                                else
+                                if (await enumerator.MoveNext(cancellationToken)
+                                                    .ConfigureAwait(false))
                                 {
-                                    error?.Throw();
-                                    return false;
+                                    current = enumerator.Current;
+                                    return true;
                                 }
+                            }
+                            catch (TException ex)
+                            {
+                                var err = handler(ex)
+                                    .GetEnumerator();
+                                enumerator?.Dispose();
+                                enumerator = err;
+                                isDone = true;
+                                goto case AsyncIteratorState.Iterating; // loop so we hit the catch state
+                            }
+                        }
 
-                                error = null;
+                        if (await enumerator.MoveNext(cancellationToken)
+                                            .ConfigureAwait(false))
+                        {
+                            current = enumerator.Current;
+                            return true;
+                        }
 
-                                a.Disposable = e;
-                            }
+                        break;
+                }
 
-                            try
+                Dispose();
+                return false;
+            }
+        }
+
+        private sealed class CatchAsyncIterator<TSource> : AsyncIterator<TSource>
+        {
+            private readonly IEnumerable<IAsyncEnumerable<TSource>> sources;
+            private IAsyncEnumerator<TSource> enumerator;
+            private ExceptionDispatchInfo error;
+
+            private IEnumerator<IAsyncEnumerable<TSource>> sourcesEnumerator;
+
+            public CatchAsyncIterator(IEnumerable<IAsyncEnumerable<TSource>> sources)
+            {
+                this.sources = sources;
+            }
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new CatchAsyncIterator<TSource>(sources);
+            }
+
+            public override void Dispose()
+            {
+                if (sourcesEnumerator != null)
+                {
+                    sourcesEnumerator.Dispose();
+                    sourcesEnumerator = null;
+                }
+
+                if (enumerator != null)
+                {
+                    enumerator.Dispose();
+                    enumerator = null;
+                }
+
+                error = null;
+
+                base.Dispose();
+            }
+
+            protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        sourcesEnumerator = sources.GetEnumerator();
+
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        if (enumerator == null)
+                        {
+                            if (!sourcesEnumerator.MoveNext())
                             {
-                                return await e.MoveNext(ct)
-                                              .ConfigureAwait(false);
+                                // only throw if we have an error on the last one
+                                error?.Throw();
+                                break; // done, nothing else to do
                             }
-                            catch (Exception exception)
+
+                            error = null;
+                            enumerator = sourcesEnumerator.Current.GetEnumerator();
+                        }
+
+                        try
+                        {
+                            if (await enumerator.MoveNext(cancellationToken)
+                                                .ConfigureAwait(false))
                             {
-                                e.Dispose();
-                                e = null;
-                                error = ExceptionDispatchInfo.Capture(exception);
-                                return await f(ct)
-                                           .ConfigureAwait(false);
+                                current = enumerator.Current;
+                                return true;
                             }
-                        };
-
-                    return CreateEnumerator(
-                        f,
-                        () => e.Current,
-                        d.Dispose,
-                        a
-                    );
-                });
+                        }
+                        catch (Exception ex)
+                        {
+                            // Done with the current one, go to the next
+                            enumerator.Dispose();
+                            enumerator = null;
+                            error = ExceptionDispatchInfo.Capture(ex);
+                            goto case AsyncIteratorState.Iterating;
+                        }
+
+                        break;
+                }
+
+
+                Dispose();
+                return false;
+            }
         }
     }
 }