Browse Source

IxAsync.Timeout: propagate timeout cancellation to main src

akarnokd 6 years ago
parent
commit
dd6adb6fd3

+ 107 - 0
Ix.NET/Source/System.Interactive.Async.Tests/System/Linq/Operators/Timeout.cs

@@ -0,0 +1,107 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the Apache 2.0 License.
+// See the LICENSE file in the project root for more information. 
+
+using System;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using Xunit;
+
+namespace Tests
+{
+    public class Timeout : AsyncEnumerableExTests
+    {
+        [Fact]
+        public async Task Timeout_Never()
+        {
+            var source = AsyncEnumerableEx.Never<int>().Timeout(TimeSpan.FromMilliseconds(100));
+
+            var en = source.GetAsyncEnumerator();
+
+            try
+            {
+                await en.MoveNextAsync();
+
+                Assert.False(true, "MoveNextAsync should have thrown");
+            }
+            catch (TimeoutException)
+            {
+                // expected
+            }
+            finally
+            {
+                await en.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Timeout_Delayed_Main()
+        {
+            var source = AsyncEnumerable.Range(1, 5)
+                .SelectAwait(async v =>
+                {
+                    await Task.Delay(300);
+                    return v;
+                })
+                .Timeout(TimeSpan.FromMilliseconds(100));
+
+            var en = source.GetAsyncEnumerator();
+
+            try
+            {
+                await en.MoveNextAsync();
+
+                Assert.False(true, "MoveNextAsync should have thrown");
+            }
+            catch (TimeoutException)
+            {
+                // expected
+            }
+            finally
+            {
+                await en.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Timeout_Delayed_Main_Canceled()
+        {
+            var tcs = new TaskCompletionSource<int>();
+
+            var source = AsyncEnumerable.Range(1, 5)
+                .SelectAwaitWithCancellation(async (v, ct) =>
+                {
+                    try
+                    {
+                        await Task.Delay(500, ct);
+                    }
+                    catch (TaskCanceledException)
+                    {
+                        tcs.SetResult(0);
+                    }
+                    return v;
+                })
+                .Timeout(TimeSpan.FromMilliseconds(250));
+
+            var en = source.GetAsyncEnumerator();
+
+            try
+            {
+                await en.MoveNextAsync();
+
+                Assert.False(true, "MoveNextAsync should have thrown");
+            }
+            catch (TimeoutException)
+            {
+                // expected
+            }
+            finally
+            {
+                await en.DisposeAsync();
+            }
+
+            Assert.Equal(0, await tcs.Task);
+        }
+    }
+}

+ 7 - 2
Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Timeout.cs

@@ -32,6 +32,8 @@ namespace System.Linq
 
             private Task? _loserTask;
 
+            private CancellationTokenSource? _sourceCTS;
+
             public TimeoutAsyncIterator(IAsyncEnumerable<TSource> source, TimeSpan timeout)
             {
                 Debug.Assert(source != null);
@@ -67,7 +69,8 @@ namespace System.Linq
                 switch (_state)
                 {
                     case AsyncIteratorState.Allocated:
-                        _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
+                        _sourceCTS = CancellationTokenSource.CreateLinkedTokenSource(_cancellationToken);
+                        _enumerator = _source.GetAsyncEnumerator(_sourceCTS.Token);
 
                         _state = AsyncIteratorState.Iterating;
                         goto case AsyncIteratorState.Iterating;
@@ -77,7 +80,7 @@ namespace System.Linq
 
                         if (!moveNext.IsCompleted)
                         {
-                            using var delayCts = new CancellationTokenSource();
+                            using var delayCts = CancellationTokenSource.CreateLinkedTokenSource(_cancellationToken);
 
                             var delay = Task.Delay(_timeout, delayCts.Token);
 
@@ -101,6 +104,8 @@ namespace System.Linq
 
                                 _loserTask = next.ContinueWith((_, state) => ((IAsyncDisposable)state).DisposeAsync().AsTask(), _enumerator);
 
+                                _sourceCTS!.Cancel();
+
                                 throw new TimeoutException();
                             }