Browse Source

Merge pull request #766 from danielcweber/AvoidLeakingTaskContinuations

Avoid leaking task continuations
Daniel C. Weber 7 years ago
parent
commit
0bf8cffdae

+ 4 - 7
Rx.NET/Source/src/System.Reactive/Concurrency/Scheduler.Async.cs

@@ -23,17 +23,14 @@ namespace System.Reactive.Concurrency
                 action(new CancelableScheduler(self, _cts.Token), s, _cts.Token).ContinueWith(
                     (t, thisObject) =>
                     {
-                        if (!t.IsCanceled)
-                        {
-                            var @this = (AsyncInvocation<TState>)thisObject;
+                        var @this = (AsyncInvocation<TState>)thisObject;
 
-                            t.Exception?.Handle(e => e is OperationCanceledException);
+                        t.Exception?.Handle(e => e is OperationCanceledException);
 
-                            Disposable.SetSingle(ref @this._run, t.Result);
-                        }
+                        Disposable.SetSingle(ref @this._run, t.Result);
                     },
                     this,
-                    TaskContinuationOptions.ExecuteSynchronously);
+                    TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.NotOnCanceled);
 
                 return this;
             }

+ 10 - 1
Rx.NET/Source/src/System.Reactive/Linq/Observable/Merge.cs

@@ -288,6 +288,7 @@ namespace System.Reactive.Linq.ObservableImpl
                 }
 
                 private readonly object _gate = new object();
+                private readonly CancellationTokenSource _cts = new CancellationTokenSource();
                 private volatile int _count = 1;
 
                 public override void OnNext(Task<TSource> value)
@@ -299,7 +300,7 @@ namespace System.Reactive.Linq.ObservableImpl
                     }
                     else
                     {
-                        value.ContinueWith((t, thisObject) => ((_)thisObject).OnCompletedTask(t), this);
+                        value.ContinueWith((t, thisObject) => ((_)thisObject).OnCompletedTask(t), this, _cts.Token);
                     }
                 }
 
@@ -354,6 +355,14 @@ namespace System.Reactive.Linq.ObservableImpl
                         }
                     }
                 }
+
+                protected override void Dispose(bool disposing)
+                {
+                    if (disposing)
+                        _cts.Cancel();
+
+                    base.Dispose(disposing);
+                }
             }
         }
     }

+ 4 - 4
Rx.NET/Source/src/System.Reactive/Linq/Observable/SelectMany.cs

@@ -597,7 +597,7 @@ namespace System.Reactive.Linq.ObservableImpl
                     //
                     // Separate method to avoid closure in synchronous completion case.
                     //
-                    task.ContinueWith(t => OnCompletedTask(value, t));
+                    task.ContinueWith(t => OnCompletedTask(value, t), _cancel.Token);
                 }
 
                 private void OnCompletedTask(TSource value, Task<TCollection> task)
@@ -758,7 +758,7 @@ namespace System.Reactive.Linq.ObservableImpl
                     //
                     // Separate method to avoid closure in synchronous completion case.
                     //
-                    task.ContinueWith(t => OnCompletedTask(value, index, t));
+                    task.ContinueWith(t => OnCompletedTask(value, index, t), _cancel.Token);
                 }
 
                 private void OnCompletedTask(TSource value, int index, Task<TCollection> task)
@@ -1538,7 +1538,7 @@ namespace System.Reactive.Linq.ObservableImpl
                     }
                     else
                     {
-                        task.ContinueWith((closureTask, thisObject) => ((_)thisObject).OnCompletedTask(closureTask), this);
+                        task.ContinueWith((closureTask, thisObject) => ((_)thisObject).OnCompletedTask(closureTask), this, _cts.Token);
                     }
                 }
 
@@ -1670,7 +1670,7 @@ namespace System.Reactive.Linq.ObservableImpl
                     }
                     else
                     {
-                        task.ContinueWith((closureTask, thisObject) => ((_)thisObject).OnCompletedTask(closureTask), this);
+                        task.ContinueWith((closureTask, thisObject) => ((_)thisObject).OnCompletedTask(closureTask), this, _cts.Token);
                     }
                 }
 

+ 61 - 22
Rx.NET/Source/src/System.Reactive/Threading/Tasks/TaskObservableExtensions.cs

