Browse Source

Address some bugs that were hanging tests

Oren Novotny 9 years ago
parent
commit
abfb2c93da

+ 11 - 3
Ix.NET/Source/System.Interactive.Async/AsyncIterator.cs

@@ -23,6 +23,7 @@ namespace System.Linq
             internal State state = State.New;
             internal TSource current;
             private CancellationTokenSource cancellationTokenSource;
+            private List<CancellationTokenRegistration> moveNextRegistrations;
 
             protected AsyncIterator()
             {
@@ -37,6 +38,7 @@ namespace System.Linq
 
                 enumerator.state = State.Allocated;
                 enumerator.cancellationTokenSource = new CancellationTokenSource();
+                enumerator.moveNextRegistrations = new List<CancellationTokenRegistration>();
                 return enumerator;
             }
 
@@ -48,6 +50,11 @@ namespace System.Linq
                     cancellationTokenSource.Cancel();
                 }
                 cancellationTokenSource.Dispose();
+                foreach (var r in moveNextRegistrations)
+                {
+                    r.Dispose();
+                }
+                moveNextRegistrations.Clear();
                 current = default(TSource);
                 state = State.Disposed;
             }
@@ -59,14 +66,15 @@ namespace System.Linq
                 if (state == State.Disposed)
                     return false;
 
-                using (var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, cancellationTokenSource.Token))
-                using (cancellationToken.Register(Dispose))
+                // We keep these because cancelling any of these must trigger dispose of the iterator
+                moveNextRegistrations.Add(cancellationToken.Register(Dispose));
 
+                using (var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, cancellationTokenSource.Token))
                 {
                     try
                     {
                         var result = await MoveNextCore(cts.Token).ConfigureAwait(false);
-
+                        
                         return result;
                     }
                     catch

+ 2 - 0
Ix.NET/Source/System.Interactive.Async/Create.cs

@@ -76,6 +76,8 @@ namespace System.Linq
             return self;
         }
 
+        
+
         private class AnonymousAsyncEnumerable<T> : IAsyncEnumerable<T>
         {
             private readonly Func<IAsyncEnumerator<T>> getEnumerator;

+ 1 - 34
Ix.NET/Source/System.Interactive.Async/GroupJoin.cs

@@ -31,7 +31,6 @@ namespace System.Linq
             return new GroupJoinAsyncEnumerable<TOuter, TInner, TKey, TResult>(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer);
         }
 
-
         public static IAsyncEnumerable<TResult> GroupJoin<TOuter, TInner, TKey, TResult>(this IAsyncEnumerable<TOuter> outer, IAsyncEnumerable<TInner> inner, Func<TOuter, TKey> outerKeySelector, Func<TInner, TKey> innerKeySelector, Func<TOuter, IAsyncEnumerable<TInner>, TResult> resultSelector)
         {
             if (outer == null)
@@ -48,39 +47,7 @@ namespace System.Linq
             return outer.GroupJoin(inner, outerKeySelector, innerKeySelector, resultSelector, EqualityComparer<TKey>.Default);
         }
 
-        internal sealed class AsyncEnumerableAdapter<T> : IAsyncEnumerable<T>
-        {
-            private readonly IEnumerable<T> _source;
-
-            public AsyncEnumerableAdapter(IEnumerable<T> source)
-            {
-                _source = source;
-            }
-
-            public IAsyncEnumerator<T> GetEnumerator()
-                => new AsyncEnumeratorAdapter(_source.GetEnumerator());
-
-            private sealed class AsyncEnumeratorAdapter : IAsyncEnumerator<T>
-            {
-                private readonly IEnumerator<T> _enumerator;
-
-                public AsyncEnumeratorAdapter(IEnumerator<T> enumerator)
-                {
-                    _enumerator = enumerator;
-                }
-
-                public Task<bool> MoveNext(CancellationToken cancellationToken)
-                {
-                    cancellationToken.ThrowIfCancellationRequested();
-
-                    return Task.FromResult(_enumerator.MoveNext());
-                }
-
-                public T Current => _enumerator.Current;
-
-                public void Dispose() => _enumerator.Dispose();
-            }
-        }
+       
 
 
         private sealed class GroupJoinAsyncEnumerable<TOuter, TInner, TKey, TResult> : IAsyncEnumerable<TResult>

