Oren Novotny 9 years ago
parent
commit
320f7d66f7
1 changed files with 71 additions and 23 deletions
  1. 71 23
      Ix.NET/Source/System.Interactive.Async/Zip.cs

+ 71 - 23
Ix.NET/Source/System.Interactive.Async/Zip.cs

@@ -5,6 +5,7 @@
 using System;
 using System.Collections.Generic;
 using System.Linq;
+using System.Threading;
 using System.Threading.Tasks;
 
 namespace System.Linq
@@ -20,30 +21,77 @@ namespace System.Linq
             if (selector == null)
                 throw new ArgumentNullException(nameof(selector));
 
-            return CreateEnumerable(
-                () =>
+            return new ZipAsyncIterator<TFirst, TSecond, TResult>(first, second, selector);
+        }
+
+        private sealed class ZipAsyncIterator<TFirst, TSecond, TResult> : AsyncIterator<TResult>
+        {
+            private readonly IAsyncEnumerable<TFirst> first;
+            private readonly IAsyncEnumerable<TSecond> second;
+            private readonly Func<TFirst, TSecond, TResult> selector;
+
+            private IAsyncEnumerator<TFirst> firstEnumerator;
+            private IAsyncEnumerator<TSecond> secondEnumerator;
+
+            public ZipAsyncIterator(IAsyncEnumerable<TFirst> first, IAsyncEnumerable<TSecond> second, Func<TFirst, TSecond, TResult> selector)
+            {
+                this.first = first;
+                this.second = second;
+                this.selector = selector;
+            }
+
+            public override AsyncIterator<TResult> Clone()
+            {
+                return new ZipAsyncIterator<TFirst, TSecond, TResult>(first, second, selector);
+            }
+
+            public override void Dispose()
+            {
+                if (firstEnumerator != null)
                 {
-                    var e1 = first.GetEnumerator();
-                    var e2 = second.GetEnumerator();
-                    var current = default(TResult);
-
-                    var cts = new CancellationTokenDisposable();
-                    var d = Disposable.Create(cts, e1, e2);
-
-                    return CreateEnumerator(
-                        ct => e1.MoveNext(cts.Token)
-                                .Zip(e2.MoveNext(cts.Token),
-                                     (f, s) =>
-                                     {
-                                         var result = f && s;
-                                         if (result)
-                                             current = selector(e1.Current, e2.Current);
-                                         return result;
-                                     }),
-                        () => current,
-                        d.Dispose
-                    );
-                });
+                    firstEnumerator.Dispose();
+                    firstEnumerator = null;
+                }
+                if (secondEnumerator != null)
+                {
+                    secondEnumerator.Dispose();
+                    secondEnumerator = null;
+                }
+
+                base.Dispose();
+            }
+
+            protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        firstEnumerator = first.GetEnumerator();
+                        secondEnumerator = second.GetEnumerator();
+
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+
+                        // We kick these off and join so they can potentially run in parallel
+                        var ft = firstEnumerator.MoveNext(cancellationToken);
+                        var st = secondEnumerator.MoveNext(cancellationToken);
+                        await Task.WhenAll(ft, st)
+                                  .ConfigureAwait(false);
+
+                        if (ft.Result && st.Result)
+                        {
+                            current = selector(firstEnumerator.Current, secondEnumerator.Current);
+                            return true;
+                        }
+
+                        Dispose();
+                        break;
+                }
+
+                return false;
+            }
         }
     }
 }