@@ -17,6 +17,65 @@ namespace System.Reactive.Threading.Tasks
     /// </summary>
     public static class TaskObservableExtensions
     {
+        private sealed class SlowTaskObservable : IObservable<Unit>
+        {
+            private readonly Task _task;
+            private readonly IScheduler _scheduler;
+
+            public SlowTaskObservable(Task task, IScheduler scheduler)
+            {
+                _task = task;
+                _scheduler = scheduler;
+            }
+
+            public IDisposable Subscribe(IObserver<Unit> observer)
+            {
+                if (observer == null)
+                {
+                    throw new ArgumentNullException(nameof(observer));
+                }
+
+                var cts = new CancellationDisposable();
+                var options = GetTaskContinuationOptions(_scheduler);
+
+                if (_scheduler == null)
+                    _task.ContinueWith((t, subjectObject) => t.EmitTaskResult((IObserver<Unit>)subjectObject), observer, cts.Token, options, TaskScheduler.Current);
+                else
+                    _task.ContinueWith((t, subjectObject) => _scheduler.ScheduleAction((t, subjectObject), tuple => tuple.t.EmitTaskResult((IObserver<Unit>)tuple.subjectObject)), observer, cts.Token, options, TaskScheduler.Current);
+
+                return cts;
+            }
+        }
+
+        private sealed class SlowTaskObservable<TResult> : IObservable<TResult>
+        {
+            private readonly Task<TResult> _task;
+            private readonly IScheduler _scheduler;
+
+            public SlowTaskObservable(Task<TResult> task, IScheduler scheduler)
+            {
+                _task = task;
+                _scheduler = scheduler;
+            }
+
+            public IDisposable Subscribe(IObserver<TResult> observer)
+            {
+                if (observer == null)
+                {
+                    throw new ArgumentNullException(nameof(observer));
+                }
+
+                var cts = new CancellationDisposable();
+                var options = GetTaskContinuationOptions(_scheduler);
+
+                if (_scheduler == null)
+                    _task.ContinueWith((t, subjectObject) => t.EmitTaskResult((IObserver<TResult>)subjectObject), observer, cts.Token, options, TaskScheduler.Current);
+                else
+                    _task.ContinueWith((t, subjectObject) => _scheduler.ScheduleAction((t, subjectObject), tuple => tuple.t.EmitTaskResult((IObserver<TResult>)tuple.subjectObject)), observer, cts.Token, options, TaskScheduler.Current);
+
+                return cts;
+            }
+        }
         /// <summary>
         /// Returns an observable sequence that signals when the task completes.
         /// </summary>
@@ -74,12 +133,7 @@ namespace System.Reactive.Threading.Tasks
                 return new Return<Unit>(Unit.Default, scheduler);
             }
 
-            var subject = new AsyncSubject<Unit>();
-            var options = GetTaskContinuationOptions(scheduler);
-
-            task.ContinueWith((t, subjectObject) => t.EmitTaskResult((AsyncSubject<Unit>)subjectObject), subject, options);
-
-            return subject.ToObservableResult(scheduler);
+            return new SlowTaskObservable(task, scheduler);
         }
 
         private static void EmitTaskResult(this Task task, IObserver<Unit> subject)
@@ -178,12 +232,7 @@ namespace System.Reactive.Threading.Tasks
                 return new Return<TResult>(task.Result, scheduler);
             }
 
-            var subject = new AsyncSubject<TResult>();
-            var options = GetTaskContinuationOptions(scheduler);
-
-            task.ContinueWith((t, subjectObject) => t.EmitTaskResult((AsyncSubject<TResult>)subjectObject), subject, options);
-
-            return subject.ToObservableResult(scheduler);
+            return new SlowTaskObservable<TResult>(task, scheduler);
         }
 
         private static void EmitTaskResult<TResult>(this Task<TResult> task, IObserver<TResult> subject)
@@ -225,16 +274,6 @@ namespace System.Reactive.Threading.Tasks
             return options;
         }
 
-        private static IObservable<TResult> ToObservableResult<TResult>(this AsyncSubject<TResult> subject, IScheduler scheduler)
-        {
-            if (scheduler != null)
-            {
-                return subject.ObserveOn(scheduler);
-            }
-
-            return subject.AsObservable();
-        }
-
         internal static IDisposable Subscribe<TResult>(this Task<TResult> task, IObserver<TResult> observer)
         {
             if (task.IsCompleted)