+ 0 - 1
Ix.NET/Source/System.Interactive.Async/Grouping.cs

@@ -279,7 +279,6 @@ namespace System.Linq
             public IAsyncEnumerator<IAsyncGrouping<TKey, TSource>> GetEnumerator()
             {
                 Internal.Lookup<TKey, TSource> lookup = null;
-                IAsyncGrouping<TKey, TSource> current = null;
                 IEnumerator<IGrouping<TKey, TSource>> enumerator = null;
 
                 return CreateEnumerator(

+ 3 - 0
Ix.NET/Source/System.Interactive.Async/IIListProvider.cs

@@ -14,12 +14,14 @@ namespace System.Linq
         /// <summary>
         /// Produce an array of the sequence through an optimized path.
         /// </summary>
+        /// <param name="cancellationToken"></param>
         /// <returns>The array.</returns>
         Task<TElement[]> ToArrayAsync(CancellationToken cancellationToken);
 
         /// <summary>
         /// Produce a <see cref="List{TElement}"/> of the sequence through an optimized path.
         /// </summary>
+        /// <param name="cancellationToken"></param>
         /// <returns>The <see cref="List{TElement}"/>.</returns>
         Task<List<TElement>> ToListAsync(CancellationToken cancellationToken);
 
@@ -28,6 +30,7 @@ namespace System.Linq
         /// </summary>
         /// <param name="onlyIfCheap">If true then the count should only be calculated if doing
         /// so is quick (sure or likely to be constant time), otherwise -1 should be returned.</param>
+        /// <param name="cancellationToken"></param>
         /// <returns>The number of elements.</returns>
         Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken);
     }

+ 4 - 1
Ix.NET/Source/System.Interactive.Async/Take.cs

@@ -27,6 +27,7 @@ namespace System.Linq
 
                     var cts = new CancellationTokenDisposable();
                     var d = Disposable.Create(cts, e);
+                    var current = default(TSource);
 
                     return CreateEnumerator(
                         async ct =>
@@ -38,13 +39,15 @@ namespace System.Linq
                                                 .ConfigureAwait(false);
 
                             --n;
+                            if (result)
+                                current = e.Current;
 
                             if (n == 0)
                                 e.Dispose();
 
                             return result;
                         },
-                        () => e.Current,
+                        () => current,
                         d.Dispose,
                         e
                     );

+ 70 - 24
Ix.NET/Source/System.Interactive.Async/ToAsyncEnumerable.cs

@@ -4,6 +4,7 @@
 
 using System;
 using System.Collections.Generic;
+using System.Diagnostics;
 using System.Linq;
 using System.Threading;
 using System.Threading.Tasks;
@@ -17,30 +18,7 @@ namespace System.Linq
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
 
-            return CreateEnumerable(
-                () =>
-                {
-                    var e = source.GetEnumerator();
-
-                    return CreateEnumerator(
-                        ct => Task.Run(() =>
-                                       {
-                                           var res = false;
-                                           try
-                                           {
-                                               res = e.MoveNext();
-                                           }
-                                           finally
-                                           {
-                                               if (!res)
-                                                   e.Dispose();
-                                           }
-                                           return res;
-                                       }, ct),
-                        () => e.Current,
-                        () => e.Dispose()
-                    );
-                });
+            return new AsyncEnumerableAdapter<TSource>(source);
         }
 
         public static IAsyncEnumerable<TSource> ToAsyncEnumerable<TSource>(this Task<TSource> task)
@@ -91,5 +69,73 @@ namespace System.Linq
                 }
             }
         }
