浏览代码

Optimize grouping

Oren Novotny 9 年之前
父节点
当前提交
26c841fe99

+ 1 - 1
Ix.NET/Source/System.Interactive.Async/GroupJoin.cs

@@ -127,7 +127,7 @@ namespace System.Linq
                     }
 
                     var item = _outer.Current;
-                    Current = _resultSelector(item, new AsyncEnumerableAdapter<TInner>(_lookup[_outerKeySelector(item)]));
+                    Current = _resultSelector(item, _lookup[_outerKeySelector(item)].ToAsyncEnumerable());
                     return true;
                 }
 

+ 148 - 192
Ix.NET/Source/System.Interactive.Async/Grouping.cs

@@ -25,127 +25,7 @@ namespace System.Linq
             if (comparer == null)
                 throw new ArgumentNullException(nameof(comparer));
 
-            return CreateEnumerable(() =>
-                          {
-                              var gate = new object();
-
-                              var e = source.GetEnumerator();
-                              var count = 1;
-
-                              var map = new Dictionary<TKey, AsyncGrouping<TKey, TElement>>(comparer);
-                              var list = new List<IAsyncGrouping<TKey, TElement>>();
-
-                              var index = 0;
-
-                              var current = default(IAsyncGrouping<TKey, TElement>);
-                              var faulted = default(ExceptionDispatchInfo);
-
-                              var res = default(bool?);
-
-                              var cts = new CancellationTokenDisposable();
-                              var refCount = new Disposable(
-                                  () =>
-                                  {
-                                      if (Interlocked.Decrement(ref count) == 0)
-                                          e.Dispose();
-                                  }
-                              );
-                              var d = Disposable.Create(cts, refCount);
-
-                              var iterateSource = default(Func<CancellationToken, Task<bool>>);
-                              iterateSource = async ct =>
-                                              {
-                                                  lock (gate)
-                                                  {
-                                                      if (res != null)
-                                                      {
-                                                          return res.Value;
-                                                      }
-                                                      res = null;
-                                                  }
-
-                                                  faulted?.Throw();
-
-                                                  try
-                                                  {
-                                                      res = await e.MoveNext(ct)
-                                                                   .ConfigureAwait(false);
-                                                      if (res == true)
-                                                      {
-                                                          var key = default(TKey);
-                                                          var element = default(TElement);
-
-                                                          var cur = e.Current;
-                                                          try
-                                                          {
-                                                              key = keySelector(cur);
-                                                              element = elementSelector(cur);
-                                                          }
-                                                          catch (Exception exception)
-                                                          {
-                                                              foreach (var v in map.Values)
-                                                                  v.Error(exception);
-
-                                                              throw;
-                                                          }
-
-                                                          var group = default(AsyncGrouping<TKey, TElement>);
-                                                          if (!map.TryGetValue(key, out group))
-                                                          {
-                                                              group = new AsyncGrouping<TKey, TElement>(key, iterateSource, refCount);
-                                                              map.Add(key, group);
-                                                              lock (list)
-                                                                  list.Add(group);
-
-                                                              Interlocked.Increment(ref count);
-                                                          }
-                                                          group.Add(element);
-                                                      }
-
-                                                      return res.Value;
-                                                  }
-                                                  catch (Exception ex)
-                                                  {
-                                                      foreach (var v in map.Values)
-                                                          v.Error(ex);
-
-                                                      faulted = ExceptionDispatchInfo.Capture(ex);
-                                                      throw;
-                                                  }
-                                                  finally
-                                                  {
-                                                      res = null;
-                                                  }
-                                              };
-
-                              var f = default(Func<CancellationToken, Task<bool>>);
-                              f = async ct =>
-                                  {
-                                      var result = await iterateSource(ct)
-                                                       .ConfigureAwait(false);
-
-                                      current = null;
-                                      lock (list)
-                                      {
-                                          if (index < list.Count)
-                                              current = list[index++];
-                                      }
-
-                                      if (current != null)
-                                      {
-                                          return true;
-                                      }
-                                      return result && await f(ct)
-                                                 .ConfigureAwait(false);
-                                  };
-
-                              return CreateEnumerator(
-                                  f,
-                                  () => current,
-                                  d.Dispose,
-                                  e
-                              );
-                          });
+            return new GroupedAsyncEnumerable<TSource, TKey, TElement>(source, keySelector, elementSelector, comparer);
         }
 
         public static IAsyncEnumerable<IAsyncGrouping<TKey, TElement>> GroupBy<TSource, TKey, TElement>(this IAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector, Func<TSource, TElement> elementSelector)
