1
0
Эх сурвалжийг харах

Add support for IIListProvider-based optimizations

Oren Novotny 9 жил өмнө
parent
commit
1d3c387596

+ 6 - 0
Ix.NET/Source/System.Interactive.Async/Count.cs

@@ -17,6 +17,12 @@ namespace System.Linq
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
 
+            var listProv = source as IIListProvider<TSource>;
+            if (listProv != null)
+            {
+                return listProv.GetCountAsync(onlyIfCheap: false, cancellationToken: cancellationToken);
+            }
+
             return source.Aggregate(0, (c, _) => checked(c + 1), cancellationToken);
         }
 

+ 35 - 1
Ix.NET/Source/System.Interactive.Async/DefaultIfEmpty.cs

@@ -3,6 +3,7 @@
 // See the LICENSE file in the project root for more information. 
 
 using System;
+using System.Collections;
 using System.Collections.Generic;
 using System.Diagnostics;
 using System.Linq;
@@ -29,7 +30,7 @@ namespace System.Linq
             return DefaultIfEmpty(source, default(TSource));
         }
 
-        private sealed class DefaultIfEmptyAsyncIterator<TSource> : AsyncIterator<TSource>
+        private sealed class DefaultIfEmptyAsyncIterator<TSource> : AsyncIterator<TSource>, IIListProvider<TSource>
         {
             private readonly IAsyncEnumerable<TSource> source;
             private readonly TSource defaultValue;
@@ -90,6 +91,39 @@ namespace System.Linq
                 Dispose();
                 return false;
             }
+
+            public async Task<TSource[]> ToArrayAsync(CancellationToken cancellationToken)
+            {
+                var array = await source.ToArray(cancellationToken).ConfigureAwait(false);
+                return array.Length == 0 ? new[] { defaultValue } : array;
+            }
+
+            public async Task<List<TSource>> ToListAsync(CancellationToken cancellationToken)
+            {
+                var list = await source.ToList(cancellationToken).ConfigureAwait(false);
+                if (list.Count == 0)
+                {
+                    list.Add(defaultValue);
+                }
+
+                return list;
+            }
+
+            public async Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
+            {
+                int count;
+                if (!onlyIfCheap || source is ICollection<TSource> || source is ICollection)
+                {
+                    count = await source.Count(cancellationToken).ConfigureAwait(false);
+                }
+                else
+                {
+                    var listProv = source as IIListProvider<TSource>;
+                    count = listProv == null ? -1 : await listProv.GetCountAsync(onlyIfCheap: true, cancellationToken: cancellationToken).ConfigureAwait(false);
+                }
+
+                return count == 0 ? 1 : count;
+            }
         }
     }
 }

+ 1 - 1
Ix.NET/Source/System.Interactive.Async/GroupJoin.cs

@@ -48,7 +48,7 @@ namespace System.Linq
             return outer.GroupJoin(inner, outerKeySelector, innerKeySelector, resultSelector, EqualityComparer<TKey>.Default);
         }
 
-        private sealed class AsyncEnumerableAdapter<T> : IAsyncEnumerable<T>
+        internal sealed class AsyncEnumerableAdapter<T> : IAsyncEnumerable<T>
         {
             private readonly IEnumerable<T> _source;
 

+ 34 - 0
Ix.NET/Source/System.Interactive.Async/IIListProvider.cs

@@ -0,0 +1,34 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace System.Linq
+{
+    /// <summary>
+    /// An iterator that can produce an array or <see cref="List{TElement}"/> through an optimized path.
+    /// </summary>
+    interface IIListProvider<TElement> : IAsyncEnumerable<TElement>
+    {
+        /// <summary>
+        /// Produce an array of the sequence through an optimized path.
+        /// </summary>
+        /// <returns>The array.</returns>
+        Task<TElement[]> ToArrayAsync(CancellationToken cancellationToken);
+
+        /// <summary>
+        /// Produce a <see cref="List{TElement}"/> of the sequence through an optimized path.
+        /// </summary>
+        /// <returns>The <see cref="List{TElement}"/>.</returns>
+        Task<List<TElement>> ToListAsync(CancellationToken cancellationToken);
+
+        /// <summary>
+        /// Returns the count of elements in the sequence.
+        /// </summary>
+        /// <param name="onlyIfCheap">If true then the count should only be calculated if doing
+        /// so is quick (sure or likely to be constant time), otherwise -1 should be returned.</param>
+        /// <returns>The number of elements.</returns>
+        Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken);
+    }
+}

