Răsfoiți Sursa

Simplify GroupJoin based on CoreFX impl

Oren Novotny 9 ani în urmă
părinte
comite
e5693340be

+ 10 - 44
Ix.NET/Source/System.Interactive.Async/AsyncEnumerable.Multiple.cs

@@ -405,7 +405,7 @@ namespace System.Linq
             public IAsyncEnumerator<TResult> GetEnumerator()
                 => new GroupJoinAsyncEnumerator(
                     _outer.GetEnumerator(),
-                    _inner.GetEnumerator(),
+                    _inner,
                     _outerKeySelector,
                     _innerKeySelector,
                     _resultSelector,
@@ -414,17 +414,17 @@ namespace System.Linq
             private sealed class GroupJoinAsyncEnumerator : IAsyncEnumerator<TResult>
             {
                 private readonly IAsyncEnumerator<TOuter> _outer;
-                private readonly IAsyncEnumerator<TInner> _inner;
+                private readonly IAsyncEnumerable<TInner> _inner;
                 private readonly Func<TOuter, TKey> _outerKeySelector;
                 private readonly Func<TInner, TKey> _innerKeySelector;
                 private readonly Func<TOuter, IAsyncEnumerable<TInner>, TResult> _resultSelector;
                 private readonly IEqualityComparer<TKey> _comparer;
 
-                private Dictionary<TKey, List<TInner>> _innerGroups;
+                private Internal.Lookup<TKey, TInner> _lookup;
 
                 public GroupJoinAsyncEnumerator(
                     IAsyncEnumerator<TOuter> outer,
-                    IAsyncEnumerator<TInner> inner,
+                    IAsyncEnumerable<TInner> inner,
                     Func<TOuter, TKey> outerKeySelector,
                     Func<TInner, TKey> innerKeySelector,
                     Func<TOuter, IAsyncEnumerable<TInner>, TResult> resultSelector,
@@ -440,52 +440,19 @@ namespace System.Linq
 
                 public async Task<bool> MoveNext(CancellationToken cancellationToken)
                 {
-                    List<TInner> group;
-
+                    // nothing to do 
                     if (!await _outer.MoveNext(cancellationToken).ConfigureAwait(false))
                     {
                         return false;
                     }
 
-                    if (_innerGroups == null)
+                    if (_lookup == null)
                     {
-                        _innerGroups = new Dictionary<TKey, List<TInner>>(_comparer);
-
-                        while (await _inner.MoveNext(cancellationToken).ConfigureAwait(false))
-                        {
-                            var inner = _inner.Current;
-                            var innerKey = _innerKeySelector(inner);
-
-                            if (innerKey != null)
-                            {
-                                if (!_innerGroups.TryGetValue(innerKey, out group))
-                                {
-                                    _innerGroups.Add(innerKey, group = new List<TInner>());
-                                }
-
-                                group.Add(inner);
-                            }
-                        }
+                        _lookup = await Internal.Lookup<TKey, TInner>.CreateForJoinAsync(_inner, _innerKeySelector, _comparer, cancellationToken).ConfigureAwait(false);
                     }
-
-                    var outer = _outer.Current;
-                    var outerKey = _outerKeySelector(outer);
-
-                    Current
-                        = _resultSelector(
-                            outer,
-                            new AsyncEnumerableAdapter<TInner>(
-                                outerKey != null
-                                && _innerGroups.TryGetValue(outerKey, out group)
-                                    ? (IEnumerable<TInner>)group
-                                    :
-#if NO_ARRAY_EMPTY
-                                    EmptyArray<TInner>.Value
-#else
-                                    Array.Empty<TInner>()
-#endif
-                                    ));
-
+                    
+                    var item = _outer.Current;
+                    Current = _resultSelector(item, new AsyncEnumerableAdapter<TInner>(_lookup[_outerKeySelector(item)]));
                     return true;
                 }
 
@@ -493,7 +460,6 @@ namespace System.Linq
 
                 public void Dispose()
                 {
-                    _inner.Dispose();
                     _outer.Dispose();
                 }
 

+ 131 - 0
Ix.NET/Source/System.Interactive.Async/Grouping.cs

@@ -0,0 +1,131 @@
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading.Tasks;
+
+namespace System.Linq.Internal
+{
+    internal class Grouping<TKey, TElement> : IGrouping<TKey, TElement>, IList<TElement>
+    {
+        internal TKey _key;
+        internal int _hashCode;
+        internal TElement[] _elements;
+        internal int _count;
+        internal Grouping<TKey, TElement> _hashNext;
+        internal Grouping<TKey, TElement> _next;
+
+        internal Grouping()
+        {
+        }
+
+        internal void Add(TElement element)
+        {
+            if (_elements.Length == _count)
+            {
+                Array.Resize(ref _elements, checked(_count * 2));
+            }
+
+            _elements[_count] = element;
+            _count++;
+        }
+
+        internal void Trim()
+        {
+            if (_elements.Length != _count)
+            {
+                Array.Resize(ref _elements, _count);
+            }
+        }
+
+        public IEnumerator<TElement> GetEnumerator()
+        {
+            for (int i = 0; i < _count; i++)
+            {
+                yield return _elements[i];
+            }
+        }
+
+        IEnumerator IEnumerable.GetEnumerator()
+        {
+            return GetEnumerator();
+        }
+
+        // DDB195907: implement IGrouping<>.Key implicitly
+        // so that WPF binding works on this property.
+        public TKey Key
+        {
+            get { return _key; }
+        }
+
+        int ICollection<TElement>.Count
+        {
+            get { return _count; }
+        }
+
+        bool ICollection<TElement>.IsReadOnly
+        {
+            get { return true; }
+        }
+
+        void ICollection<TElement>.Add(TElement item)
+        {
+            throw new NotSupportedException(Strings.NOT_SUPPORTED);
+        }
+
+        void ICollection<TElement>.Clear()
+        {
+            throw new NotSupportedException(Strings.NOT_SUPPORTED);
+        }
+
+        bool ICollection<TElement>.Contains(TElement item)
+        {
+            return Array.IndexOf(_elements, item, 0, _count) >= 0;
+        }
+
+        void ICollection<TElement>.CopyTo(TElement[] array, int arrayIndex)
+        {
+            Array.Copy(_elements, 0, array, arrayIndex, _count);
+        }
+
+        bool ICollection<TElement>.Remove(TElement item)
+        {
+            throw new NotSupportedException(Strings.NOT_SUPPORTED);
+        }
+
+        int IList<TElement>.IndexOf(TElement item)
+        {
+            return Array.IndexOf(_elements, item, 0, _count);
+        }
+
+        void IList<TElement>.Insert(int index, TElement item)
+        {
+            throw new NotSupportedException(Strings.NOT_SUPPORTED);
+        }
+
+        void IList<TElement>.RemoveAt(int index)
+        {
+            throw new NotSupportedException(Strings.NOT_SUPPORTED);
+        }
+
+        TElement IList<TElement>.this[int index]
+        {
+            get
+            {
+                if (index < 0 || index >= _count)
+                {
+                    throw new ArgumentOutOfRangeException(nameof(index));
+                }
+
+                return _elements[index];
+            }
+
+            set
+            {
+                throw new NotSupportedException(Strings.NOT_SUPPORTED);
+            }
+        }
+    }
+
+
+}

+ 255 - 0
Ix.NET/Source/System.Interactive.Async/Lookup.cs

@@ -0,0 +1,255 @@
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace System.Linq.Internal
+{
+    internal class Lookup<TKey, TElement> : ILookup<TKey, TElement>
+    {
+        private readonly IEqualityComparer<TKey> _comparer;
+        private Grouping<TKey, TElement>[] _groupings;
+        private Grouping<TKey, TElement> _lastGrouping;
+        private int _count;
+
+        internal static Lookup<TKey, TElement> Create<TSource>(IEnumerable<TSource> source, Func<TSource, TKey> keySelector, Func<TSource, TElement> elementSelector, IEqualityComparer<TKey> comparer)
+        {
+            Debug.Assert(source != null);
+            Debug.Assert(keySelector != null);
+            Debug.Assert(elementSelector != null);
+
+            Lookup<TKey, TElement> lookup = new Lookup<TKey, TElement>(comparer);
+            foreach (TSource item in source)
+            {
+                lookup.GetGrouping(keySelector(item), create: true).Add(elementSelector(item));
+            }
+
+            return lookup;
+        }
+
+        internal static Lookup<TKey, TElement> Create(IEnumerable<TElement> source, Func<TElement, TKey> keySelector, IEqualityComparer<TKey> comparer)
+        {
+            Debug.Assert(source != null);
+            Debug.Assert(keySelector != null);
+
+            Lookup<TKey, TElement> lookup = new Lookup<TKey, TElement>(comparer);
+            foreach (TElement item in source)
+            {
+                lookup.GetGrouping(keySelector(item), create: true).Add(item);
+            }
+
+            return lookup;
+        }
+
+        internal static Lookup<TKey, TElement> CreateForJoin(IEnumerable<TElement> source, Func<TElement, TKey> keySelector, IEqualityComparer<TKey> comparer)
+        {
+            Lookup<TKey, TElement> lookup = new Lookup<TKey, TElement>(comparer);
+            foreach (TElement item in source)
+            {
+                TKey key = keySelector(item);
+                if (key != null)
+                {
+                    lookup.GetGrouping(key, create: true).Add(item);
+                }
+            }
+
+            return lookup;
+        }
+
+        internal static async Task<Lookup<TKey, TElement>> CreateForJoinAsync(IAsyncEnumerable<TElement> source, Func<TElement, TKey> keySelector, IEqualityComparer<TKey> comparer, CancellationToken cancellationToken)
+        {
+            Lookup<TKey, TElement> lookup = new Lookup<TKey, TElement>(comparer);
+            using (var enu = source.GetEnumerator())
+            {
+                while (await enu.MoveNext(cancellationToken)
+                                .ConfigureAwait(false))
+                {
+                    TKey key = keySelector(enu.Current);
+                    if (key != null)
+                    {
+                        lookup.GetGrouping(key, create: true).Add(enu.Current);
+                    }
+                }
+            }
+
+            return lookup;
+        }
+
+        private Lookup(IEqualityComparer<TKey> comparer)
+        {
+            _comparer = comparer ?? EqualityComparer<TKey>.Default;
+            _groupings = new Grouping<TKey, TElement>[7];
+        }
+
+        public int Count
+        {
+            get { return _count; }
+        }
+
+        public IEnumerable<TElement> this[TKey key]
+        {
+            get
+            {
+                Grouping<TKey, TElement> grouping = GetGrouping(key, create: false);
+                if (grouping != null)
+                {
+                    return grouping;
+                }
+
+#if NO_ARRAY_EMPTY
+                return EmptyArray<TElement>.Value;
+#else
+                return Array.Empty<TElement>();
+#endif
+            }
+        }
+
+        public bool Contains(TKey key)
+        {
+            return GetGrouping(key, create: false) != null;
+        }
+
+        public IEnumerator<IGrouping<TKey, TElement>> GetEnumerator()
+        {
+            Grouping<TKey, TElement> g = _lastGrouping;
+            if (g != null)
+            {
+                do
+                {
+                    g = g._next;
+                    yield return g;
+                }
+                while (g != _lastGrouping);
+            }
+        }
+
+        internal TResult[] ToArray<TResult>(Func<TKey, IEnumerable<TElement>, TResult> resultSelector)
+        {
+            TResult[] array = new TResult[_count];
+            int index = 0;
+            Grouping<TKey, TElement> g = _lastGrouping;
+            if (g != null)
+            {
+                do
+                {
+                    g = g._next;
+                    g.Trim();
+                    array[index] = resultSelector(g._key, g._elements);
+                    ++index;
+                }
+                while (g != _lastGrouping);
+            }
+
+            return array;
+        }
+
+
+        internal List<TResult> ToList<TResult>(Func<TKey, IEnumerable<TElement>, TResult> resultSelector)
+        {
+            List<TResult> list = new List<TResult>(_count);
+            Grouping<TKey, TElement> g = _lastGrouping;
+            if (g != null)
+            {
+                do
+                {
+                    g = g._next;
+                    g.Trim();
+                    list.Add(resultSelector(g._key, g._elements));
+                }
+                while (g != _lastGrouping);
+            }
+
+            return list;
+        }
+
+        public IEnumerable<TResult> ApplyResultSelector<TResult>(Func<TKey, IEnumerable<TElement>, TResult> resultSelector)
+        {
+            Grouping<TKey, TElement> g = _lastGrouping;
+            if (g != null)
+            {
+                do
+                {
+                    g = g._next;
+                    g.Trim();
+                    yield return resultSelector(g._key, g._elements);
+                }
+                while (g != _lastGrouping);
+            }
+        }
+
+        IEnumerator IEnumerable.GetEnumerator()
+        {
+            return GetEnumerator();
+        }
+
+        internal int InternalGetHashCode(TKey key)
+        {
+            // Handle comparer implementations that throw when passed null
+            return (key == null) ? 0 : _comparer.GetHashCode(key) & 0x7FFFFFFF;
+        }
+
+        internal Grouping<TKey, TElement> GetGrouping(TKey key, bool create)
+        {
+            int hashCode = InternalGetHashCode(key);
+            for (Grouping<TKey, TElement> g = _groupings[hashCode % _groupings.Length]; g != null; g = g._hashNext)
+            {
+                if (g._hashCode == hashCode && _comparer.Equals(g._key, key))
+                {
+                    return g;
+                }
+            }
+
+            if (create)
+            {
+                if (_count == _groupings.Length)
+                {
+                    Resize();
+                }
+
+                int index = hashCode % _groupings.Length;
+                Grouping<TKey, TElement> g = new Grouping<TKey, TElement>();
+                g._key = key;
+                g._hashCode = hashCode;
+                g._elements = new TElement[1];
+                g._hashNext = _groupings[index];
+                _groupings[index] = g;
+                if (_lastGrouping == null)
+                {
+                    g._next = g;
+                }
+                else
+                {
+                    g._next = _lastGrouping._next;
+                    _lastGrouping._next = g;
+                }
+
+                _lastGrouping = g;
+                _count++;
+                return g;
+            }
+
+            return null;
+        }
+
+        private void Resize()
+        {
+            int newSize = checked((_count * 2) + 1);
+            Grouping<TKey, TElement>[] newGroupings = new Grouping<TKey, TElement>[newSize];
+            Grouping<TKey, TElement> g = _lastGrouping;
+            do
+            {
+                g = g._next;
+                int index = g._hashCode % newSize;
+                g._hashNext = newGroupings[index];
+                newGroupings[index] = g;
+            }
+            while (g != _lastGrouping);
+
+            _groupings = newGroupings;
+        }
+    }
+
+}

+ 1 - 0
Ix.NET/Source/System.Interactive.Async/Strings.cs

@@ -7,5 +7,6 @@ namespace System.Linq
     {
         public static string NO_ELEMENTS = "Source sequence doesn't contain any elements.";
         public static string MORE_THAN_ONE_ELEMENT = "Source sequence contains more than one element.";
+        public static string NOT_SUPPORTED = "NOT SUPPORTED";
     }
 }