@@ -225,8 +105,7 @@ namespace System.Linq
             if (comparer == null)
                 throw new ArgumentNullException(nameof(comparer));
 
-            return source.GroupBy(keySelector, x => x, comparer)
-                         .Select(g => resultSelector(g.Key, g));
+            return new GroupedResultAsyncEnumerable<TSource, TKey, TResult>(source, keySelector, resultSelector, comparer);
         }
 
         public static IAsyncEnumerable<TResult> GroupBy<TSource, TKey, TResult>(this IAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector, Func<TKey, IAsyncEnumerable<TSource>, TResult> resultSelector)
@@ -238,8 +117,7 @@ namespace System.Linq
             if (resultSelector == null)
                 throw new ArgumentNullException(nameof(resultSelector));
 
-            return source.GroupBy(keySelector, x => x, EqualityComparer<TKey>.Default)
-                         .Select(g => resultSelector(g.Key, g));
+            return GroupBy(source, keySelector, resultSelector, EqualityComparer<TKey>.Default);
         }
 
         private static IEnumerable<IGrouping<TKey, TElement>> GroupUntil<TSource, TKey, TElement>(this IEnumerable<TSource> source, Func<TSource, TKey> keySelector, Func<TSource, TElement> elementSelector, IComparer<TKey> comparer)
@@ -257,27 +135,30 @@ namespace System.Linq
             }
         }
 
-        internal sealed class GroupedAsyncEnumerable<TSource, TKey> : IIListProvider<IAsyncGrouping<TKey, TSource>>
+        internal sealed class GroupedResultAsyncEnumerable<TSource, TKey, TResult> : IIListProvider<TResult>
         {
             private readonly IAsyncEnumerable<TSource> source;
             private readonly Func<TSource, TKey> keySelector;
+            private readonly Func<TKey, IAsyncEnumerable<TSource>, TResult> resultSelector;
             private readonly IEqualityComparer<TKey> comparer;
 
-            public GroupedAsyncEnumerable(IAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector, IEqualityComparer<TKey> comparer)
+            public GroupedResultAsyncEnumerable(IAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector, Func<TKey, IAsyncEnumerable<TSource>, TResult> resultSelector, IEqualityComparer<TKey> comparer)
             {
                 if (source == null) throw new ArgumentNullException(nameof(source));
                 if (keySelector == null) throw new ArgumentNullException(nameof(keySelector));
+                if (resultSelector == null) throw new ArgumentNullException(nameof(resultSelector));
 
                 this.source = source;
                 this.keySelector = keySelector;
+                this.resultSelector = resultSelector;
                 this.comparer = comparer;
             }
 
 
-            public IAsyncEnumerator<IAsyncGrouping<TKey, TSource>> GetEnumerator()
+            public IAsyncEnumerator<TResult> GetEnumerator()
             {
                 Internal.Lookup<TKey, TSource> lookup = null;
-                IEnumerator<IGrouping<TKey, TSource>> enumerator = null;
+                IEnumerator<TResult> enumerator = null;
 
                 return CreateEnumerator(
                     async ct =>
@@ -285,7 +166,7 @@ namespace System.Linq
                         if (lookup == null)
                         {
                             lookup = await Internal.Lookup<TKey, TSource>.CreateAsync(source, keySelector, comparer, ct).ConfigureAwait(false);
-                            enumerator = lookup.GetEnumerator();
+                            enumerator = lookup.ApplyResultSelector(resultSelector).GetEnumerator();
                         }
 
                         // By the time we get here, the lookup is sync
@@ -294,27 +175,27 @@ namespace System.Linq
 
                         return enumerator?.MoveNext() ?? false;
                     },
-                    () => (IAsyncGrouping<TKey, TSource>)enumerator?.Current,
+                    () => enumerator.Current,
                     () =>
+                    {
+                        if (enumerator != null)
                         {
-                            if (enumerator != null)
-                            {
-                                enumerator.Dispose();
-                                enumerator = null;
-                            }
-                        });
+                            enumerator.Dispose();
+                            enumerator = null;
+                        }
+                    });
             }
 
