Oren Novotny 9 years ago
parent
commit
6f5173469c

+ 92 - 27
Ix.NET/Source/System.Interactive.Async/Except.cs

@@ -31,40 +31,105 @@ namespace System.Linq
             if (comparer == null)
                 throw new ArgumentNullException(nameof(comparer));
 
-            return CreateEnumerable(
-                () =>
+            return new ExceptAsyncIterator<TSource>(first, second, comparer);
+        }
+
+        private sealed class ExceptAsyncIterator<TSource> : AsyncIterator<TSource>
+        {
+            private readonly IEqualityComparer<TSource> comparer;
+            private readonly IAsyncEnumerable<TSource> first;
+            private readonly IAsyncEnumerable<TSource> second;
+
+            private Task fillSetTask;
+
+            private IAsyncEnumerator<TSource> firstEnumerator;
+            private Set<TSource> set;
+
+            private bool setFilled;
+
+            public ExceptAsyncIterator(IAsyncEnumerable<TSource> first, IAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer)
+            {
+                this.first = first;
+                this.second = second;
+                this.comparer = comparer;
+            }
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new ExceptAsyncIterator<TSource>(first, second, comparer);
+            }
+
+            public override void Dispose()
+            {
+                if (firstEnumerator != null)
+                {
+                    firstEnumerator.Dispose();
+                    firstEnumerator = null;
+                }
+
+                set = null;
+
+                base.Dispose();
+            }
+
+            protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
                 {
-                    var e = first.GetEnumerator();
+                    case AsyncIteratorState.Allocated:
+                        firstEnumerator = first.GetEnumerator();
+                        set = new Set<TSource>(comparer);
+                        setFilled = false;
+                        fillSetTask = FillSet(cancellationToken);
 
-                    var cts = new CancellationTokenDisposable();
-                    var d = Disposable.Create(cts, e);
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
 
-                    var mapTask = default(Task<Dictionary<TSource, TSource>>);
-                    var getMapTask = new Func<CancellationToken, Task<Dictionary<TSource, TSource>>>(
-                        ct => mapTask ?? (mapTask = second.ToDictionary(x => x, comparer, ct)));
+                    case AsyncIteratorState.Iterating:
+
+                        bool moveNext;
+                        if (!setFilled)
+                        {
+                            // This is here so we don't need to call Task.WhenAll each time after the set is filled
+                            var moveNextTask = firstEnumerator.MoveNext(cancellationToken);
+                            await Task.WhenAll(moveNextTask, fillSetTask)
+                                      .ConfigureAwait(false);
+                            setFilled = true;
+                            moveNext = moveNextTask.Result;
+                        }
+                        else
+                        {
+                            moveNext = await firstEnumerator.MoveNext(cancellationToken)
+                                                            .ConfigureAwait(false);
+                        }
 
-                    var f = default(Func<CancellationToken, Task<bool>>);
-                    f = async ct =>
+                        if (moveNext)
                         {
-                            if (await e.MoveNext(ct)
-                                       .Zip(getMapTask(ct), (b, _) => b)
-                                       .ConfigureAwait(false))
+                            var item = firstEnumerator.Current;
+                            if (set.Add(item))
                             {
-                                if (!mapTask.Result.ContainsKey(e.Current))
-                                    return true;
-                                return await f(ct)
-                                           .ConfigureAwait(false);
+                                current = item;
+                                return true;
                             }
-                            return false;
-                        };
-
-                    return CreateEnumerator(
-                        f,
-                        () => e.Current,
-                        d.Dispose,
-                        e
-                    );
-                });
+                            goto case AsyncIteratorState.Iterating; // loop
+                        }
+
+                        Dispose();
+                        break;
+                }
+
+                return false;
+            }
+
+            private async Task FillSet(CancellationToken cancellationToken)
+            {
+                var array = await second.ToArray(cancellationToken)
+                                        .ConfigureAwait(false);
+                for (var i = 0; i < array.Length; i++)
+                {
+                    set.Add(array[i]);
+                }
+            }
         }
     }
 }

+ 0 - 43
Ix.NET/Source/System.Interactive.Async/TaskExt.cs

@@ -16,48 +16,5 @@ namespace System.Threading.Tasks
             tcs.TrySetException(exception);
             return tcs.Task;
         }
-
-        public static Task<V> Zip<T, U, V>(this Task<T> t1, Task<U> t2, Func<T, U, V> f)
-        {
-            var tcs = new TaskCompletionSource<V>();
-
-            var i = 2;
-            var complete = new Action<Task>(t =>
-            {
-                if (Interlocked.Decrement(ref i) == 0)
-                {
-                    var exs = new List<Exception>();
-                    if (t1.IsFaulted)
-                        exs.Add(t1.Exception);
-                    if (t2.IsFaulted)
-                        exs.Add(t2.Exception);
-
-                    if (exs.Count > 0)
-                        tcs.TrySetException(exs);
-                    else if (t1.IsCanceled || t2.IsCanceled)
-                        tcs.TrySetCanceled();
-                    else
-                    {
-                        var res = default(V);
-                        try
-                        {
-                            res = f(t1.Result, t2.Result);
-                        }
-                        catch (Exception ex)
-                        {
-                            tcs.TrySetException(ex);
-                            return;
-                        }
-
-                        tcs.TrySetResult(res);
-                    }
-                }
-            });
-
-            t1.ContinueWith(complete);
-            t2.ContinueWith(complete);
-
-            return tcs.Task;
-        }
     }
 }