1
0
Oren Novotny 9 жил өмнө
parent
commit
04b3be8f96

+ 115 - 0
Ix.NET/Source/System.Interactive.Async/AsyncEnumerableHelpers.cs

@@ -0,0 +1,115 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the Apache 2.0 License.
+// See the LICENSE file in the project root for more information. 
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace System.Collections.Generic
+{
+    // Based on https://github.com/dotnet/corefx/blob/ec2685715b01d12f16b08d0dfa326649b12db8ec/src/Common/src/System/Collections/Generic/EnumerableHelpers.cs
+    internal static class AsyncEnumerableHelpers
+    {
+        internal static async Task<T[]> ToArray<T>(IAsyncEnumerable<T> source, CancellationToken cancellationToken)
+        {
+            var result = await ToArrayWithLength(source, cancellationToken)
+                             .ConfigureAwait(false);
+            Array.Resize(ref result.array, result.length);
+            return result.array;
+        }
+
+        internal static async Task<ArrayWithLength<T>> ToArrayWithLength<T>(IAsyncEnumerable<T> source, CancellationToken cancellationToken)
+        {
+            var result = new ArrayWithLength<T>();
+            // Check for short circuit optimizations. This one is very unlikely
+            // but could be here as a group
+            var ic = source as ICollection<T>;
+            if (ic != null)
+            {
+                var count = ic.Count;
+                if (count != 0)
+                {
+                    // Allocate an array of the desired size, then copy the elements into it. Note that this has the same 
+                    // issue regarding concurrency as other existing collections like List<T>. If the collection size 
+                    // concurrently changes between the array allocation and the CopyTo, we could end up either getting an 
+                    // exception from overrunning the array (if the size went up) or we could end up not filling as many 
+                    // items as 'count' suggests (if the size went down).  This is only an issue for concurrent collections 
+                    // that implement ICollection<T>, which as of .NET 4.6 is just ConcurrentDictionary<TKey, TValue>.
+                    result.array = new T[count];
+                    ic.CopyTo(result.array, 0);
+                    result.length = count;
+                    return result;
+                }
+            }
+            else
+            {
+                using (var en = source.GetEnumerator())
+                {
+                    if (await en.MoveNext(cancellationToken)
+                                .ConfigureAwait(false))
+                    {
+                        const int DefaultCapacity = 4;
+                        var arr = new T[DefaultCapacity];
+                        arr[0] = en.Current;
+                        var count = 1;
+
+                        while (await en.MoveNext(cancellationToken)
+                                       .ConfigureAwait(false))
+                        {
+                            if (count == arr.Length)
+                            {
+                                // MaxArrayLength is defined in Array.MaxArrayLength and in gchelpers in CoreCLR.
+                                // It represents the maximum number of elements that can be in an array where
+                                // the size of the element is greater than one byte; a separate, slightly larger constant,
+                                // is used when the size of the element is one.
+                                const int MaxArrayLength = 0x7FEFFFFF;
+
+                                // This is the same growth logic as in List<T>:
+                                // If the array is currently empty, we make it a default size.  Otherwise, we attempt to 
+                                // double the size of the array.  Doubling will overflow once the size of the array reaches
+                                // 2^30, since doubling to 2^31 is 1 larger than Int32.MaxValue.  In that case, we instead 
+                                // constrain the length to be MaxArrayLength (this overflow check works because of the 
+                                // cast to uint).  Because a slightly larger constant is used when T is one byte in size, we 
+                                // could then end up in a situation where arr.Length is MaxArrayLength or slightly larger, such 
+                                // that we constrain newLength to be MaxArrayLength but the needed number of elements is actually 
+                                // larger than that.  For that case, we then ensure that the newLength is large enough to hold 
+                                // the desired capacity.  This does mean that in the very rare case where we've grown to such a 
+                                // large size, each new element added after MaxArrayLength will end up doing a resize.
+                                var newLength = count << 1;
+                                if ((uint)newLength > MaxArrayLength)
+                                {
+                                    newLength = MaxArrayLength <= count ? count + 1 : MaxArrayLength;
+                                }
+
+                                Array.Resize(ref arr, newLength);
+                            }
+
+                            arr[count++] = en.Current;
+                        }
+
+                        result.length = count;
+                        result.array = arr;
+                        return result;
+                    }
+                }
+            }
+
+            result.length = 0;
+#if NO_ARRAY_EMPTY
+            result.array = EmptyArray<T>.Value;
+#else
+            result.array = Array.Empty<T>();
+#endif
+            return result;
+        }
+
+        internal struct ArrayWithLength<T>
+        {
+            public T[] array;
+            public int length;
+        }
+    }
+}