-            public async Task<IAsyncGrouping<TKey, TSource>[]> ToArrayAsync(CancellationToken cancellationToken)
+            public async Task<TResult[]> ToArrayAsync(CancellationToken cancellationToken)
             {
-                IIListProvider<IAsyncGrouping<TKey, TSource>> lookup = await Internal.Lookup<TKey, TSource>.CreateAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false);
-                return await lookup.ToArrayAsync(cancellationToken).ConfigureAwait(false);
+                var lookup = await Internal.Lookup<TKey, TSource>.CreateAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false);
+                return lookup.ToArray(resultSelector);
             }
 
-            public async Task<List<IAsyncGrouping<TKey, TSource>>> ToListAsync(CancellationToken cancellationToken)
+            public async Task<List<TResult>> ToListAsync(CancellationToken cancellationToken)
             {
-                IIListProvider<IAsyncGrouping<TKey, TSource>> lookup = await Internal.Lookup<TKey, TSource>.CreateAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false);
-                return await lookup.ToListAsync(cancellationToken).ConfigureAwait(false);
+                var lookup = await Internal.Lookup<TKey, TSource>.CreateAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false);
+                return lookup.ToList(resultSelector);
             }
 
             public async Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
@@ -330,79 +211,155 @@ namespace System.Linq
             }
         }
 
-        private class AsyncGrouping<TKey, TElement> : IAsyncGrouping<TKey, TElement>
+        internal sealed class GroupedAsyncEnumerable<TSource, TKey, TElement> : IIListProvider<IAsyncGrouping<TKey, TElement>>
         {
-            private readonly List<TElement> elements = new List<TElement>();
-            private readonly Func<CancellationToken, Task<bool>> iterateSource;
-            private readonly IDisposable sourceDisposable;
-            private bool done;
-            private ExceptionDispatchInfo exception;
+            private readonly IAsyncEnumerable<TSource> source;
+            private readonly Func<TSource, TKey> keySelector;
+            private readonly Func<TSource, TElement> elementSelector;
+            private readonly IEqualityComparer<TKey> comparer;
 
-            public AsyncGrouping(TKey key, Func<CancellationToken, Task<bool>> iterateSource, IDisposable sourceDisposable)
+            public GroupedAsyncEnumerable(IAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector, Func<TSource, TElement> elementSelector, IEqualityComparer<TKey> comparer)
             {
-                this.iterateSource = iterateSource;
-                this.sourceDisposable = sourceDisposable;
-                Key = key;
+                if (source == null) throw new ArgumentNullException(nameof(source));
+                if (keySelector == null) throw new ArgumentNullException(nameof(keySelector));
+                if (elementSelector == null) throw new ArgumentNullException(nameof(elementSelector));
+
+                this.source = source;
+                this.keySelector = keySelector;
+                this.elementSelector = elementSelector;
+                this.comparer = comparer;
             }
 
-            public TKey Key { get; }
 
-            public IAsyncEnumerator<TElement> GetEnumerator()
+            public IAsyncEnumerator<IAsyncGrouping<TKey, TElement>> GetEnumerator()
             {
-                var index = -1;
+                Internal.Lookup<TKey, TElement> lookup = null;
+                IEnumerator<IGrouping<TKey, TElement>> enumerator = null;
 
-                var cts = new CancellationTokenDisposable();
-                var d = Disposable.Create(cts, sourceDisposable);
-
-                var f = default(Func<CancellationToken, Task<bool>>);
-                f = async ct =>
+                return CreateEnumerator(
+                    async ct =>
                     {
-                        var size = 0;
-                        lock (elements)
-                            size = elements.Count;
-
-                        if (index < size)
+                        if (lookup == null)
                         {
-                            return true;
+                            lookup = await Internal.Lookup<TKey, TElement>.CreateAsync(source, keySelector, elementSelector, comparer, ct).ConfigureAwait(false);
+                            enumerator = lookup.GetEnumerator();
                         }
-                        if (done)
-                        {
-                            exception?.Throw();
+
+                        // By the time we get here, the lookup is sync
+                        if (ct.IsCancellationRequested)
                             return false;
-                        }
-                        if (await iterateSource(ct)
-                                .ConfigureAwait(false))
+
+                        return enumerator?.MoveNext() ?? false;
+                    },
+                    () => (IAsyncGrouping<TKey, TElement>)enumerator?.Current,
+                    () =>
+                    {
+                        if (enumerator != null)
                         {
-                            return await f(ct)
-                                       .ConfigureAwait(false);
+                            enumerator.Dispose();
+                            enumerator = null;
                         }
-                        return false;
-                    };
+                    });
+            }
+
+            public async Task<IAsyncGrouping<TKey, TElement>[]> ToArrayAsync(CancellationToken cancellationToken)
+            {
+                IIListProvider<IAsyncGrouping<TKey, TElement>> lookup = await Internal.Lookup<TKey, TElement>.CreateAsync(source, keySelector, elementSelector, comparer, cancellationToken).ConfigureAwait(false);
+                return await lookup.ToArrayAsync(cancellationToken).ConfigureAwait(false);
+            }
+
+            public async Task<List<IAsyncGrouping<TKey, TElement>>> ToListAsync(CancellationToken cancellationToken)
+            {
+                IIListProvider<IAsyncGrouping<TKey, TElement>> lookup = await Internal.Lookup<TKey, TElement>.CreateAsync(source, keySelector, elementSelector, comparer, cancellationToken).ConfigureAwait(false);
+                return await lookup.ToListAsync(cancellationToken).ConfigureAwait(false);
+            }
+
+            public async Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
+            {
+                if (onlyIfCheap)
+                {
+                    return -1;
+                }
+
+                var lookup = await Internal.Lookup<TKey, TElement>.CreateAsync(source, keySelector, elementSelector, comparer, cancellationToken).ConfigureAwait(false);
+
+                return lookup.Count;
+            }
+        }
+
+        internal sealed class GroupedAsyncEnumerable<TSource, TKey> : IIListProvider<IAsyncGrouping<TKey, TSource>>
+        {
+            private readonly IAsyncEnumerable<TSource> source;
+            private readonly Func<TSource, TKey> keySelector;
+            private readonly IEqualityComparer<TKey> comparer;
+
+            public GroupedAsyncEnumerable(IAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector, IEqualityComparer<TKey> comparer)
+            {
+                if (source == null) throw new ArgumentNullException(nameof(source));
+                if (keySelector == null) throw new ArgumentNullException(nameof(keySelector));
+
+                this.source = source;
+                this.keySelector = keySelector;
+                this.comparer = comparer;
+            }
+
+
+            public IAsyncEnumerator<IAsyncGrouping<TKey, TSource>> GetEnumerator()
+            {
+                Internal.Lookup<TKey, TSource> lookup = null;
+                IEnumerator<IGrouping<TKey, TSource>> enumerator = null;
 
                 return CreateEnumerator(
-                    ct =>
+                    async ct =>
                     {
-                        ++index;
-                        return f(cts.Token);
+                        if (lookup == null)
+                        {
+                            lookup = await Internal.Lookup<TKey, TSource>.CreateAsync(source, keySelector, comparer, ct).ConfigureAwait(false);
+                            enumerator = lookup.GetEnumerator();
+                        }
+
+                        // By the time we get here, the lookup is sync
+                        if (ct.IsCancellationRequested)
+                            return false;
+
+                        return enumerator?.MoveNext() ?? false;
                     },
-                    () => elements[index],
-                    d.Dispose,
-                    null
-                );
+                    () => (IAsyncGrouping<TKey, TSource>)enumerator?.Current,
+                    () =>
+                        {
+                            if (enumerator != null)
+                            {
+                                enumerator.Dispose();
+                                enumerator = null;
+                            }
+                        });
             }
 
