Browse Source

SelectMany can implement IAsyncIListProvider.

Bart De Smet 6 years ago
parent
commit
5cddc62b53

+ 267 - 4
Ix.NET/Source/System.Linq.Async/System/Linq/Operators/SelectMany.cs

@@ -346,7 +346,7 @@ namespace System.Linq
         }
 #endif
 
-        private sealed class SelectManyAsyncIterator<TSource, TResult> : AsyncIterator<TResult>
+        private sealed class SelectManyAsyncIterator<TSource, TResult> : AsyncIterator<TResult>, IAsyncIListProvider<TResult>
         {
             private const int State_Source = 1;
             private const int State_Result = 2;
@@ -389,6 +389,91 @@ namespace System.Linq
                 await base.DisposeAsync().ConfigureAwait(false);
             }
 
+            public ValueTask<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
+            {
+                if (onlyIfCheap)
+                {
+                    return new ValueTask<int>(-1);
+                }
+
+                return Core(cancellationToken);
+
+                async ValueTask<int> Core(CancellationToken _cancellationToken)
+                {
+                    var count = 0;
+
+#if CSHARP8 && AETOR_HAS_CT // CS0656 Missing compiler required member 'System.Collections.Generic.IAsyncEnumerable`1.GetAsyncEnumerator'
+                    await foreach (var element in _source.WithCancellation(_cancellationToken).ConfigureAwait(false))
+                    {
+                        checked
+                        {
+                            count += await _selector(element).CountAsync().ConfigureAwait(false);
+                        }
+                    }
+#else
+                    var e = _source.GetAsyncEnumerator(_cancellationToken);
+
+                    try
+                    {
+                        while (await e.MoveNextAsync().ConfigureAwait(false))
+                        {
+                            checked
+                            {
+                                count += await _selector(e.Current).CountAsync().ConfigureAwait(false);
+                            }
+                        }
+                    }
+                    finally
+                    {
+                        await e.DisposeAsync().ConfigureAwait(false);
+                    }
+#endif
+
+                    return count;
+                }
+            }
+
+            public async ValueTask<TResult[]> ToArrayAsync(CancellationToken cancellationToken)
+            {
+                // REVIEW: Substitute for SparseArrayBuilder<T> logic once we have access to that.
+
+                var list = await ToListAsync(cancellationToken).ConfigureAwait(false);
+
+                return list.ToArray();
+            }
+
+            public async ValueTask<List<TResult>> ToListAsync(CancellationToken cancellationToken)
+            {
+                var list = new List<TResult>();
+
+#if CSHARP8 && AETOR_HAS_CT // CS0656 Missing compiler required member 'System.Collections.Generic.IAsyncEnumerable`1.GetAsyncEnumerator'
+                await foreach (var element in _source.WithCancellation(cancellationToken).ConfigureAwait(false))
+                {
+                    var items = _selector(element);
+
+                    await list.AddRangeAsync(items, cancellationToken).ConfigureAwait(false);
+                }
+#else
+                var e = _source.GetAsyncEnumerator(cancellationToken);
+
+                try
+                {
+                    while (await e.MoveNextAsync().ConfigureAwait(false))
+                    {
+                        var items = _selector(e.Current);
+
+                        await list.AddRangeAsync(items, cancellationToken).ConfigureAwait(false);
+                    }
+                }
+                finally
+                {
+                    await e.DisposeAsync().ConfigureAwait(false);
+                }
+#endif
+
+                return list;
+            }
+
             protected override async ValueTask<bool> MoveNextCore()
             {
                 switch (_state)
@@ -437,7 +522,7 @@ namespace System.Linq
             }
         }
 
-        private sealed class SelectManyAsyncIteratorWithTask<TSource, TResult> : AsyncIterator<TResult>
+        private sealed class SelectManyAsyncIteratorWithTask<TSource, TResult> : AsyncIterator<TResult>, IAsyncIListProvider<TResult>
         {
             private const int State_Source = 1;
             private const int State_Result = 2;
@@ -480,6 +565,95 @@ namespace System.Linq
                 await base.DisposeAsync().ConfigureAwait(false);
             }
 
