浏览代码

Add GroupingAsyncEnumerable. Grouping evals the source on first MoveNext

Oren Novotny 9 年之前
父节点
当前提交
ed36022e48

+ 90 - 8
Ix.NET/Source/System.Interactive.Async/Grouping.cs

@@ -169,7 +169,7 @@ namespace System.Linq
             if (comparer == null)
                 throw new ArgumentNullException(nameof(comparer));
 
-            return source.GroupBy(keySelector, x => x, comparer);
+            return new GroupedAsyncEnumerable<TSource, TKey>(source, keySelector, comparer, CancellationToken.None);
         }
 
         public static IAsyncEnumerable<IAsyncGrouping<TKey, TSource>> GroupBy<TSource, TKey>(this IAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector)
@@ -179,7 +179,7 @@ namespace System.Linq
             if (keySelector == null)
                 throw new ArgumentNullException(nameof(keySelector));
 
-            return source.GroupBy(keySelector, x => x, EqualityComparer<TKey>.Default);
+            return new GroupedAsyncEnumerable<TSource, TKey>(source, keySelector, EqualityComparer<TKey>.Default, CancellationToken.None);
         }
 
         public static IAsyncEnumerable<TResult> GroupBy<TSource, TKey, TElement, TResult>(this IAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector, Func<TSource, TElement> elementSelector, Func<TKey, IAsyncEnumerable<TElement>, TResult> resultSelector, IEqualityComparer<TKey> comparer)
@@ -257,6 +257,82 @@ namespace System.Linq
             }
         }
 
+        internal sealed class GroupedAsyncEnumerable<TSource, TKey> : IIListProvider<IAsyncGrouping<TKey, TSource>>
+        {
+            private readonly IAsyncEnumerable<TSource> source;
+            private readonly Func<TSource, TKey> keySelector;
+            private readonly IEqualityComparer<TKey> comparer;
+            private readonly CancellationToken cancellationToken;
+
+            public GroupedAsyncEnumerable(IAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector, IEqualityComparer<TKey> comparer, CancellationToken cancellationToken)
+            {
+                if (source == null) throw new ArgumentNullException(nameof(source));
+                if (keySelector == null) throw new ArgumentNullException(nameof(keySelector));
+
+                this.source = source;
+                this.keySelector = keySelector;
+                this.comparer = comparer;
+                this.cancellationToken = cancellationToken;
+            }
+
+
+            public IAsyncEnumerator<IAsyncGrouping<TKey, TSource>> GetEnumerator()
+            {
+                Internal.Lookup<TKey, TSource> lookup = null;
+                IAsyncGrouping<TKey, TSource> current = null;
+                IEnumerator<IGrouping<TKey, TSource>> enumerator = null;
+
+                return CreateEnumerator(
+                    async ct =>
+                    {
+                        if (lookup == null)
+                        {
+                            lookup = await Internal.Lookup<TKey, TSource>.CreateAsync(source, keySelector, comparer, ct).ConfigureAwait(false);
+                            enumerator = lookup.GetEnumerator();
+                        }
+
+                        // By the time we get here, the lookup is sync
+                        if (ct.IsCancellationRequested)
+                            return false;
+
+                        return enumerator?.MoveNext() ?? false;
+                    },
+                    () => (IAsyncGrouping<TKey, TSource>)enumerator?.Current,
+                    () =>
+                        {
+                            if (enumerator != null)
+                            {
+                                enumerator.Dispose();
+                                enumerator = null;
+                            }
+                        });
+            }
+
+            public async Task<IAsyncGrouping<TKey, TSource>[]> ToArrayAsync(CancellationToken cancellationToken)
+            {
+                IIListProvider<IAsyncGrouping<TKey, TSource>> lookup = await Internal.Lookup<TKey, TSource>.CreateAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false);
+                return await lookup.ToArrayAsync(cancellationToken).ConfigureAwait(false);
+            }
+
+            public async Task<List<IAsyncGrouping<TKey, TSource>>> ToListAsync(CancellationToken cancellationToken)
+            {
+                IIListProvider<IAsyncGrouping<TKey, TSource>> lookup = await Internal.Lookup<TKey, TSource>.CreateAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false);
+                return await lookup.ToListAsync(cancellationToken).ConfigureAwait(false);
+            }
+
+            public async Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
+            {
+                if (onlyIfCheap)
+                {
+                    return -1;
+                }
+
+                var lookup = await Internal.Lookup<TKey, TSource>.CreateAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false);
+
+                return lookup.Count;
+            }
+        }
+
         private class AsyncGrouping<TKey, TElement> : IAsyncGrouping<TKey, TElement>
         {
             private readonly List<TElement> elements = new List<TElement>();
@@ -339,7 +415,7 @@ namespace System.Linq.Internal
 {
     /// Adapted from System.Linq.Grouping from .NET Framework
     /// Source: https://github.com/dotnet/corefx/blob/b90532bc97b07234a7d18073819d019645285f1c/src/System.Linq/src/System/Linq/Grouping.cs#L64
-    internal class Grouping<TKey, TElement> : IGrouping<TKey, TElement>, IList<TElement>
+    internal class Grouping<TKey, TElement> : IGrouping<TKey, TElement>, IList<TElement>, IAsyncGrouping<TKey, TElement>
     {
         internal int _count;
         internal TElement[] _elements;
@@ -348,6 +424,11 @@ namespace System.Linq.Internal
         internal TKey _key;
         internal Grouping<TKey, TElement> _next;
 
+        IEnumerator IEnumerable.GetEnumerator()
+        {
+            return GetEnumerator();
+        }
+
         public IEnumerator<TElement> GetEnumerator()
         {
             for (var i = 0; i < _count; i++)
@@ -356,11 +437,6 @@ namespace System.Linq.Internal
             }
         }
 
-        IEnumerator IEnumerable.GetEnumerator()
-        {
-            return GetEnumerator();
-        }
-
         // DDB195907: implement IGrouping<>.Key implicitly
         // so that WPF binding works on this property.
         public TKey Key
@@ -451,5 +527,11 @@ namespace System.Linq.Internal
                 Array.Resize(ref _elements, _count);
             }
         }