-            public void Add(TElement element)
+            public async Task<IAsyncGrouping<TKey, TSource>[]> ToArrayAsync(CancellationToken cancellationToken)
             {
-                lock (elements)
-                    elements.Add(element);
+                IIListProvider<IAsyncGrouping<TKey, TSource>> lookup = await Internal.Lookup<TKey, TSource>.CreateAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false);
+                return await lookup.ToArrayAsync(cancellationToken).ConfigureAwait(false);
             }
 
-            public void Error(Exception exception)
+            public async Task<List<IAsyncGrouping<TKey, TSource>>> ToListAsync(CancellationToken cancellationToken)
+            {
+                IIListProvider<IAsyncGrouping<TKey, TSource>> lookup = await Internal.Lookup<TKey, TSource>.CreateAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false);
+                return await lookup.ToListAsync(cancellationToken).ConfigureAwait(false);
+            }
+
+            public async Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
             {
-                done = true;
-                this.exception = ExceptionDispatchInfo.Capture(exception);
+                if (onlyIfCheap)
+                {
+                    return -1;
+                }
+
+                var lookup = await Internal.Lookup<TKey, TSource>.CreateAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false);
+
+                return lookup.Count;
             }
         }
+
     }
 }
 
@@ -527,8 +484,7 @@ namespace System.Linq.Internal
 
         IAsyncEnumerator<TElement> IAsyncEnumerable<TElement>.GetEnumerator()
         {
-            var adapter = new AsyncEnumerable.AsyncEnumerableAdapter<TElement>(this);
-            return adapter.GetEnumerator();
+            return this.ToAsyncEnumerable().GetEnumerator();
         }
     }
 }