+ 52 - 6
Ix.NET/Source/System.Interactive.Async/Lookup.cs

@@ -120,7 +120,7 @@ namespace System.Linq
 
 namespace System.Linq.Internal
 {
-    internal class Lookup<TKey, TElement> : ILookup<TKey, TElement>
+    internal class Lookup<TKey, TElement> : ILookup<TKey, TElement>, IIListProvider<IGrouping<TKey, TElement>>
     {
         private readonly IEqualityComparer<TKey> _comparer;
         private Grouping<TKey, TElement>[] _groupings;
@@ -157,6 +157,11 @@ namespace System.Linq.Internal
             return GetGrouping(key, create: false) != null;
         }
 
+        IEnumerator IEnumerable.GetEnumerator()
+        {
+            return GetEnumerator();
+        }
+
         public IEnumerator<IGrouping<TKey, TElement>> GetEnumerator()
         {
             var g = _lastGrouping;
@@ -170,11 +175,6 @@ namespace System.Linq.Internal
             }
         }
 
-        IEnumerator IEnumerable.GetEnumerator()
-        {
-            return GetEnumerator();
-        }
-
         public IEnumerable<TResult> ApplyResultSelector<TResult>(Func<TKey, IEnumerable<TElement>, TResult> resultSelector)
         {
             var g = _lastGrouping;
@@ -377,5 +377,51 @@ namespace System.Linq.Internal
 
             _groupings = newGroupings;
         }
+
+        IAsyncEnumerator<IGrouping<TKey, TElement>> IAsyncEnumerable<IGrouping<TKey, TElement>>.GetEnumerator()
+        {
+            return new AsyncEnumerable.AsyncEnumerableAdapter<IGrouping<TKey, TElement>>(this).GetEnumerator();
+        }
+
+        public Task<IGrouping<TKey, TElement>[]> ToArrayAsync(CancellationToken cancellationToken)
+        {
+            var array = new IGrouping<TKey, TElement>[Count];
+            var index = 0;
+            var g = _lastGrouping;
+            if (g != null)
+            {
+                do
+                {
+                    g = g._next;
+                    array[index] = g;
+                    ++index;
+                }
+                while (g != _lastGrouping);
+            }
+
+            return Task.FromResult(array);
+        }
+
+        public Task<List<IGrouping<TKey, TElement>>> ToListAsync(CancellationToken cancellationToken)
+        {
+            var list = new List<IGrouping<TKey, TElement>>(Count);
+            var g = _lastGrouping;
+            if (g != null)
+            {
+                do
+                {
+                    g = g._next;
+                    list.Add(g);
+                }
+                while (g != _lastGrouping);
+            }
+
+            return Task.FromResult(list);
+        }
+
+        public Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
+        {
+            return Task.FromResult(Count);
+        }
     }
 }

+ 9 - 0
Ix.NET/Source/System.Interactive.Async/ToCollection.cs

@@ -25,6 +25,11 @@ namespace System.Linq
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
 
+            var arrayProvider = source as IIListProvider<TSource>;
+            if (arrayProvider != null)
+                return arrayProvider.ToArrayAsync(cancellationToken);
+
+
             return source.Aggregate(new List<TSource>(), (list, x) =>
                                                          {
                                                              list.Add(x);
@@ -145,6 +150,10 @@ namespace System.Linq
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
 
+            var listProvider = source as IIListProvider<TSource>;
+            if (listProvider != null)
+                return listProvider.ToListAsync(cancellationToken);
+
             return source.Aggregate(new List<TSource>(), (list, x) =>
                                                          {
                                                              list.Add(x);