فهرست منبع

Optimize join

Oren Novotny 9 سال پیش
والد
کامیت
c537de2ee2
1فایلهای تغییر یافته به همراه125 افزوده شده و 123 حذف شده
  1. 125 123
      Ix.NET/Source/System.Interactive.Async/Join.cs

+ 125 - 123
Ix.NET/Source/System.Interactive.Async/Join.cs

@@ -4,6 +4,7 @@
 
 using System;
 using System.Collections.Generic;
+using System.Diagnostics;
 using System.Linq;
 using System.Threading;
 using System.Threading.Tasks;
@@ -27,149 +28,150 @@ namespace System.Linq
             if (comparer == null)
                 throw new ArgumentNullException(nameof(comparer));
 
-            return CreateEnumerable(
-                () =>
-                {
-                    var oe = outer.GetEnumerator();
-                    var ie = inner.GetEnumerator();
+            return new JoinAsyncIterator<TOuter, TInner, TKey, TResult>(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer);
+        }
 
-                    var cts = new CancellationTokenDisposable();
-                    var d = Disposable.Create(cts, oe, ie);
+        public static IAsyncEnumerable<TResult> Join<TOuter, TInner, TKey, TResult>(this IAsyncEnumerable<TOuter> outer, IAsyncEnumerable<TInner> inner, Func<TOuter, TKey> outerKeySelector, Func<TInner, TKey> innerKeySelector, Func<TOuter, TInner, TResult> resultSelector)
+        {
+            if (outer == null)
+                throw new ArgumentNullException(nameof(outer));
+            if (inner == null)
+                throw new ArgumentNullException(nameof(inner));
+            if (outerKeySelector == null)
+                throw new ArgumentNullException(nameof(outerKeySelector));
+            if (innerKeySelector == null)
+                throw new ArgumentNullException(nameof(innerKeySelector));
+            if (resultSelector == null)
+                throw new ArgumentNullException(nameof(resultSelector));
 
-                    var current = default(TResult);
-                    var useOuter = true;
-                    var outerMap = new Dictionary<TKey, List<TOuter>>(comparer);
-                    var innerMap = new Dictionary<TKey, List<TInner>>(comparer);
-                    var q = new Queue<TResult>();
+            return new JoinAsyncIterator<TOuter,TInner,TKey,TResult>(outer, inner, outerKeySelector, innerKeySelector, resultSelector, EqualityComparer<TKey>.Default);
+        }
 
-                    var f = default(Func<CancellationToken, Task<bool>>);
-                    f = async ct =>
+        internal sealed class JoinAsyncIterator<TOuter, TInner, TKey, TResult> : AsyncIterator<TResult>
+        {
+            private readonly IAsyncEnumerable<TOuter> outer;
+            private readonly IAsyncEnumerable<TInner> inner;
+            private readonly Func<TOuter, TKey> outerKeySelector;
+            private readonly Func<TInner, TKey> innerKeySelector;
+            private readonly Func<TOuter, TInner, TResult> resultSelector;
+            private readonly IEqualityComparer<TKey> comparer;
+
+            private IAsyncEnumerator<TOuter> outerEnumerator;
+            private IAsyncEnumerator<TInner> innerEnumerator;
+            private Mode mode;
+
+            public JoinAsyncIterator(IAsyncEnumerable<TOuter> outer, IAsyncEnumerable<TInner> inner, Func<TOuter, TKey> outerKeySelector, Func<TInner, TKey> innerKeySelector, Func<TOuter, TInner, TResult> resultSelector, IEqualityComparer<TKey> comparer)
+            {
+                Debug.Assert(outer != null);
+                Debug.Assert(inner != null);
+                Debug.Assert(outerKeySelector != null);
+                Debug.Assert(innerKeySelector != null);
+                Debug.Assert(resultSelector != null);
+
+                this.outer = outer;
+                this.inner = inner;
+                this.outerKeySelector = outerKeySelector;
+                this.innerKeySelector = innerKeySelector;
+                this.resultSelector = resultSelector;
+                this.comparer = comparer;
+            }
+
+            public override AsyncIterator<TResult> Clone()
+            {
+                return new JoinAsyncIterator<TOuter, TInner, TKey, TResult>(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer);
+            }
+
+            public override void Dispose()
+            {
+                if (outerEnumerator != null)
+                {
+                    outerEnumerator.Dispose();
+                    outerEnumerator = null;
+                }
+
+                base.Dispose();
+            }
+
+            private enum Mode
+            {
+                Begin,
+                DoLoop,
+                For,
+                While,
+            }
+
+            // State machine vars
+            Internal.Lookup<TKey, TInner> lookup;
+            int count;
+            TInner[] elements;
+            int index;
+            TOuter item;
+
+            protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case State.Allocated:
+                        outerEnumerator = outer.GetEnumerator();
+                        mode = Mode.Begin;
+                        state = State.Iterating;
+                        goto case State.Iterating;
+
+                    case State.Iterating:
+                        switch (mode)
                         {
-                            if (q.Count > 0)
-                            {
-                                current = q.Dequeue();
-                                return true;
-                            }
-
-                            var b = useOuter;
-                            if (ie == null && oe == null)
-                            {
-                                return false;
-                            }
-                            if (ie == null)
-                                b = true;
-                            else if (oe == null)
-                                b = false;
-                            useOuter = !useOuter;
-
-                            var enqueue = new Func<TOuter, TInner, bool>(
-                                (o, i) =>
-                                {
-                                    var result = resultSelector(o, i);
-                                    q.Enqueue(result);
-                                    return true;
-                                });
-
-                            if (b)
-                            {
-                                if (await oe.MoveNext(ct)
-                                            .ConfigureAwait(false))
+                            case Mode.Begin:
+                                if (await outerEnumerator.MoveNext(cancellationToken)
+                                                         .ConfigureAwait(false))
                                 {
-                                    var element = oe.Current;
-                                    var key = default(TKey);
-
-                                    key = outerKeySelector(element);
-
-                                    var outerList = default(List<TOuter>);
-                                    if (!outerMap.TryGetValue(key, out outerList))
+                                    lookup = await Internal.Lookup<TKey, TInner>.CreateForJoinAsync(inner, innerKeySelector, comparer, cancellationToken).ConfigureAwait(false);
+                                    if (lookup.Count != 0)
                                     {
-                                        outerList = new List<TOuter>();
-                                        outerMap.Add(key, outerList);
+                                        mode = Mode.DoLoop;
+                                        goto case Mode.DoLoop;   
                                     }
-
-                                    outerList.Add(element);
-
-                                    var innerList = default(List<TInner>);
-                                    if (!innerMap.TryGetValue(key, out innerList))
-                                    {
-                                        innerList = new List<TInner>();
-                                        innerMap.Add(key, innerList);
-                                    }
-
-                                    foreach (var v in innerList)
-                                    {
-                                        if (!enqueue(element, v))
-                                            return false;
-                                    }
-
-                                    return await f(ct)
-                                               .ConfigureAwait(false);
                                 }
-                                oe.Dispose();
-                                oe = null;
-                                return await f(ct)
-                                           .ConfigureAwait(false);
-                            }
-                            if (await ie.MoveNext(ct)
-                                        .ConfigureAwait(false))
-                            {
-                                var element = ie.Current;
-                                var key = innerKeySelector(element);
-
-                                var innerList = default(List<TInner>);
-                                if (!innerMap.TryGetValue(key, out innerList))
+
+                                break;
+                            case Mode.DoLoop:
+                                item = outerEnumerator.Current;
+                                var g = lookup.GetGrouping(outerKeySelector(item), create: false);
+                                if (g != null)
                                 {
-                                    innerList = new List<TInner>();
-                                    innerMap.Add(key, innerList);
+                                    count = g._count;
+                                    elements = g._elements;
+                                    index = 0;
+                                    mode = Mode.For;
+                                    goto case Mode.For;
                                 }
 
-                                innerList.Add(element);
+                                break;
 
-                                var outerList = default(List<TOuter>);
-                                if (!outerMap.TryGetValue(key, out outerList))
+                            case Mode.For:
+                                current = resultSelector(item, elements[index]);
+                                index++;
+                                if (index == count)
                                 {
-                                    outerList = new List<TOuter>();
-                                    outerMap.Add(key, outerList);
+                                    mode = Mode.While;
                                 }
+                                return true;
 
-                                foreach (var v in outerList)
+                            case Mode.While:
+                                var hasNext = await outerEnumerator.MoveNext(cancellationToken).ConfigureAwait(false);
+                                if (hasNext)
                                 {
-                                    if (!enqueue(v, element))
-                                        return false;
+                                    goto case Mode.DoLoop;
                                 }
 
-                                return await f(ct)
-                                           .ConfigureAwait(false);
-                            }
-                            ie.Dispose();
-                            ie = null;
-                            return await f(ct)
-                                       .ConfigureAwait(false);
-                        };
-
-                    return CreateEnumerator(
-                        f,
-                        () => current,
-                        d.Dispose,
-                        ie
-                    );
-                });
-        }
+                                Dispose();
+                                break;
+                        }
 
-        public static IAsyncEnumerable<TResult> Join<TOuter, TInner, TKey, TResult>(this IAsyncEnumerable<TOuter> outer, IAsyncEnumerable<TInner> inner, Func<TOuter, TKey> outerKeySelector, Func<TInner, TKey> innerKeySelector, Func<TOuter, TInner, TResult> resultSelector)
-        {
-            if (outer == null)
-                throw new ArgumentNullException(nameof(outer));
-            if (inner == null)
-                throw new ArgumentNullException(nameof(inner));
-            if (outerKeySelector == null)
-                throw new ArgumentNullException(nameof(outerKeySelector));
-            if (innerKeySelector == null)
-                throw new ArgumentNullException(nameof(innerKeySelector));
-            if (resultSelector == null)
-                throw new ArgumentNullException(nameof(resultSelector));
+                        break;
+                }
 
-            return outer.Join(inner, outerKeySelector, innerKeySelector, resultSelector, EqualityComparer<TKey>.Default);
+                return false;
+            }
         }
     }
 }