Ver Fonte

Cancel in-flight TaskCompletionSources in ToAsyncEnumerable(Observable).

Daniel C. Weber há 5 anos atrás
pai
commit
fa1c97f3e1

+ 29 - 0
Ix.NET/Source/System.Linq.Async.Tests/System/Linq/Operators/ToAsyncEnumerable.cs

@@ -393,6 +393,35 @@ namespace Tests
             stop.WaitOne();
         }
 
+        [Fact]
+        public async Task ToAsyncEnumerable_Observable_Cancel_InFlight()
+        {
+            var xs = new MyObservable<int>(obs =>
+            {
+                var cts = new CancellationTokenSource();
+
+                Task.Run(async () =>
+                {
+                    for (var i = 0; !cts.IsCancellationRequested; i++)
+                    {
+                        await Task.Delay(10);
+                        obs.OnNext(i);
+                    }
+                });
+
+                return new MyDisposable(cts.Cancel);
+            }).ToAsyncEnumerable();
+
+            using var c = new CancellationTokenSource();
+
+            await using var e = xs.GetAsyncEnumerator(c.Token);
+
+            var task = e.MoveNextAsync();
+            c.Cancel();
+
+            await AssertThrowsAsync<TaskCanceledException>(task.AsTask());
+        }
+
         [Fact]
         public async Task ToAsyncEnumerable_Observable6_Async()
         {

+ 26 - 1
Ix.NET/Source/System.Linq.Async/System/Linq/Operators/ToAsyncEnumerable.Observable.cs

@@ -169,7 +169,32 @@ namespace System.Linq
 
             private void DisposeSubscription() => Interlocked.Exchange(ref _subscription, null)?.Dispose();
 
-            private void OnCanceled(object? state) => Dispose();
+            private void OnCanceled(object? state)
+            {
+                var cancelledTcs = default(TaskCompletionSource<bool>);
+
+                Dispose();
+
+                while (true)
+                {
+                    var signal = Volatile.Read(ref _signal);
+
+                    if (signal != null)
+                    {
+                        if (signal.TrySetCanceled(_cancellationToken))
+                            return;
+                    }
+
+                    if (cancelledTcs == null)
+                    {
+                        cancelledTcs = new TaskCompletionSource<bool>();
+                        cancelledTcs.TrySetCanceled(_cancellationToken);
+                    }
+
+                    if (Interlocked.CompareExchange(ref _signal, cancelledTcs, signal) == signal)
+                        return;
+                }
+            }
 
             private Task Resume()
             {