+ 241 - 84
Ix.NET/Source/System.Interactive.Async/Concatenate.cs

@@ -4,6 +4,7 @@
 
 using System;
 using System.Collections.Generic;
+using System.Diagnostics;
 using System.Linq;
 using System.Threading;
 using System.Threading.Tasks;
@@ -19,47 +20,10 @@ namespace System.Linq
             if (second == null)
                 throw new ArgumentNullException(nameof(second));
 
-            return CreateEnumerable(
-                () =>
-                {
-                    var switched = false;
-                    var e = first.GetEnumerator();
-
-                    var cts = new CancellationTokenDisposable();
-                    var a = new AssignableDisposable
-                    {
-                        Disposable = e
-                    };
-                    var d = Disposable.Create(cts, a);
-
-                    var f = default(Func<CancellationToken, Task<bool>>);
-                    f = async ct =>
-                        {
-                            if (await e.MoveNext(ct)
-                                       .ConfigureAwait(false))
-                            {
-                                return true;
-                            }
-                            if (switched)
-                            {
-                                return false;
-                            }
-                            switched = true;
-
-                            e = second.GetEnumerator();
-                            a.Disposable = e;
-
-                            return await f(ct)
-                                       .ConfigureAwait(false);
-                        };
-
-                    return CreateEnumerator(
-                        f,
-                        () => e.Current,
-                        d.Dispose,
-                        e
-                    );
-                });
+            var concatFirst = first as ConcatAsyncIterator<TSource>;
+            return concatFirst != null ?
+                       concatFirst.Concat(second) :
+                       new Concat2AsyncIterator<TSource>(first, second);
         }
 
         public static IAsyncEnumerable<TSource> Concat<TSource>(this IEnumerable<IAsyncEnumerable<TSource>> sources)
