Преглед изворни кода

Merge branch 'master' into AsyncIxSomeTests

David Karnok пре 6 година
родитељ
комит
c7ec529ec5

+ 132 - 0
Ix.NET/Source/System.Interactive.Async.Tests/System/Linq/Operators/Timeout.cs

@@ -0,0 +1,132 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the Apache 2.0 License.
+// See the LICENSE file in the project root for more information. 
+
+using System;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using Xunit;
+
+namespace Tests
+{
+    public class Timeout : AsyncEnumerableExTests
+    {
+        [Fact]
+        public async Task Timeout_Never()
+        {
+            var source = AsyncEnumerableEx.Never<int>().Timeout(TimeSpan.FromMilliseconds(100));
+
+            var en = source.GetAsyncEnumerator();
+
+            try
+            {
+                await en.MoveNextAsync();
+
+                Assert.False(true, "MoveNextAsync should have thrown");
+            }
+            catch (TimeoutException)
+            {
+                // expected
+            }
+            finally
+            {
+                await en.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Timeout_Double_Never()
+        {
+            var source = AsyncEnumerableEx.Never<int>()
+                .Timeout(TimeSpan.FromMilliseconds(300))
+                .Timeout(TimeSpan.FromMilliseconds(100));
+
+            var en = source.GetAsyncEnumerator();
+
+            try
+            {
+                await en.MoveNextAsync();
+
+                Assert.False(true, "MoveNextAsync should have thrown");
+            }
+            catch (TimeoutException)
+            {
+                // expected
+            }
+            finally
+            {
+                await en.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Timeout_Delayed_Main()
+        {
+            var source = AsyncEnumerable.Range(1, 5)
+                .SelectAwait(async v =>
+                {
+                    await Task.Delay(300);
+                    return v;
+                })
+                .Timeout(TimeSpan.FromMilliseconds(100));
+
+            var en = source.GetAsyncEnumerator();
+
+            try
+            {
+                await en.MoveNextAsync();
+
+                Assert.False(true, "MoveNextAsync should have thrown");
+            }
+            catch (TimeoutException)
+            {
+                // expected
+            }
+            finally
+            {
+                await en.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Timeout_Delayed_Main_Canceled()
+        {
+            var tcs = new TaskCompletionSource<int>();
+
+            var source = AsyncEnumerable.Range(1, 5)
+                .SelectAwaitWithCancellation(async (v, ct) =>
+                {
+                    try
+                    {
+                        await Task.Delay(500, ct);
+                    }
+                    catch (TaskCanceledException)
+                    {
+                        tcs.SetResult(0);
+                    }
+                    return v;
+                })
+                .Timeout(TimeSpan.FromMilliseconds(250));
+
+            var en = source.GetAsyncEnumerator();
+
+            try
+            {
+                await en.MoveNextAsync();
+
+                Assert.False(true, "MoveNextAsync should have thrown");
+            }
+            catch (TimeoutException)
+            {
+                // expected
+            }
+            finally
+            {
+                await en.DisposeAsync();
+            }
+
+            Assert.Equal(0, await tcs.Task);
+        }
+    }
+}

+ 12 - 2
Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Timeout.cs

@@ -31,6 +31,8 @@ namespace System.Linq
 
             private Task? _loserTask;
 
+            private CancellationTokenSource? _sourceCTS;
+
             public TimeoutAsyncIterator(IAsyncEnumerable<TSource> source, TimeSpan timeout)
             {
                 _source = source;
@@ -55,6 +57,11 @@ namespace System.Linq
                     await _enumerator.DisposeAsync().ConfigureAwait(false);
                     _enumerator = null;
                 }
+                if (_sourceCTS != null)
+                {
+                    _sourceCTS.Dispose();
+                    _sourceCTS = null;
+                }
 
                 await base.DisposeAsync().ConfigureAwait(false);
             }
@@ -64,7 +71,8 @@ namespace System.Linq
                 switch (_state)
                 {
                     case AsyncIteratorState.Allocated:
-                        _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
+                        _sourceCTS = CancellationTokenSource.CreateLinkedTokenSource(_cancellationToken);
+                        _enumerator = _source.GetAsyncEnumerator(_sourceCTS.Token);
 
                         _state = AsyncIteratorState.Iterating;
                         goto case AsyncIteratorState.Iterating;
@@ -74,7 +82,7 @@ namespace System.Linq
 
                         if (!moveNext.IsCompleted)
                         {
-                            using var delayCts = new CancellationTokenSource();
+                            using var delayCts = CancellationTokenSource.CreateLinkedTokenSource(_cancellationToken);
 
                             var delay = Task.Delay(_timeout, delayCts.Token);
 
@@ -98,6 +106,8 @@ namespace System.Linq
 
                                 _loserTask = next.ContinueWith((_, state) => ((IAsyncDisposable)state!).DisposeAsync().AsTask(), _enumerator);
 
+                                _sourceCTS!.Cancel();
+
                                 throw new TimeoutException();
                             }