+
+        IAsyncEnumerator<TElement> IAsyncEnumerable<TElement>.GetEnumerator()
+        {
+            var adapter = new AsyncEnumerable.AsyncEnumerableAdapter<TElement>(this);
+            return adapter.GetEnumerator();
+        }
     }
 }

+ 37 - 1
Ix.NET/Source/System.Interactive.Async/Lookup.cs

@@ -120,7 +120,7 @@ namespace System.Linq
 
 namespace System.Linq.Internal
 {
-    internal class Lookup<TKey, TElement> : ILookup<TKey, TElement>, IIListProvider<IGrouping<TKey, TElement>>
+    internal class Lookup<TKey, TElement> : ILookup<TKey, TElement>, IIListProvider<IGrouping<TKey, TElement>>, IIListProvider<IAsyncGrouping<TKey, TElement>>
     {
         private readonly IEqualityComparer<TKey> _comparer;
         private Grouping<TKey, TElement>[] _groupings;
@@ -240,6 +240,25 @@ namespace System.Linq.Internal
             return lookup;
         }
 
+        internal static async Task<Lookup<TKey, TElement>> CreateAsync(IAsyncEnumerable<TElement> source, Func<TElement, TKey> keySelector,  IEqualityComparer<TKey> comparer, CancellationToken cancellationToken)
+        {
+            Debug.Assert(source != null);
+            Debug.Assert(keySelector != null);
+
+            var lookup = new Lookup<TKey, TElement>(comparer);
+            using (var enu = source.GetEnumerator())
+            {
+                while (await enu.MoveNext(cancellationToken)
+                                .ConfigureAwait(false))
+                {
+                    lookup.GetGrouping(keySelector(enu.Current), create: true)
+                          .Add(enu.Current);
+                }
+            }
+
+            return lookup;
+        }
+
         internal static Lookup<TKey, TElement> CreateForJoin(IEnumerable<TElement> source, Func<TElement, TKey> keySelector, IEqualityComparer<TKey> comparer)
         {
             var lookup = new Lookup<TKey, TElement>(comparer);
@@ -402,6 +421,7 @@ namespace System.Linq.Internal
             return Task.FromResult(array);
         }
 
+
         public Task<List<IGrouping<TKey, TElement>>> ToListAsync(CancellationToken cancellationToken)
         {
             var list = new List<IGrouping<TKey, TElement>>(Count);
@@ -423,5 +443,21 @@ namespace System.Linq.Internal
         {
             return Task.FromResult(Count);
         }
+
+        IAsyncEnumerator<IAsyncGrouping<TKey, TElement>> IAsyncEnumerable<IAsyncGrouping<TKey, TElement>>.GetEnumerator()
+        {
+            return new AsyncEnumerable.AsyncEnumerableAdapter<IAsyncGrouping<TKey, TElement>>(Enumerable.Cast<IAsyncGrouping<TKey, TElement>>(this)).GetEnumerator();
+        }
+
+        Task<List<IAsyncGrouping<TKey, TElement>>> IIListProvider<IAsyncGrouping<TKey, TElement>>.ToListAsync(CancellationToken cancellationToken)
+        {
+            throw new NotImplementedException();
+        }
+
+        Task<IAsyncGrouping<TKey, TElement>[]> IIListProvider<IAsyncGrouping<TKey, TElement>>.ToArrayAsync(CancellationToken cancellationToken)
+        {
+            throw new NotImplementedException();
+        }
+
     }
 }