@@ -80,53 +44,246 @@ namespace System.Linq
 
         private static IAsyncEnumerable<TSource> Concat_<TSource>(this IEnumerable<IAsyncEnumerable<TSource>> sources)
         {
-            return CreateEnumerable(
-                () =>
+            using (var e = sources.GetEnumerator())
+            {
+                IAsyncEnumerable<TSource> prev = null;
+                while (e.MoveNext())
+                {
+                    if (prev == null)
+                    {
+                        prev = e.Current;
+                    }
+                    else
+                    {
+                        prev = prev.Concat(e.Current);
+                    }
+                }
+
+                return prev ?? Empty<TSource>();
+            }
+        }
+
+        private sealed class Concat2AsyncIterator<TSource> : ConcatAsyncIterator<TSource>
+        {
+            private readonly IAsyncEnumerable<TSource> first;
+            private readonly IAsyncEnumerable<TSource> second;
+
+            internal Concat2AsyncIterator(IAsyncEnumerable<TSource> first, IAsyncEnumerable<TSource> second)
+            {
+                Debug.Assert(first != null && second != null);
+
+                this.first = first;
+                this.second = second;
+            }
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new Concat2AsyncIterator<TSource>(first, second);
+            }
+
+            internal override ConcatAsyncIterator<TSource> Concat(IAsyncEnumerable<TSource> next)
+            {
+                return new ConcatNAsyncIterator<TSource>(this, next, 2);
+            }
+
+            internal override IAsyncEnumerable<TSource> GetAsyncEnumerable(int index)
+            {
+                switch (index)
                 {
-                    var se = sources.GetEnumerator();
-                    var e = default(IAsyncEnumerator<TSource>);
+                    case 0:
+                        return first;
+                    case 1:
+                        return second;
+                    default:
+                        return null;
+                }
+            }
+        }
+
+        private abstract class ConcatAsyncIterator<TSource> : AsyncIterator<TSource>, IIListProvider<TSource>
+        {
+            private int counter;
+            private IAsyncEnumerator<TSource> enumerator;
 
-                    var cts = new CancellationTokenDisposable();
-                    var a = new AssignableDisposable();
-                    var d = Disposable.Create(cts, se, a);
+            public Task<TSource[]> ToArrayAsync(CancellationToken cancellationToken)
+            {
+                return AsyncEnumerableHelpers.ToArray(this, cancellationToken);
+            }
 
-                    var f = default(Func<CancellationToken, Task<bool>>);
-                    f = async ct =>
+            public async Task<List<TSource>> ToListAsync(CancellationToken cancellationToken)
+            {
+                var list = new List<TSource>();
+                for (var i = 0;; i++)
+                {
+                    var source = GetAsyncEnumerable(i);
+                    if (source == null)
+                    {
+                        break;
+                    }
+                    using (var e = source.GetEnumerator())
+                    {
+                        while (await e.MoveNext(cancellationToken)
+                                      .ConfigureAwait(false))
                         {
-                            if (e == null)
-                            {
-                                var b = false;
-                                b = se.MoveNext();
-                                if (b)
-                                    e = se.Current.GetEnumerator();
-
-                                if (!b)
-                                {
-                                    return false;
-                                }
-
-                                a.Disposable = e;
-                            }
-
-                            if (await e.MoveNext(ct)
-                                       .ConfigureAwait(false))
-                            {
-                                return true;
-                            }
-                            e.Dispose();
-                            e = null;
-
-                            return await f(ct)
-                                       .ConfigureAwait(false);
-                        };
-
-                    return CreateEnumerator(
-                        f,
-                        () => e.Current,
-                        d.Dispose,
-                        a
-                    );
-                });
+                            list.Add(e.Current);
+                        }
+                    }
+                }
+
+                return list;
+            }
+
+            public async Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
+            {
+                if (onlyIfCheap)
+                {
+                    return -1;
+                }
+
+                var count = 0;
+                for (var i = 0;; i++)
+                {
+                    var source = GetAsyncEnumerable(i);
+                    if (source == null)
+                    {
+                        break;
+                    }
+
+                    checked
+                    {
+                        count += await source.Count(cancellationToken);
+                    }
+                }
+
+                return count;
+            }
+
+            public override void Dispose()
+            {
+                if (enumerator != null)
+                {
+                    enumerator.Dispose();
+                    enumerator = null;
+                }
+
+                base.Dispose();
+            }
+
+            protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                if (state == State.Allocated)
+                {
+                    enumerator = GetAsyncEnumerable(0)
+                        .GetEnumerator();
+                    state = State.Iterating;
+                    counter = 2;
+                }
+
+                if (state == State.Iterating)
+                {
+                    while (true)
+                    {
+                        if (await enumerator.MoveNext(cancellationToken)
+                                            .ConfigureAwait(false))
+                        {
+                            current = enumerator.Current;
+                            return true;
+                        }
+                        // note, this is simply to match the logic of 
+                        // https://github.com/dotnet/corefx/blob/ec2685715b01d12f16b08d0dfa326649b12db8ec/src/system.linq/src/system/linq/concatenate.cs#L173-L173
+                        var next = GetAsyncEnumerable(counter++ - 1);
+                        if (next != null)
+                        {
+                            enumerator.Dispose();
+                            enumerator = next.GetEnumerator();
+                            continue;
+                        }
+
+                        Dispose();
+                        break;
+                    }
+                }
+
+                return false;
+            }
+
+            internal abstract ConcatAsyncIterator<TSource> Concat(IAsyncEnumerable<TSource> next);
+
+            internal abstract IAsyncEnumerable<TSource> GetAsyncEnumerable(int index);
+        }
+
+        // To handle chains of >= 3 sources, we chain the concat iterators together and allow
+        // GetEnumerable to fetch enumerables from the previous sources.  This means that rather
+        // than each MoveNext/Current calls having to traverse all of the previous sources, we
+        // only have to traverse all of the previous sources once per chained enumerable.  An
+        // alternative would be to use an array to store all of the enumerables, but this has
+        // a much better memory profile and without much additional run-time cost.
+        private sealed class ConcatNAsyncIterator<TSource> : ConcatAsyncIterator<TSource>
+        {
+            private readonly IAsyncEnumerable<TSource> next;
+            private readonly int nextIndex;
+            private readonly ConcatAsyncIterator<TSource> previousConcat;
+
+            internal ConcatNAsyncIterator(ConcatAsyncIterator<TSource> previousConcat, IAsyncEnumerable<TSource> next, int nextIndex)
+            {
+                Debug.Assert(previousConcat != null);
+                Debug.Assert(next != null);
+                Debug.Assert(nextIndex >= 2);
+
+                this.previousConcat = previousConcat;
+                this.next = next;
+                this.nextIndex = nextIndex;
+            }
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new ConcatNAsyncIterator<TSource>(previousConcat, next, nextIndex);
+            }
+
+            internal override ConcatAsyncIterator<TSource> Concat(IAsyncEnumerable<TSource> next)
+            {
+                if (nextIndex == int.MaxValue - 2)
+                {
+                    // In the unlikely case of this many concatenations, if we produced a ConcatNIterator
+                    // with int.MaxValue then state would overflow before it matched it's index.
+                    // So we use the naïve approach of just having a left and right sequence.
+                    return new Concat2AsyncIterator<TSource>(this, next);
+                }
+
+                return new ConcatNAsyncIterator<TSource>(this, next, nextIndex + 1);
+            }
+
+            internal override IAsyncEnumerable<TSource> GetAsyncEnumerable(int index)
+            {
+                if (index > nextIndex)
+                {
+                    return null;
+                }
+
+                // Walk back through the chain of ConcatNIterators looking for the one
+                // that has its _nextIndex equal to index.  If we don't find one, then it
+                // must be prior to any of them, so call GetEnumerable on the previous
+                // Concat2Iterator.  This avoids a deep recursive call chain.
+                var current = this;
+                while (true)
+                {
+                    if (index == current.nextIndex)
+                    {
+                        return current.next;
+                    }
+
+                    var prevN = current.previousConcat as ConcatNAsyncIterator<TSource>;
+                    if (prevN != null)
+                    {
+                        current = prevN;
+                        continue;
+                    }
+
+                    Debug.Assert(current.previousConcat is Concat2AsyncIterator<TSource>);
+                    Debug.Assert(index == 0 || index == 1);
+                    return current.previousConcat.GetAsyncEnumerable(index);
+                }
+            }
         }
     }
 }