+ 52 - 2
Ix.NET/Source/System.Interactive.Async/Lookup.cs

@@ -189,6 +189,20 @@ namespace System.Linq.Internal
             }
         }
 
+        public IEnumerable<TResult> ApplyResultSelector<TResult>(Func<TKey, IAsyncEnumerable<TElement>, TResult> resultSelector)
+        {
+            var g = _lastGrouping;
+            if (g != null)
+            {
+                do
+                {
+                    g = g._next;
+                    g.Trim();
+                    yield return resultSelector(g._key, g._elements.ToAsyncEnumerable());
+                } while (g != _lastGrouping);
+            }
+        }
+
         internal static Lookup<TKey, TElement> Create<TSource>(IEnumerable<TSource> source, Func<TSource, TKey> keySelector, Func<TSource, TElement> elementSelector, IEqualityComparer<TKey> comparer)
         {
             Debug.Assert(source != null);
@@ -363,6 +377,25 @@ namespace System.Linq.Internal
             return array;
         }
 
+        internal TResult[] ToArray<TResult>(Func<TKey, IAsyncEnumerable<TElement>, TResult> resultSelector)
+        {
+            var array = new TResult[Count];
+            var index = 0;
+            var g = _lastGrouping;
+            if (g != null)
+            {
+                do
+                {
+                    g = g._next;
+                    g.Trim();
+                    array[index] = resultSelector(g._key, g._elements.ToAsyncEnumerable());
+                    ++index;
+                } while (g != _lastGrouping);
+            }
+
+            return array;
+        }
+
 
         internal List<TResult> ToList<TResult>(Func<TKey, IEnumerable<TElement>, TResult> resultSelector)
         {
@@ -381,6 +414,23 @@ namespace System.Linq.Internal
             return list;
         }
 
+        internal List<TResult> ToList<TResult>(Func<TKey, IAsyncEnumerable<TElement>, TResult> resultSelector)
+        {
+            var list = new List<TResult>(Count);
+            var g = _lastGrouping;
+            if (g != null)
+            {
+                do
+                {
+                    g = g._next;
+                    g.Trim();
+                    list.Add(resultSelector(g._key, g._elements.ToAsyncEnumerable()));
+                } while (g != _lastGrouping);
+            }
+
+            return list;
+        }
+
         private void Resize()
         {
             var newSize = checked((Count*2) + 1);
@@ -399,7 +449,7 @@ namespace System.Linq.Internal
 
         IAsyncEnumerator<IGrouping<TKey, TElement>> IAsyncEnumerable<IGrouping<TKey, TElement>>.GetEnumerator()
         {
-            return new AsyncEnumerable.AsyncEnumerableAdapter<IGrouping<TKey, TElement>>(this).GetEnumerator();
+            return this.ToAsyncEnumerable<IGrouping<TKey, TElement>>().GetEnumerator();
         }
 
         public Task<IGrouping<TKey, TElement>[]> ToArrayAsync(CancellationToken cancellationToken)
@@ -446,7 +496,7 @@ namespace System.Linq.Internal
 
         IAsyncEnumerator<IAsyncGrouping<TKey, TElement>> IAsyncEnumerable<IAsyncGrouping<TKey, TElement>>.GetEnumerator()
         {
-            return new AsyncEnumerable.AsyncEnumerableAdapter<IAsyncGrouping<TKey, TElement>>(Enumerable.Cast<IAsyncGrouping<TKey, TElement>>(this)).GetEnumerator();
+            return Enumerable.Cast<IAsyncGrouping<TKey, TElement>>(this).ToAsyncEnumerable().GetEnumerator();
         }
 
         Task<List<IAsyncGrouping<TKey, TElement>>> IIListProvider<IAsyncGrouping<TKey, TElement>>.ToListAsync(CancellationToken cancellationToken)

+ 0 - 3
Ix.NET/Source/Tests/AsyncTests.Single.cs

@@ -1915,9 +1915,6 @@ namespace Tests
 
             e.Dispose();
 
-           // TODO: Do the internal iterators really get cleaned up?
-           // look once this group by method has been updated
-
             HasNext(g1e, 'd');
             HasNext(g1e, 'g');
             HasNext(g1e, 'j');