+ 40 - 34
Ix.NET/Source/Tests/AsyncTests.Single.cs

@@ -1589,21 +1589,23 @@ namespace Tests
 
             var e = ys.GetEnumerator();
 
-            Assert.True(e.MoveNext().Result);
-            var g1 = e.Current;
-            Assert.Equal(g1.Key, 42);
-            var g1e = g1.GetEnumerator();
-            HasNext(g1e, 42);
+            AssertThrows<Exception>(() => e.MoveNext().Wait(WaitTimeoutMs), ex_ => ((AggregateException)ex_).Flatten().InnerExceptions.Single().Message == "Bang!");
 
-            Assert.True(e.MoveNext().Result);
-            var g2 = e.Current;
-            Assert.Equal(g2.Key, 43);
-            var g2e = g2.GetEnumerator();
-            HasNext(g2e, 43);
+            //Assert.True(e.MoveNext().Result);
+            //var g1 = e.Current;
+            //Assert.Equal(g1.Key, 42);
+            //var g1e = g1.GetEnumerator();
+            //HasNext(g1e, 42);
 
-            AssertThrows<Exception>(() => e.MoveNext().Wait(WaitTimeoutMs), ex_ => ((AggregateException)ex_).Flatten().InnerExceptions.Single().Message == "Bang!");
-            AssertThrows<Exception>(() => g1e.MoveNext().Wait(WaitTimeoutMs), ex_ => ((AggregateException)ex_).Flatten().InnerExceptions.Single().Message == "Bang!");
-            AssertThrows<Exception>(() => g2e.MoveNext().Wait(WaitTimeoutMs), ex_ => ((AggregateException)ex_).Flatten().InnerExceptions.Single().Message == "Bang!");
+            //Assert.True(e.MoveNext().Result);
+            //var g2 = e.Current;
+            //Assert.Equal(g2.Key, 43);
+            //var g2e = g2.GetEnumerator();
+            //HasNext(g2e, 43);
+
+            
+            //AssertThrows<Exception>(() => g1e.MoveNext().Wait(WaitTimeoutMs), ex_ => ((AggregateException)ex_).Flatten().InnerExceptions.Single().Message == "Bang!");
+            //AssertThrows<Exception>(() => g2e.MoveNext().Wait(WaitTimeoutMs), ex_ => ((AggregateException)ex_).Flatten().InnerExceptions.Single().Message == "Bang!");
         }
 
         [Fact]