+
+        internal sealed class AsyncEnumerableAdapter<T> : AsyncIterator<T>, IIListProvider<T>
+        {
+            private readonly IEnumerable<T> source;
+            private IEnumerator<T> enumerator;
+ 
+            public AsyncEnumerableAdapter(IEnumerable<T> source)
+            {
+                Debug.Assert(source != null);
+                this.source = source;
+            }
+
+            public override AsyncIterator<T> Clone()
+            {
+                return new AsyncEnumerableAdapter<T>(source);
+            }
+
+            public override void Dispose()
+            {
+                if (enumerator != null)
+                {
+                    enumerator.Dispose();
+                    enumerator = null;
+                }
+
+                base.Dispose();
+            }
+
+            protected override Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case State.Allocated:
+                        enumerator = source.GetEnumerator();
+                        state = State.Iterating;
+                        goto case State.Iterating;
+
+                    case State.Iterating:
+                        if (enumerator.MoveNext())
+                        {
+                            current = enumerator.Current;
+                            return Task.FromResult(true);
+                        }
+
+                        Dispose();
+                        break;
+                }
+                
+                return Task.FromResult(false);
+            }
+
+            // These optimizations rely on the Sys.Linq impls from IEnumerable to optimize
+            // and short circuit as appropriate
+            public Task<T[]> ToArrayAsync(CancellationToken cancellationToken)
+            {
+                return Task.FromResult(source.ToArray());
+            }
+
+            public Task<List<T>> ToListAsync(CancellationToken cancellationToken)
+            {
+                return Task.FromResult(source.ToList());
+            }
+
+            public Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
+            {
+                return Task.FromResult(source.Count());
+            }
+        }
     }
 }

+ 1 - 1
Ix.NET/Source/System.Interactive.Async/Using.cs

@@ -64,7 +64,7 @@ namespace System.Linq
                         },
                         () => current,
                         d.Dispose,
-                        null
+                        d
                     );
                 });
         }

+ 22 - 8
Ix.NET/Source/Tests/AsyncTests.Bugs.cs

@@ -146,7 +146,7 @@ namespace Tests
         }
 
         [Fact]
-        public void CorrectCancel()
+        public async Task CorrectCancel()
         {
             var disposed = new TaskCompletionSource<bool>();
 
@@ -179,10 +179,12 @@ namespace Tests
                 // it. This design is chosen because cancelling a MoveNext call leaves
                 // the enumerator in an indeterminate state. Further interactions with
                 // it should be forbidden.
-                Assert.True(disposed.Task.Result);
+
+                var result = await disposed.Task;
+                Assert.True(result);
             }
 
-            Assert.False(e.MoveNext().Result);
+            Assert.False(await e.MoveNext());
         }
 
         [Fact]
@@ -240,21 +242,33 @@ namespace Tests
 
             var e = xs.GetEnumerator();
             var cts = new CancellationTokenSource();
-            var t = e.MoveNext(cts.Token);
+
+
+            Task<bool> t = null;
+            var tMoveNext =Task.Run(
+                () =>
+                {
+                    // This call *will* block
+                    t = e.MoveNext(cts.Token);
+                });
+         
 
             isRunningEvent.WaitOne();
             cts.Cancel();
 
             try
             {
-                t.Wait(0);
+                tMoveNext.Wait(0);
                 Assert.False(t.IsCanceled);
             }
             catch
             {
-                Assert.False(true);
+                // T will still be null
+                Assert.Null(t);
             }
 
+
+            // enable it to finish
             evt.Set();
         }
 
@@ -266,7 +280,7 @@ namespace Tests
         }
 
         [Fact]
-        public void TakeOneFromSelectMany()
+        public async Task TakeOneFromSelectMany()
         {
             var enumerable = AsyncEnumerable
                 .Return(0)
@@ -274,7 +288,7 @@ namespace Tests
                 .Take(1)
                 .Do(_ => { });
 
-            Assert.Equal("Check", enumerable.First().Result);
+            Assert.Equal("Check", await enumerable.First());
         }
 
         [Fact]

+ 3 - 2
Ix.NET/Source/Tests/AsyncTests.Creation.cs

@@ -359,7 +359,7 @@ namespace Tests
         }
 
         [Fact]
-        public void Using6()
+        public async Task Using6()
         {
             var i = 0;
             var disposed = new TaskCompletionSource<bool>();
@@ -394,7 +394,8 @@ namespace Tests
                 ex.Flatten().Handle(inner => inner is TaskCanceledException);
             }
 
-            Assert.True(disposed.Task.Result);
+            Assert.True(disposed.Task.IsCompleted);
+            Assert.True(await disposed.Task);
         }
 
         class MyD : IDisposable