Oren Novotny il y a 9 ans
Parent
commit
f9adb73794
1 fichiers modifiés avec 99 ajouts et 41 suppressions
  1. 99 41
      Ix.NET/Source/System.Interactive.Async/Intersect.cs

+ 99 - 41
Ix.NET/Source/System.Interactive.Async/Intersect.cs

@@ -21,47 +21,7 @@ namespace System.Linq
             if (comparer == null)
                 throw new ArgumentNullException(nameof(comparer));
 
-            return CreateEnumerable(
-                () =>
-                {
-                    var e = first.GetEnumerator();
-
-                    var cts = new CancellationTokenDisposable();
-                    var d = Disposable.Create(cts, e);
-
-                    var mapTask = default(Task<Dictionary<TSource, TSource>>);
-                    var getMapTask = new Func<CancellationToken, Task<Dictionary<TSource, TSource>>>(
-                        ct =>
-                        {
-                            if (mapTask == null)
-                                mapTask = second.ToDictionary(x => x, comparer, ct);
-                            return mapTask;
-                        });
-
-                    var f = default(Func<CancellationToken, Task<bool>>);
-                    f = async ct =>
-                        {
-                            if (await e.MoveNext(ct)
-                                       .Zip(getMapTask(ct), (b, _) => b)
-                                       .ConfigureAwait(false))
-                            {
-                                // Note: Result here is safe because the task
-                                // was completed in the Zip() call above
-                                if (mapTask.Result.ContainsKey(e.Current))
-                                    return true;
-                                return await f(ct)
-                                           .ConfigureAwait(false);
-                            }
-                            return false;
-                        };
-
-                    return CreateEnumerator(
-                        f,
-                        () => e.Current,
-                        d.Dispose,
-                        e
-                    );
-                });
+            return new IntersectAsyncIterator<TSource>(first, second, comparer);
         }
 
 
@@ -74,5 +34,103 @@ namespace System.Linq
 
             return first.Intersect(second, EqualityComparer<TSource>.Default);
         }
+
+        private sealed class IntersectAsyncIterator<TSource> : AsyncIterator<TSource>
+        {
+            private readonly IEqualityComparer<TSource> comparer;
+            private readonly IAsyncEnumerable<TSource> first;
+            private readonly IAsyncEnumerable<TSource> second;
+
+            private Task fillSetTask;
+
+            private IAsyncEnumerator<TSource> firstEnumerator;
+            private Set<TSource> set;
+
+            private bool setFilled;
+
+            public IntersectAsyncIterator(IAsyncEnumerable<TSource> first, IAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer)
+            {
+                this.first = first;
+                this.second = second;
+                this.comparer = comparer;
+            }
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new IntersectAsyncIterator<TSource>(first, second, comparer);
+            }
+
+            public override void Dispose()
+            {
+                if (firstEnumerator != null)
+                {
+                    firstEnumerator.Dispose();
+                    firstEnumerator = null;
+                }
+
+                set = null;
+
+                base.Dispose();
+            }
+
+            protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        firstEnumerator = first.GetEnumerator();
+                        set = new Set<TSource>(comparer);
+                        setFilled = false;
+                        fillSetTask = FillSet(cancellationToken);
+
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+
+                        bool moveNext;
+                        if (!setFilled)
+                        {
+                            // This is here so we don't need to call Task.WhenAll each time after the set is filled
+                            var moveNextTask = firstEnumerator.MoveNext(cancellationToken);
+                            await Task.WhenAll(moveNextTask, fillSetTask)
+                                      .ConfigureAwait(false);
+                            setFilled = true;
+                            moveNext = moveNextTask.Result;
+                        }
+                        else
+                        {
+                            moveNext = await firstEnumerator.MoveNext(cancellationToken)
+                                                            .ConfigureAwait(false);
+                        }
+
+                        if (moveNext)
+                        {
+                            var item = firstEnumerator.Current;
+                            if (set.Remove(item))
+                            {
+                                current = item;
+                                return true;
+                            }
+                            goto case AsyncIteratorState.Iterating; // loop
+                        }
+
+                        Dispose();
+                        break;
+                }
+
+                return false;
+            }
+
+            private async Task FillSet(CancellationToken cancellationToken)
+            {
+                var array = await second.ToArray(cancellationToken)
+                                        .ConfigureAwait(false);
+                for (var i = 0; i < array.Length; i++)
+                {
+                    set.Add(array[i]);
+                }
+            }
+        }
     }
 }