+ 1 - 6
Ix.NET/Source/System.Interactive.Async/ToCollection.cs

@@ -29,12 +29,7 @@ namespace System.Linq
             if (arrayProvider != null)
                 return arrayProvider.ToArrayAsync(cancellationToken);
 
-
-            return source.Aggregate(new List<TSource>(), (list, x) =>
-                                                         {
-                                                             list.Add(x);
-                                                             return list;
-                                                         }, list => list.ToArray(), cancellationToken);
+            return AsyncEnumerableHelpers.ToArray(source, cancellationToken);
         }
 
         public static Task<Dictionary<TKey, TElement>> ToDictionary<TSource, TKey, TElement>(this IAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector, Func<TSource, TElement> elementSelector, IEqualityComparer<TKey> comparer)

+ 1 - 1
Ix.NET/Source/Tests/AsyncTests.Conversions.cs

@@ -324,7 +324,7 @@ namespace Tests
 
             var ae = AsyncEnumerable.CreateEnumerable(
                 () => AsyncEnumerable.CreateEnumerator<int>(
-                    async ct => false,
+                    ct => Task.FromResult(false),
                     () => { throw new InvalidOperationException(); },
                     () => { evt.Set(); }));
 

+ 3 - 9
Ix.NET/Source/Tests/AsyncTests.Multiple.cs

@@ -104,15 +104,9 @@ namespace Tests
         [Fact]
         public void Concat6()
         {
-            var res = AsyncEnumerable.Concat(ConcatXss());
-
-            var e = res.GetEnumerator();
-            HasNext(e, 1);
-            HasNext(e, 2);
-            HasNext(e, 3);
-            HasNext(e, 4);
-            HasNext(e, 5);
-            AssertThrows<Exception>(() => e.MoveNext().Wait(WaitTimeoutMs), ex_ => ((AggregateException)ex_).Flatten().InnerExceptions.Single().Message == "Bang!");
+            // Note: Concat does an eager traverse of the enumerables to build up
+            // its sequences. If the outer enumerable throws it'll trhow here.
+            AssertThrows<Exception>(() => AsyncEnumerable.Concat(ConcatXss()), ex_ => ex_.Message == "Bang!");
         }
 
         static IEnumerable<IAsyncEnumerable<int>> ConcatXss()