Browse Source

Ensure cancellation tests have something in-flight to cancel

Oren Novotny 9 years ago
parent
commit
fcb7426fe9
2 changed files with 87 additions and 12 deletions
  1. 86 11
      Ix.NET/Source/Tests/AsyncTests.Bugs.cs
  2. 1 1
      Ix.NET/Source/Tests/AsyncTests.Creation.cs

+ 86 - 11
Ix.NET/Source/Tests/AsyncTests.Bugs.cs

@@ -142,7 +142,8 @@ namespace Tests
             var e = ys.GetEnumerator();
             var e = ys.GetEnumerator();
             await Assert.ThrowsAsync<Exception>(() => e.MoveNext());
             await Assert.ThrowsAsync<Exception>(() => e.MoveNext());
 
 
-            Assert.True(disposed.Task.Result);
+            var result = await disposed.Task;
+            Assert.True(result);
         }
         }
 
 
         [Fact]
         [Fact]
@@ -150,10 +151,10 @@ namespace Tests
         {
         {
             var disposed = new TaskCompletionSource<bool>();
             var disposed = new TaskCompletionSource<bool>();
 
 
-            var xs = new[] { 1, 2, 3 }.WithDispose(() =>
+            var xs = new CancellationTestAsyncEnumerable().WithDispose(() =>
             {
             {
-                disposed.SetResult(true);
-            }).ToAsyncEnumerable();
+                disposed.TrySetResult(true);
+            });
 
 
             var ys = xs.Select(x => x + 1).Where(x => true);
             var ys = xs.Select(x => x + 1).Where(x => true);
 
 
@@ -190,7 +191,7 @@ namespace Tests
         [Fact]
         [Fact]
         public void CanCancelMoveNext()
         public void CanCancelMoveNext()
         {
         {
-            var xs = new CancellationTestEnumerable().Select(x => x).Where(x => true);
+            var xs = new CancellationTestAsyncEnumerable().Select(x => x).Where(x => true);
 
 
             var e = xs.GetEnumerator();
             var e = xs.GetEnumerator();
             var cts = new CancellationTokenSource();
             var cts = new CancellationTokenSource();
@@ -212,27 +213,92 @@ namespace Tests
         /// <summary>
         /// <summary>
         /// Waits WaitTimeoutMs or until cancellation is requested. If cancellation was not requested, MoveNext returns true.
         /// Waits WaitTimeoutMs or until cancellation is requested. If cancellation was not requested, MoveNext returns true.
         /// </summary>
         /// </summary>
-        private sealed class CancellationTestEnumerable : IAsyncEnumerable<object>
+        private sealed class CancellationTestAsyncEnumerable : IAsyncEnumerable<int>
         {
         {
-            public IAsyncEnumerator<object> GetEnumerator() => new TestEnumerator();
+            private readonly int iterationsBeforeDelay;
 
 
-            private sealed class TestEnumerator : IAsyncEnumerator<object>
+            public CancellationTestAsyncEnumerable(int iterationsBeforeDelay = 0)
             {
             {
+                this.iterationsBeforeDelay = iterationsBeforeDelay;
+            }
+            public IAsyncEnumerator<int> GetEnumerator() => new TestEnumerator(iterationsBeforeDelay);
+
+            private sealed class TestEnumerator : IAsyncEnumerator<int>
+            {
+                private readonly int iterationsBeforeDelay;
+
+                public TestEnumerator(int iterationsBeforeDelay)
+                {
+                    this.iterationsBeforeDelay = iterationsBeforeDelay;
+                }
+                int i = -1;
                 public void Dispose()
                 public void Dispose()
                 {
                 {
                 }
                 }
-                
-                public object Current { get; }
+
+                public int Current => i;
                 
                 
                 public async Task<bool> MoveNext(CancellationToken cancellationToken)
                 public async Task<bool> MoveNext(CancellationToken cancellationToken)
                 {
                 {
-                    await Task.Delay(WaitTimeoutMs, cancellationToken);
+                    i++;
+                    if (Current >= iterationsBeforeDelay)
+                    {
+                        await Task.Delay(WaitTimeoutMs, cancellationToken);
+                    }
                     cancellationToken.ThrowIfCancellationRequested();
                     cancellationToken.ThrowIfCancellationRequested();
                     return true;
                     return true;
                 }
                 }
             }
             }
         }
         }
 
 
+        /// <summary>
+        /// Waits WaitTimeoutMs or until cancellation is requested. If cancellation was not requested, MoveNext returns true.
+        /// </summary>
+        private sealed class CancellationTestEnumerable<T> : IEnumerable<T>
+        {
+            private readonly CancellationToken cancellationToken;
+
+            public CancellationTestEnumerable()
+            {
+            }
+            public IEnumerator<T> GetEnumerator() => new TestEnumerator();
+
+            private sealed class TestEnumerator : IEnumerator<T>
+            {
+                private readonly CancellationTokenSource cancellationTokenSource;
+
+                public TestEnumerator()
+                {
+                    cancellationTokenSource = new CancellationTokenSource();
+                }
+                public void Dispose()
+                {
+                    cancellationTokenSource.Cancel();
+                }
+
+                public void Reset()
+                {
+                  
+                }
+
+                object IEnumerator.Current => Current;
+
+                public T Current { get; }
+
+                public bool MoveNext()
+                {
+                    Task.Delay(WaitTimeoutMs, cancellationTokenSource.Token).Wait();
+                    cancellationTokenSource.Token.ThrowIfCancellationRequested();
+                    return true;
+                }
+            }
+
+            IEnumerator IEnumerable.GetEnumerator()
+            {
+                return GetEnumerator();
+            }
+        }
+
         [Fact]
         [Fact]
         public void ToAsyncEnumeratorCannotCancelOnceRunning()
         public void ToAsyncEnumeratorCannotCancelOnceRunning()
         {
         {
@@ -357,6 +423,15 @@ namespace Tests
             });
             });
         }
         }
 
 
+        public static IAsyncEnumerable<T> WithDispose<T>(this IAsyncEnumerable<T> source, Action a)
+        {
+            return AsyncEnumerable.CreateEnumerable<T>(() =>
+            {
+                var e = source.GetEnumerator();
+                return AsyncEnumerable.CreateEnumerator<T>(e.MoveNext, () => e.Current, () => { e.Dispose(); a(); });
+            });
+        }
+
         class Enumerator<T> : IEnumerator<T>
         class Enumerator<T> : IEnumerator<T>
         {
         {
             private readonly Func<bool> _moveNext;
             private readonly Func<bool> _moveNext;

+ 1 - 1
Ix.NET/Source/Tests/AsyncTests.Creation.cs

@@ -370,7 +370,7 @@ namespace Tests
                     i++;
                     i++;
                     return new MyD(() => { disposed.TrySetResult(true); });
                     return new MyD(() => { disposed.TrySetResult(true); });
                 },
                 },
-                _ => AsyncEnumerable.Range(0, 10)
+                _ => new CancellationTestAsyncEnumerable(2) // need to use this to verify we actually cancel
             );
             );
 
 
             Assert.Equal(0, i);
             Assert.Equal(0, i);