Oren Novotny 9 years ago
parent
commit
05a2f11a67
1 changed files with 75 additions and 42 deletions
  1. 75 42
      Ix.NET/Source/System.Interactive.Async/Do.cs

+ 75 - 42
Ix.NET/Source/System.Interactive.Async/Do.cs

@@ -4,6 +4,7 @@
 
 using System;
 using System.Collections.Generic;
+using System.Diagnostics;
 using System.Linq;
 using System.Threading;
 using System.Threading.Tasks;
@@ -19,7 +20,7 @@ namespace System.Linq
             if (onNext == null)
                 throw new ArgumentNullException(nameof(onNext));
 
-            return DoHelper(source, onNext, _ => { }, () => { });
+            return DoHelper(source, onNext, null, null);
         }
 
         public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action onCompleted)
@@ -31,7 +32,7 @@ namespace System.Linq
             if (onCompleted == null)
                 throw new ArgumentNullException(nameof(onCompleted));
 
-            return DoHelper(source, onNext, _ => { }, onCompleted);
+            return DoHelper(source, onNext, null, onCompleted);
         }
 
         public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError)
@@ -43,7 +44,7 @@ namespace System.Linq
             if (onError == null)
                 throw new ArgumentNullException(nameof(onError));
 
-            return DoHelper(source, onNext, onError, () => { });
+            return DoHelper(source, onNext, onError, null);
         }
 
         public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError, Action onCompleted)
@@ -72,52 +73,84 @@ namespace System.Linq
 
         private static IAsyncEnumerable<TSource> DoHelper<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError, Action onCompleted)
         {
-            return CreateEnumerable(
-                () =>
+            return new DoAsyncIterator<TSource>(source, onNext, onError, onCompleted);
+        }
+
+        private sealed class DoAsyncIterator<TSource> : AsyncIterator<TSource>
+        {
+            private readonly Action onCompleted;
+            private readonly Action<Exception> onError;
+            private readonly Action<TSource> onNext;
+            private readonly IAsyncEnumerable<TSource> source;
+
+            private IAsyncEnumerator<TSource> enumerator;
+
+            public DoAsyncIterator(IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError, Action onCompleted)
+            {
+                Debug.Assert(source != null);
+                Debug.Assert(onNext != null);
+
+                this.source = source;
+                this.onNext = onNext;
+                this.onError = onError;
+                this.onCompleted = onCompleted;
+            }
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new DoAsyncIterator<TSource>(source, onNext, onError, onCompleted);
+            }
+
+            public override void Dispose()
+            {
+                if (enumerator != null)
                 {
-                    var e = source.GetEnumerator();
+                    enumerator.Dispose();
+                    enumerator = null;
+                }
 
-                    var cts = new CancellationTokenDisposable();
-                    var d = Disposable.Create(cts, e);
+                base.Dispose();
+            }
 
-                    var current = default(TSource);
+            protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetEnumerator();
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
 
-                    var f = default(Func<CancellationToken, Task<bool>>);
-                    f = async ct =>
+                    case AsyncIteratorState.Iterating:
+                        try
                         {
-                            try
-                            {
-                                var result = await e.MoveNext(ct)
-                                                    .ConfigureAwait(false);
-                                if (!result)
-                                {
-                                    onCompleted();
-                                }
-                                else
-                                {
-                                    current = e.Current;
-                                    onNext(current);
-                                }
-                                return result;
-                            }
-                            catch (OperationCanceledException)
-                            {
-                                throw;
-                            }
-                            catch (Exception ex)
+                            if (await enumerator.MoveNext(cancellationToken)
+                                                .ConfigureAwait(false))
                             {
-                                onError(ex);
-                                throw;
+                                current = enumerator.Current;
+                                onNext(current);
+
+                                return true;
                             }
-                        };
-
-                    return CreateEnumerator(
-                        f,
-                        () => current,
-                        d.Dispose,
-                        e
-                    );
-                });
+                        }
+                        catch (OperationCanceledException)
+                        {
+                            throw;
+                        }
+                        catch (Exception ex)
+                        {
+                            onError?.Invoke(ex);
+                            throw;
+                        }
+
+                        onCompleted?.Invoke();
+
+                        Dispose();
+                        break;
+                }
+
+                return false;
+            }
         }
     }
 }