+            public ValueTask<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
+            {
+                if (onlyIfCheap)
+                {
+                    return new ValueTask<int>(-1);
+                }
+
+                return Core(cancellationToken);
+
+                async ValueTask<int> Core(CancellationToken _cancellationToken)
+                {
+                    var count = 0;
+
+#if CSHARP8 && AETOR_HAS_CT // CS0656 Missing compiler required member 'System.Collections.Generic.IAsyncEnumerable`1.GetAsyncEnumerator'
+                    await foreach (var element in _source.WithCancellation(_cancellationToken).ConfigureAwait(false))
+                    {
+                        var items = await _selector(element).ConfigureAwait(false);
+
+                        checked
+                        {
+                            count += await items.CountAsync().ConfigureAwait(false);
+                        }
+                    }
+#else
+                    var e = _source.GetAsyncEnumerator(_cancellationToken);
+
+                    try
+                    {
+                        while (await e.MoveNextAsync().ConfigureAwait(false))
+                        {
+                            var items = await _selector(e.Current).ConfigureAwait(false);
+
+                            checked
+                            {
+                                count += await items.CountAsync().ConfigureAwait(false);
+                            }
+                        }
+                    }
+                    finally
+                    {
+                        await e.DisposeAsync().ConfigureAwait(false);
+                    }
+#endif
+
+                    return count;
+                }
+            }
+
+            public async ValueTask<TResult[]> ToArrayAsync(CancellationToken cancellationToken)
+            {
+                // REVIEW: Substitute for SparseArrayBuilder<T> logic once we have access to that.
+
+                var list = await ToListAsync(cancellationToken).ConfigureAwait(false);
+
+                return list.ToArray();
+            }
+
+            public async ValueTask<List<TResult>> ToListAsync(CancellationToken cancellationToken)
+            {
+                var list = new List<TResult>();
+
+#if CSHARP8 && AETOR_HAS_CT // CS0656 Missing compiler required member 'System.Collections.Generic.IAsyncEnumerable`1.GetAsyncEnumerator'
+                await foreach (var element in _source.WithCancellation(cancellationToken).ConfigureAwait(false))
+                {
+                    var items = await _selector(element).ConfigureAwait(false);
+
+                    await list.AddRangeAsync(items, cancellationToken).ConfigureAwait(false);
+                }
+#else
+                var e = _source.GetAsyncEnumerator(cancellationToken);
+
+                try
+                {
+                    while (await e.MoveNextAsync().ConfigureAwait(false))
+                    {
+                        var items = await _selector(e.Current).ConfigureAwait(false);
+
+                        await list.AddRangeAsync(items, cancellationToken).ConfigureAwait(false);
+                    }
+                }
+                finally
+                {
+                    await e.DisposeAsync().ConfigureAwait(false);
+                }
+#endif
+
+                return list;
+            }
+
             protected override async ValueTask<bool> MoveNextCore()
             {
                 switch (_state)
@@ -529,7 +703,7 @@ namespace System.Linq
         }
 
 #if !NO_DEEP_CANCELLATION
-        private sealed class SelectManyAsyncIteratorWithTaskAndCancellation<TSource, TResult> : AsyncIterator<TResult>
+        private sealed class SelectManyAsyncIteratorWithTaskAndCancellation<TSource, TResult> : AsyncIterator<TResult>, IAsyncIListProvider<TResult>
         {
             private const int State_Source = 1;
             private const int State_Result = 2;
@@ -572,6 +746,95 @@ namespace System.Linq
                 await base.DisposeAsync().ConfigureAwait(false);
             }
 
