浏览代码

4.x: Make TailRecursiveSink lock-free and have less allocations (#499)

David Karnok 7 年之前
父节点
当前提交
10a44ad8fc

+ 1 - 1
Rx.NET/Source/src/System.Reactive/Internal/ConcatSink.cs

@@ -15,6 +15,6 @@ namespace System.Reactive
 
         protected override IEnumerable<IObservable<TSource>> Extract(IObservable<TSource> source) => (source as IConcatenatable<TSource>)?.GetSources();
 
-        public override void OnCompleted() => _recurse();
+        public override void OnCompleted() => Recurse();
     }
 }

+ 130 - 123
Rx.NET/Source/src/System.Reactive/Internal/TailRecursiveSink.cs

@@ -5,6 +5,7 @@
 using System.Collections.Generic;
 using System.Reactive.Concurrency;
 using System.Reactive.Disposables;
+using System.Threading;
 
 namespace System.Reactive
 {
@@ -15,164 +16,176 @@ namespace System.Reactive
         {
         }
 
-        private bool _isDisposed;
-        private SerialDisposable _subscription;
-        private AsyncLock _gate;
-        private Stack<IEnumerator<IObservable<TSource>>> _stack;
-        private Stack<int?> _length;
-        protected Action _recurse;
+        bool _isDisposed;
+
+        int trampoline;
+
+        IDisposable currentSubscription;
+
+        Stack<IEnumerator<IObservable<TSource>>> stack;
 
         public IDisposable Run(IEnumerable<IObservable<TSource>> sources)
         {
-            _isDisposed = false;
-            _subscription = new SerialDisposable();
-            _gate = new AsyncLock();
-            _stack = new Stack<IEnumerator<IObservable<TSource>>>();
-            _length = new Stack<int?>();
-
-            if (!TryGetEnumerator(sources, out var e))
+            if (!TryGetEnumerator(sources, out var current))
                 return Disposable.Empty;
 
-            _stack.Push(e);
-            _length.Push(Helpers.GetLength(sources));
+            stack = new Stack<IEnumerator<IObservable<TSource>>>();
+            stack.Push(current);
 
-            var cancelable = SchedulerDefaults.TailRecursion.Schedule(self =>
-            {
-                _recurse = self;
-                _gate.Wait(MoveNext);
-            });
+            Drain();
 
-            return StableCompositeDisposable.Create(_subscription, cancelable, Disposable.Create(() => _gate.Wait(Dispose)));
+            return new RecursiveSinkDisposable(this);
         }
 
-        protected abstract IEnumerable<IObservable<TSource>> Extract(IObservable<TSource> source);
-
-        private void MoveNext()
+        sealed class RecursiveSinkDisposable : IDisposable
         {
-            var hasNext = false;
-            var next = default(IObservable<TSource>);
+            readonly TailRecursiveSink<TSource> parent;
 
-            do
+            public RecursiveSinkDisposable(TailRecursiveSink<TSource> parent)
             {
-                if (_stack.Count == 0)
-                    break;
+                this.parent = parent;
+            }
 
-                if (_isDisposed)
-                    return;
+            public void Dispose()
+            {
+                parent.DisposeAll();
+            }
+        }
 
-                var e = _stack.Peek();
-                var l = _length.Peek();
+        void Drain()
+        {
+            if (Interlocked.Increment(ref trampoline) != 1)
+            {
+                return;
+            }
 
-                var current = default(IObservable<TSource>);
-                try
+            for (; ; )
+            {
+                if (Volatile.Read(ref _isDisposed))
                 {
-                    hasNext = e.MoveNext();
-                    if (hasNext)
+                    while (stack.Count != 0)
                     {
-                        current = e.Current;
+                        var enumerator = stack.Pop();
+                        enumerator.Dispose();
+                    }
+                    if (Volatile.Read(ref currentSubscription) != BooleanDisposable.True)
+                    {
+                        Interlocked.Exchange(ref currentSubscription, BooleanDisposable.True)?.Dispose();
                     }
-                }
-                catch (Exception ex)
-                {
-                    e.Dispose();
-
-                    //
-                    // Failure to enumerate the sequence cannot be handled, even by
-                    // operators like Catch, because it'd lead to another attempt at
-                    // enumerating to find the next observable sequence. Therefore,
-                    // we feed those errors directly to the observer.
-                    //
-                    _observer.OnError(ex);
-                    base.Dispose();
-                    return;
-                }
-
-                if (!hasNext)
-                {
-                    e.Dispose();
-
-                    _stack.Pop();
-                    _length.Pop();
                 }
                 else
                 {
-                    var r = l - 1;
-                    _length.Pop();
-                    _length.Push(r);
-
-                    try
+                    if (stack.Count != 0)
                     {
-                        next = Helpers.Unpack(current);
-                    }
-                    catch (Exception exception)
-                    {
-                        //
-                        // Errors from unpacking may produce side-effects that normally
-                        // would occur during a SubscribeSafe operation. Those would feed
-                        // back into the observer and be subject to the operator's error
-                        // handling behavior. For example, Catch would allow to handle
-                        // the error using a handler function.
-                        //
-                        if (!Fail(exception))
+                        var currentEnumerator = stack.Peek();
+
+                        var currentObservable = default(IObservable<TSource>);
+                        var next = default(IObservable<TSource>);
+
+                        try
+                        {
+                            if (currentEnumerator.MoveNext())
+                            {
+                                currentObservable = currentEnumerator.Current;
+                            }
+                        }
+                        catch (Exception ex)
                         {
-                            e.Dispose();
+                            currentEnumerator.Dispose();
+                            _observer.OnError(ex);
+                            base.Dispose();
+                            Volatile.Write(ref _isDisposed, true);
+                            continue;
                         }
 
-                        return;
-                    }
+                        try
+                        {
+                            next = Helpers.Unpack(currentObservable);
 
-                    //
-                    // Tail recursive case; drop the current frame.
-                    //
-                    if (r == 0)
-                    {
-                        e.Dispose();
+                        }
+                        catch (Exception ex)
+                        {
+                            next = null;
+                            if (!Fail(ex))
+                            {
+                                Volatile.Write(ref _isDisposed, true);
+                            }
+                            continue;
+                        }
 
-                        _stack.Pop();
-                        _length.Pop();
+                        if (next != null)
+                        {
+                            var nextSeq = Extract(next);
+                            if (nextSeq != null)
+                            {
+                                if (TryGetEnumerator(nextSeq, out var nextEnumerator))
+                                {
+                                    stack.Push(nextEnumerator);
+                                    continue;
+                                }
+                                else
+                                {
+                                    Volatile.Write(ref _isDisposed, true);
+                                    continue;
+                                }
+                            }
+                            else
+                            {
+                                var sad = new SingleAssignmentDisposable();
+                                if (Interlocked.CompareExchange(ref currentSubscription, sad, null) == null)
+                                {
+                                    sad.Disposable = next.SubscribeSafe(this);
+                                }
+                                else
+                                {
+                                    continue;
+                                }
+                            }
+                        }
+                        else
+                        {
+                            stack.Pop();
+                            currentEnumerator.Dispose();
+                            continue;
+                        }
                     }
-
-                    //
-                    // Flattening of nested sequences. Prevents stack overflow in observers.
-                    //
-                    var nextSeq = Extract(next);
-                    if (nextSeq != null)
+                    else
                     {
-                        if (!TryGetEnumerator(nextSeq, out var nextEnumerator))
-                            return;
-
-                        _stack.Push(nextEnumerator);
-                        _length.Push(Helpers.GetLength(nextSeq));
-
-                        hasNext = false;
+                        Volatile.Write(ref _isDisposed, true);
+                        Done();
                     }
                 }
-            } while (!hasNext);
 
-            if (!hasNext)
-            {
-                Done();
-                return;
+                if (Interlocked.Decrement(ref trampoline) == 0)
+                {
+                    break;
+                }
             }
+        }
 
-            var d = new SingleAssignmentDisposable();
-            _subscription.Disposable = d;
-            d.Disposable = next.SubscribeSafe(this);
+        void DisposeAll()
+        {
+            Volatile.Write(ref _isDisposed, true);
+            // the disposing of currentSubscription is deferred to drain due to some ObservableExTest.Iterate_Complete()
+            // Interlocked.Exchange(ref currentSubscription, BooleanDisposable.True)?.Dispose();
+            Drain();
         }
 
-        private new void Dispose()
+        protected void Recurse()
         {
-            while (_stack.Count > 0)
+            var d = Volatile.Read(ref currentSubscription);
+            if (d != BooleanDisposable.True)
             {
-                var e = _stack.Pop();
-                _length.Pop();
-
-                e.Dispose();
+                d?.Dispose();
+                if (Interlocked.CompareExchange(ref currentSubscription, null, d) == d)
+                {
+                    Drain();
+                }
             }
-
-            _isDisposed = true;
         }
 
+        protected abstract IEnumerable<IObservable<TSource>> Extract(IObservable<TSource> source);
+
         private bool TryGetEnumerator(IEnumerable<IObservable<TSource>> sources, out IEnumerator<IObservable<TSource>> result)
         {
             try
@@ -182,12 +195,6 @@ namespace System.Reactive
             }
             catch (Exception exception)
             {
-                //
-                // Failure to enumerate the sequence cannot be handled, even by
-                // operators like Catch, because it'd lead to another attempt at
-                // enumerating to find the next observable sequence. Therefore,
-                // we feed those errors directly to the observer.
-                //
                 _observer.OnError(exception);
                 base.Dispose();
 

+ 1 - 1
Rx.NET/Source/src/System.Reactive/Linq/Observable/Catch.cs

@@ -45,7 +45,7 @@ namespace System.Reactive.Linq.ObservableImpl
             public override void OnError(Exception error)
             {
                 _lastException = error;
-                _recurse();
+                Recurse();
             }
 
             public override void OnCompleted()

+ 2 - 2
Rx.NET/Source/src/System.Reactive/Linq/Observable/OnErrorResumeNext.cs

@@ -41,12 +41,12 @@ namespace System.Reactive.Linq.ObservableImpl
 
             public override void OnError(Exception error)
             {
-                _recurse();
+                Recurse();
             }
 
             public override void OnCompleted()
             {
-                _recurse();
+                Recurse();
             }
 
             protected override bool Fail(Exception error)