瀏覽代碼

Add optimizations for IList as it passes from ToAsyncEnumerable

Oren Novotny 9 年之前
父節點
當前提交
8a8061bf11

+ 6 - 0
Ix.NET/Source/System.Interactive.Async/Count.cs

@@ -17,6 +17,12 @@ namespace System.Linq
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
 
+            var collectionoft = source as ICollection<TSource>;
+            if (collectionoft != null)
+            {
+                return Task.FromResult(collectionoft.Count);
+            }
+
             var listProv = source as IIListProvider<TSource>;
             if (listProv != null)
             {

+ 6 - 0
Ix.NET/Source/System.Interactive.Async/ElementAt.cs

@@ -51,6 +51,12 @@ namespace System.Linq
 
         private static async Task<TSource> ElementAt_<TSource>(IAsyncEnumerable<TSource> source, int index, CancellationToken cancellationToken)
         {
+            var list = source as IList<TSource>;
+            if (list != null)
+            {
+                return list[index];
+            }
+
             if (index >= 0)
             {
                 using (var e = source.GetEnumerator())

+ 12 - 0
Ix.NET/Source/System.Interactive.Async/First.cs

@@ -88,6 +88,12 @@ namespace System.Linq
 
         private static async Task<TSource> First_<TSource>(IAsyncEnumerable<TSource> source, CancellationToken cancellationToken)
         {
+            var list = source as IList<TSource>;
+            if (list?.Count > 0)
+            {
+                return list[0];
+            }
+
             using (var e = source.GetEnumerator())
             {
                 if (await e.MoveNext(cancellationToken)
@@ -101,6 +107,12 @@ namespace System.Linq
 
         private static async Task<TSource> FirstOrDefault_<TSource>(IAsyncEnumerable<TSource> source, CancellationToken cancellationToken)
         {
+            var list = source as IList<TSource>;
+            if (list?.Count > 0)
+            {
+                return list[0];
+            }
+
             using (var e = source.GetEnumerator())
             {
                 if (await e.MoveNext(cancellationToken)

+ 20 - 0
Ix.NET/Source/System.Interactive.Async/Last.cs

@@ -91,6 +91,16 @@ namespace System.Linq
             var last = default(TSource);
             var hasLast = false;
 
+            var list = source as IList<TSource>;
+            if (list != null)
+            {
+                var count = list.Count;
+                if (count > 0)
+                {
+                    return list[count - 1];
+                }
+            }
+
             using (var e = source.GetEnumerator())
             {
                 while (await e.MoveNext(cancellationToken)
@@ -110,6 +120,16 @@ namespace System.Linq
             var last = default(TSource);
             var hasLast = false;
 
+            var list = source as IList<TSource>;
+            if (list != null)
+            {
+                var count = list.Count;
+                if (count > 0)
+                {
+                    return list[count - 1];
+                }
+            }
+
             using (var e = source.GetEnumerator())
             {
                 while (await e.MoveNext(cancellationToken)

+ 65 - 1
Ix.NET/Source/System.Interactive.Async/Select.cs

@@ -27,7 +27,11 @@ namespace System.Linq
                 return iterator.Select(selector);
             }
 
-            // TODO: Can we add optimizations for IList or anything else here?
+            var ilist = source as IList<TSource>;
+            if (ilist != null)
+            {
+               return new SelectIListIterator<TSource, TResult>(ilist, selector);
+            }
 
             return new SelectEnumerableAsyncIterator<TSource, TResult>(source, selector);
         }
@@ -136,5 +140,65 @@ namespace System.Linq
                 return new SelectEnumerableAsyncIterator<TSource, TResult1>(source, CombineSelectors(this.selector, selector));
             }
         }
+
+        internal sealed class SelectIListIterator<TSource, TResult> : AsyncIterator<TResult>
+        {
+            private readonly IList<TSource> source;
+            private readonly Func<TSource, TResult> selector;
+            private IEnumerator<TSource> enumerator;
+
+            public SelectIListIterator(IList<TSource> source, Func<TSource, TResult> selector)
+            {
+                Debug.Assert(source != null);
+                Debug.Assert(selector != null);
+
+                this.source = source;
+                this.selector = selector;
+            }
+
+            public override AsyncIterator<TResult> Clone()
+            {
+                return new SelectIListIterator<TSource, TResult>(source, selector);
+            }
+
+            public override void Dispose()
+            {
+                if (enumerator != null)
+                {
+                    enumerator.Dispose();
+                    enumerator = null;
+                }
+
+                base.Dispose();
+            }
+
+            protected override Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetEnumerator();
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        if ( enumerator.MoveNext())
+                        {
+                            current = selector(enumerator.Current);
+                            return Task.FromResult(true);
+                        }
+
+                        Dispose();
+                        break;
+                }
+
+                return Task.FromResult(false);
+            }
+
+            public override IAsyncEnumerable<TResult1> Select<TResult1>(Func<TResult, TResult1> selector)
+            {
+                return new SelectIListIterator<TSource, TResult1>(source, CombineSelectors(this.selector, selector));
+            }
+        }
     }
 }

+ 10 - 0
Ix.NET/Source/System.Interactive.Async/SequenceEqual.cs

@@ -60,6 +60,16 @@ namespace System.Linq
         private static async Task<bool> SequenceEqual_<TSource>(IAsyncEnumerable<TSource> first, IAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer,
                                                                 CancellationToken cancellationToken)
         {
+            var firstCol = first as ICollection<TSource>;
+            if (firstCol != null)
+            {
+                var secondCol = second as ICollection<TSource>;
+                if (secondCol != null && firstCol.Count != secondCol.Count)
+                {
+                    return false;
+                }
+            }
+
             using (var e1 = first.GetEnumerator())
             using (var e2 = second.GetEnumerator())
             {

+ 22 - 0
Ix.NET/Source/System.Interactive.Async/Single.cs

@@ -89,6 +89,17 @@ namespace System.Linq
 
         private static async Task<TSource> Single_<TSource>(IAsyncEnumerable<TSource> source, CancellationToken cancellationToken)
         {
+            var list = source as IList<TSource>;
+            if (list != null)
+            {
+                switch (list.Count)
+                {
+                    case 0: throw new InvalidOperationException(Strings.NO_ELEMENTS);
+                    case 1: return list[0];
+                }
+                throw new InvalidOperationException(Strings.MORE_THAN_ONE_ELEMENT);
+            }
+
             using (var e = source.GetEnumerator())
             {
                 if (!await e.MoveNext(cancellationToken)
@@ -108,6 +119,17 @@ namespace System.Linq
 
         private static async Task<TSource> SingleOrDefault_<TSource>(IAsyncEnumerable<TSource> source, CancellationToken cancellationToken)
         {
+            var list = source as IList<TSource>;
+            if (list != null)
+            {
+                switch (list.Count)
+                {
+                    case 0: return default(TSource);
+                    case 1: return list[0];
+                }
+                throw new InvalidOperationException(Strings.MORE_THAN_ONE_ELEMENT);
+            }
+
             using (var e = source.GetEnumerator())
             {
                 if (!await e.MoveNext(cancellationToken)

+ 175 - 1
Ix.NET/Source/System.Interactive.Async/ToAsyncEnumerable.cs

@@ -3,6 +3,7 @@
 // See the LICENSE file in the project root for more information. 
 
 using System;
+using System.Collections;
 using System.Collections.Generic;
 using System.Diagnostics;
 using System.Linq;
@@ -18,6 +19,15 @@ namespace System.Linq
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
 
+            // optimize these adapters for lists and collections
+            var ilist = source as IList<TSource>;
+            if (ilist != null)
+                return new AsyncIListEnumerableAdapter<TSource>(ilist);
+
+            var icoll = source as ICollection<TSource>;
+            if (icoll != null)
+                return new AsyncICollectionEnumerableAdapter<TSource>(icoll);
+
             return new AsyncEnumerableAdapter<TSource>(source);
         }
 
@@ -137,5 +147,169 @@ namespace System.Linq
                 return Task.FromResult(source.Count());
             }
         }
+
+        internal sealed class AsyncIListEnumerableAdapter<T> : AsyncIterator<T>, IIListProvider<T>, IList<T>
+        {
+            private readonly IList<T> source;
+            private IEnumerator<T> enumerator;
+
+            public AsyncIListEnumerableAdapter(IList<T> source)
+            {
+                Debug.Assert(source != null);
+                this.source = source;
+            }
+
+            public override AsyncIterator<T> Clone()
+            {
+                return new AsyncEnumerableAdapter<T>(source);
+            }
+
+            public override void Dispose()
+            {
+                if (enumerator != null)
+                {
+                    enumerator.Dispose();
+                    enumerator = null;
+                }
+
+                base.Dispose();
+            }
+
+            protected override Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetEnumerator();
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        if (enumerator.MoveNext())
+                        {
+                            current = enumerator.Current;
+                            return Task.FromResult(true);
+                        }
+
+                        Dispose();
+                        break;
+                }
+
+                return Task.FromResult(false);
+            }
+
+            // These optimizations rely on the Sys.Linq impls from IEnumerable to optimize
+            // and short circuit as appropriate
+            public Task<T[]> ToArrayAsync(CancellationToken cancellationToken)
+            {
+                return Task.FromResult(source.ToArray());
+            }
+
+            public Task<List<T>> ToListAsync(CancellationToken cancellationToken)
+            {
+                return Task.FromResult(source.ToList());
+            }
+
+            public Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
+            {
+                return Task.FromResult(source.Count());
+            }
+
+            IEnumerator<T> IEnumerable<T>.GetEnumerator() => source.GetEnumerator();
+
+            IEnumerator IEnumerable.GetEnumerator() => source.GetEnumerator();
+
+            void ICollection<T>.Add(T item) => source.Add(item);
+
+            void ICollection<T>.Clear() => source.Clear();
+
+            bool ICollection<T>.Contains(T item) => source.Contains(item);
+
+            void ICollection<T>.CopyTo(T[] array, int arrayIndex) => source.CopyTo(array, arrayIndex);
+
+            bool ICollection<T>.Remove(T item) => source.Remove(item);
+
+            int ICollection<T>.Count => source.Count;
+
+            bool ICollection<T>.IsReadOnly => source.IsReadOnly;
+
+            int IList<T>.IndexOf(T item) => source.IndexOf(item);
+
+            void IList<T>.Insert(int index, T item) => source.Insert(index, item);
+
+            void IList<T>.RemoveAt(int index) => source.RemoveAt(index);
+
+            T IList<T>.this[int index]
+            {
+                get { return source[index]; }
+                set { source[index] = value; }
+            }
+        }
+
+        internal sealed class AsyncICollectionEnumerableAdapter<T> : AsyncIterator<T>, IIListProvider<T>, ICollection<T>
+        {
+            private readonly ICollection<T> source;
+            private IEnumerator<T> enumerator;
+            public AsyncICollectionEnumerableAdapter(ICollection<T> source)
+            {
+                Debug.Assert(source != null);
+                this.source = source;
+            }
+            public override AsyncIterator<T> Clone()
+            {
+                return new AsyncEnumerableAdapter<T>(source);
+            }
+            public override void Dispose()
+            {
+                if (enumerator != null)
+                {
+                    enumerator.Dispose();
+                    enumerator = null;
+                }
+                base.Dispose();
+            }
+            protected override Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetEnumerator();
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+                    case AsyncIteratorState.Iterating:
+                        if (enumerator.MoveNext())
+                        {
+                            current = enumerator.Current;
+                            return Task.FromResult(true);
+                        }
+                        Dispose();
+                        break;
+                }
+                return Task.FromResult(false);
+            }
+            // These optimizations rely on the Sys.Linq impls from IEnumerable to optimize
+            // and short circuit as appropriate
+            public Task<T[]> ToArrayAsync(CancellationToken cancellationToken)
+            {
+                return Task.FromResult(source.ToArray());
+            }
+            public Task<List<T>> ToListAsync(CancellationToken cancellationToken)
+            {
+                return Task.FromResult(source.ToList());
+            }
+            public Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
+            {
+                return Task.FromResult(source.Count());
+            }
+            IEnumerator<T> IEnumerable<T>.GetEnumerator() => source.GetEnumerator();
+            IEnumerator IEnumerable.GetEnumerator() => source.GetEnumerator();
+            void ICollection<T>.Add(T item) => source.Add(item);
+            void ICollection<T>.Clear() => source.Clear();
+            bool ICollection<T>.Contains(T item) => source.Contains(item);
+            void ICollection<T>.CopyTo(T[] array, int arrayIndex) => source.CopyTo(array, arrayIndex);
+            bool ICollection<T>.Remove(T item) => source.Remove(item);
+            int ICollection<T>.Count => source.Count;
+            bool ICollection<T>.IsReadOnly => source.IsReadOnly;
+        }
     }
-}
+}