@@ -1614,14 +1616,16 @@ namespace Tests
 
             var e = ys.GetEnumerator();
 
-            Assert.True(e.MoveNext().Result);
-            var g1 = e.Current;
-            Assert.Equal(g1.Key, 42);
-            var g1e = g1.GetEnumerator();
-            HasNext(g1e, 42);
-            AssertThrows<Exception>(() => g1e.MoveNext().Wait(WaitTimeoutMs), ex_ => ((AggregateException)ex_).Flatten().InnerExceptions.Single().Message == "Bang!");
+            AssertThrows<Exception>(() => e.MoveNext().Wait(WaitTimeoutMs), ex_ => ((AggregateException)ex_).Flatten().InnerExceptions.Single().Message == "Bang!");
+
+            //Assert.True(e.MoveNext().Result);
+            //var g1 = e.Current;
+            //Assert.Equal(g1.Key, 42);
+            //var g1e = g1.GetEnumerator();
+            //HasNext(g1e, 42);
+            //AssertThrows<Exception>(() => g1e.MoveNext().Wait(WaitTimeoutMs), ex_ => ((AggregateException)ex_).Flatten().InnerExceptions.Single().Message == "Bang!");
 
-            AssertThrows<Exception>(() => e.MoveNext().Wait(WaitTimeoutMs), ex_ => ((AggregateException)ex_).Flatten().InnerExceptions.Single().Message == "Bang!");           
+                    
         }
 
         static IEnumerable<int> GetXs()
@@ -1651,21 +1655,23 @@ namespace Tests
 
             var e = ys.GetEnumerator();
 
-            Assert.True(e.MoveNext().Result);
-            var g1 = e.Current;
-            Assert.Equal(g1.Key, 1);
-            var g1e = g1.GetEnumerator();
-            HasNext(g1e, 1);
-
-            Assert.True(e.MoveNext().Result);
-            var g2 = e.Current;
-            Assert.Equal(g2.Key, 2);
-            var g2e = g2.GetEnumerator();
-            HasNext(g2e, 2);
-
             AssertThrows<Exception>(() => e.MoveNext().Wait(WaitTimeoutMs), ex_ => ((AggregateException)ex_).Flatten().InnerExceptions.Single() == ex);
-            AssertThrows<Exception>(() => g1e.MoveNext().Wait(WaitTimeoutMs), ex_ => ((AggregateException)ex_).Flatten().InnerExceptions.Single() == ex);
-            AssertThrows<Exception>(() => g2e.MoveNext().Wait(WaitTimeoutMs), ex_ => ((AggregateException)ex_).Flatten().InnerExceptions.Single() == ex);
+
+            //Assert.True(e.MoveNext().Result);
+            //var g1 = e.Current;
+            //Assert.Equal(g1.Key, 1);
+            //var g1e = g1.GetEnumerator();
+            //HasNext(g1e, 1);
+
+            //Assert.True(e.MoveNext().Result);
+            //var g2 = e.Current;
+            //Assert.Equal(g2.Key, 2);
+            //var g2e = g2.GetEnumerator();
+            //HasNext(g2e, 2);
+
+           
+            //AssertThrows<Exception>(() => g1e.MoveNext().Wait(WaitTimeoutMs), ex_ => ((AggregateException)ex_).Flatten().InnerExceptions.Single() == ex);
+            //AssertThrows<Exception>(() => g2e.MoveNext().Wait(WaitTimeoutMs), ex_ => ((AggregateException)ex_).Flatten().InnerExceptions.Single() == ex);
         }
 
         [Fact]