+            public ValueTask<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
+            {
+                if (onlyIfCheap)
+                {
+                    return new ValueTask<int>(-1);
+                }
+
+                return Core(cancellationToken);
+
+                async ValueTask<int> Core(CancellationToken _cancellationToken)
+                {
+                    var count = 0;
+
+#if CSHARP8 && AETOR_HAS_CT // CS0656 Missing compiler required member 'System.Collections.Generic.IAsyncEnumerable`1.GetAsyncEnumerator'
+                    await foreach (var element in _source.WithCancellation(_cancellationToken).ConfigureAwait(false))
+                    {
+                        var items = await _selector(element, _cancellationToken).ConfigureAwait(false);
+
+                        checked
+                        {
+                            count += await items.CountAsync().ConfigureAwait(false);
+                        }
+                    }
+#else
+                    var e = _source.GetAsyncEnumerator(_cancellationToken);
+
+                    try
+                    {
+                        while (await e.MoveNextAsync().ConfigureAwait(false))
+                        {
+                            var items = await _selector(e.Current, _cancellationToken).ConfigureAwait(false);
+
+                            checked
+                            {
+                                count += await items.CountAsync().ConfigureAwait(false);
+                            }
+                        }
+                    }
+                    finally
+                    {
+                        await e.DisposeAsync().ConfigureAwait(false);
+                    }
+#endif
+
+                    return count;
+                }
+            }
+
+            public async ValueTask<TResult[]> ToArrayAsync(CancellationToken cancellationToken)
+            {
+                // REVIEW: Substitute for SparseArrayBuilder<T> logic once we have access to that.
+
+                var list = await ToListAsync(cancellationToken).ConfigureAwait(false);
+
+                return list.ToArray();
+            }
+
+            public async ValueTask<List<TResult>> ToListAsync(CancellationToken cancellationToken)
+            {
+                var list = new List<TResult>();
+
+#if CSHARP8 && AETOR_HAS_CT // CS0656 Missing compiler required member 'System.Collections.Generic.IAsyncEnumerable`1.GetAsyncEnumerator'
+                await foreach (var element in _source.WithCancellation(cancellationToken).ConfigureAwait(false))
+                {
+                    var items = await _selector(element, cancellationToken).ConfigureAwait(false);
+
+                    await list.AddRangeAsync(items, cancellationToken).ConfigureAwait(false);
+                }
+#else
+                var e = _source.GetAsyncEnumerator(cancellationToken);
+
+                try
+                {
+                    while (await e.MoveNextAsync().ConfigureAwait(false))
+                    {
+                        var items = await _selector(e.Current, cancellationToken).ConfigureAwait(false);
+
+                        await list.AddRangeAsync(items, cancellationToken).ConfigureAwait(false);
+                    }
+                }
+                finally
+                {
+                    await e.DisposeAsync().ConfigureAwait(false);
+                }
+#endif
+
+                return list;
+            }
+
             protected override async ValueTask<bool> MoveNextCore()
             {
                 switch (_state)
@@ -621,7 +884,7 @@ namespace System.Linq
         }
 #endif
 
-#if !(CSHARP8 && USE_ASYNC_ITERATOR)
+#if !(CSHARP8 && USE_ASYNC_ITERATOR && ASYNC_ITERATOR_CAN_RETURN_AETOR)
         private sealed class SelectManyAsyncIterator<TSource, TCollection, TResult> : AsyncIterator<TResult>
         {
             private const int State_Source = 1;

+ 63 - 0
Ix.NET/Source/System.Linq.Async/System/Linq/Operators/Utilities.cs

@@ -0,0 +1,63 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the Apache 2.0 License.
+// See the LICENSE file in the project root for more information. 
+
+using System.Collections.Generic;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace System.Linq
+{
+    internal static class Utilities
+    {
+        public static async ValueTask AddRangeAsync<T>(this List<T> list, IAsyncEnumerable<T> collection, CancellationToken cancellationToken)
+        {
+            if (collection is IEnumerable<T> enumerable)
+            {
+                list.AddRange(enumerable);
+                return;
+            }
+
+            if (collection is IAsyncIListProvider<T> listProvider)
+            {
+                var count = await listProvider.GetCountAsync(onlyIfCheap: true, cancellationToken).ConfigureAwait(false);
+
+                if (count == 0)
+                {
+                    return;
+                }
+
+                if (count > 0)
+                {
+                    var newCount = list.Count + count;
+
+                    if (list.Capacity < newCount)
+                    {
+                        list.Capacity = newCount;
+                    }
+                }
+            }
+
+#if CSHARP8 && AETOR_HAS_CT // CS0656 Missing compiler required member 'System.Collections.Generic.IAsyncEnumerable`1.GetAsyncEnumerator'
+            await foreach (var item in collection.WithCancellation(cancellationToken).ConfigureAwait(false))
+            {
+                list.Add(item);
+            }
+#else
+            var e = collection.GetAsyncEnumerator(cancellationToken);
+
+            try
+            {
+                while (await e.MoveNextAsync().ConfigureAwait(false))
+                {
+                    list.Add(e.Current);
+                }
+            }
+            finally
+            {
+                await e.DisposeAsync().ConfigureAwait(false);
+            }
+#endif
+        }
+    }
+}