Browse Source

4.x: Add the RetryWhen operator (#486)

* 4.x: Add the RetryWhen operator

* Add null check to the handler's return

* Add RetryWhen to Qbservable.Generated manually
David Karnok 7 năm trước cách đây
mục cha
commit
d1f88cc7e3

+ 1 - 0
Rx.NET/Source/src/System.Reactive/Linq/IQueryLanguage.cs

@@ -581,6 +581,7 @@ namespace System.Reactive.Linq
         IObservable<TSource> Repeat<TSource>(IObservable<TSource> source, int repeatCount);
         IObservable<TSource> Retry<TSource>(IObservable<TSource> source);
         IObservable<TSource> Retry<TSource>(IObservable<TSource> source, int retryCount);
+        IObservable<TSource> RetryWhen<TSource, TSignal>(IObservable<TSource> source, Func<IObservable<Exception>, IObservable<TSignal>> handler);
         IObservable<TAccumulate> Scan<TSource, TAccumulate>(IObservable<TSource> source, TAccumulate seed, Func<TAccumulate, TSource, TAccumulate> accumulator);
         IObservable<TSource> Scan<TSource>(IObservable<TSource> source, Func<TSource, TSource, TSource> accumulator);
         IObservable<TSource> SkipLast<TSource>(IObservable<TSource> source, int count);

+ 25 - 0
Rx.NET/Source/src/System.Reactive/Linq/Observable.Single.cs

@@ -415,6 +415,31 @@ namespace System.Reactive.Linq
             return s_impl.Retry<TSource>(source, retryCount);
         }
 
+        /// <summary>
+        /// Retries (resubscribes to) the source observable after a failure and when the observable
+        /// returned by a handler produces an arbitrary item.
+        /// </summary>
+        /// <typeparam name="TSource">The type of the elements in the source sequence.</typeparam>
+        /// <typeparam name="TSignal">The arbitrary element type signaled by the handler observable.</typeparam>
+        /// <param name="source">Observable sequence to repeat until it successfully terminates.</param>
+        /// <param name="handler">The function that is called for each observer and takes an observable sequence of
+        /// errors. It should return an observable of arbitrary items that should signal that arbitrary item in
+        /// response to receiving the failure Exception from the source observable. If this observable signals
+        /// a terminal event, the sequence is terminated with that signal instead.</param>
+        /// <returns>An observable sequence producing the elements of the given sequence repeatedly until it terminates successfully.</returns>
+        /// <exception cref="ArgumentNullException"><paramref name="source"/> is null.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="handler"/> is null.</exception>
+        public static IObservable<TSource> RetryWhen<TSource, TSignal>(this IObservable<TSource> source, Func<IObservable<Exception>, IObservable<TSignal>> handler)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (handler == null)
+                throw new ArgumentNullException(nameof(handler));
+
+            return s_impl.RetryWhen(source, handler);
+        }
+
+
         #endregion
 
         #region + Scan +

+ 318 - 0
Rx.NET/Source/src/System.Reactive/Linq/Observable/RetryWhen.cs

@@ -0,0 +1,318 @@
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Reactive.Disposables;
+using System.Reactive.Subjects;
+using System.Text;
+using System.Threading;
+
+namespace System.Reactive.Linq.ObservableImpl
+{
+    internal sealed class RetryWhen<T, U> : IObservable<T>
+    {
+        readonly IObservable<T> source;
+
+        readonly Func<IObservable<Exception>, IObservable<U>> handler;
+
+        internal RetryWhen(IObservable<T> source, Func<IObservable<Exception>, IObservable<U>> handler)
+        {
+            this.source = source;
+            this.handler = handler;
+        }
+
+        public IDisposable Subscribe(IObserver<T> observer)
+        {
+            if (observer == null)
+            {
+                throw new ArgumentNullException(nameof(observer));
+            }
+
+            var errorSignals = new Subject<Exception>();
+            var redo = default(IObservable<U>);
+
+            try
+            {
+                redo = handler(errorSignals);
+                if (redo == null)
+                {
+                    throw new NullReferenceException("The handler returned a null IObservable");
+                }
+            }
+            catch (Exception ex)
+            {
+                observer.OnError(ex);
+                return Disposable.Empty;
+            }
+
+            var parent = new MainObserver(observer, source, new SerializedObserver(errorSignals));
+
+            var d = redo.SubscribeSafe(parent.handlerObserver);
+            parent.handlerObserver.OnSubscribe(d);
+
+            parent.HandlerNext();
+
+            return parent;
+        }
+
+        sealed class MainObserver : IObserver<T>, IDisposable
+        {
+            readonly IObserver<T> downstream;
+
+            readonly IObserver<Exception> errorSignal;
+
+            internal readonly HandlerObserver handlerObserver;
+
+            readonly IObservable<T> source;
+
+            SingleAssignmentDisposable upstream;
+
+            int trampoline;
+
+            int halfSerializer;
+
+            Exception error;
+
+            static readonly SingleAssignmentDisposable DISPOSED;
+
+            static MainObserver()
+            {
+                DISPOSED = new SingleAssignmentDisposable();
+                DISPOSED.Dispose();
+            }
+
+            internal MainObserver(IObserver<T> downstream, IObservable<T> source, IObserver<Exception> errorSignal)
+            {
+                this.downstream = downstream;
+                this.source = source;
+                this.errorSignal = errorSignal;
+                this.handlerObserver = new HandlerObserver(this);
+            }
+
+            public void Dispose()
+            {
+                Interlocked.Exchange(ref upstream, DISPOSED)?.Dispose();
+                handlerObserver.Dispose();
+            }
+
+            public void OnCompleted()
+            {
+                if (Interlocked.Increment(ref halfSerializer) == 1)
+                {
+                    downstream.OnCompleted();
+                    Dispose();
+                }
+            }
+
+            public void OnError(Exception error)
+            {
+                for (; ; )
+                {
+                    var d = Volatile.Read(ref upstream);
+                    if (d == DISPOSED)
+                    {
+                        break;
+                    }
+                    if (Interlocked.CompareExchange(ref upstream, null, d) == d)
+                    {
+                        errorSignal.OnNext(error);
+                        d.Dispose();
+                        break;
+                    }
+                }
+            }
+
+            public void OnNext(T value)
+            {
+                if (Interlocked.CompareExchange(ref halfSerializer, 1, 0) == 0)
+                {
+                    downstream.OnNext(value);
+                    if (Interlocked.Decrement(ref halfSerializer) != 0)
+                    {
+                        var ex = error;
+                        if (ex == null)
+                        {
+                            downstream.OnCompleted();
+                        }
+                        else
+                        {
+                            downstream.OnError(ex);
+                        }
+                        Dispose();
+                    }
+                }
+            }
+
+            internal void HandlerError(Exception error)
+            {
+                this.error = error;
+                if (Interlocked.Increment(ref halfSerializer) == 1)
+                {
+                    downstream.OnError(error);
+                    Dispose();
+                }
+            }
+
+            internal void HandlerComplete()
+            {
+                if (Interlocked.Increment(ref halfSerializer) == 1)
+                {
+                    downstream.OnCompleted();
+                    Dispose();
+                }
+            }
+
+            internal void HandlerNext()
+            {
+                if (Interlocked.Increment(ref trampoline) == 1)
+                {
+                    do
+                    {
+                        var sad = new SingleAssignmentDisposable();
+                        if (Interlocked.CompareExchange(ref upstream, sad, null) != null)
+                        {
+                            return;
+                        }
+
+                        sad.Disposable = source.SubscribeSafe(this);
+                    }
+                    while (Interlocked.Decrement(ref trampoline) != 0);
+                }
+            }
+
+            internal sealed class HandlerObserver : IObserver<U>, IDisposable
+            {
+                readonly MainObserver main;
+
+                IDisposable upstream;
+
+                internal HandlerObserver(MainObserver main)
+                {
+                    this.main = main;
+                }
+
+                internal void OnSubscribe(IDisposable d)
+                {
+                    if (Interlocked.CompareExchange(ref upstream, d, null) != null)
+                    {
+                        d?.Dispose();
+                    }
+                }
+
+                public void Dispose()
+                {
+                    Interlocked.Exchange(ref upstream, BooleanDisposable.True)?.Dispose();
+                }
+
+                public void OnCompleted()
+                {
+                    main.HandlerComplete();
+                    Dispose();
+                }
+
+                public void OnError(Exception error)
+                {
+                    main.HandlerError(error);
+                    Dispose();
+                }
+
+                public void OnNext(U value)
+                {
+                    main.HandlerNext();
+                }
+            }
+        }
+
+        sealed class SerializedObserver : IObserver<Exception>
+        {
+            readonly IObserver<Exception> downstream;
+
+            int wip;
+
+            Exception terminalException;
+
+            static readonly Exception DONE = new Exception();
+
+            static readonly Exception SIGNALED = new Exception();
+
+            readonly ConcurrentQueue<Exception> queue;
+
+            internal SerializedObserver(IObserver<Exception> downstream)
+            {
+                this.downstream = downstream;
+                this.queue = new ConcurrentQueue<Exception>();
+            }
+
+            public void OnCompleted()
+            {
+                if (Interlocked.CompareExchange(ref terminalException, DONE, null) == null)
+                {
+                    Drain();
+                }
+            }
+
+            public void OnError(Exception error)
+            {
+                if (Interlocked.CompareExchange(ref terminalException, error, null) == null)
+                {
+                    Drain();
+                }
+            }
+
+            public void OnNext(Exception value)
+            {
+                queue.Enqueue(value);
+                Drain();
+            }
+
+            void Clear()
+            {
+                while (queue.TryDequeue(out var _)) ;
+            }
+
+            void Drain()
+            {
+                if (Interlocked.Increment(ref wip) != 1)
+                {
+                    return;
+                }
+
+                int missed = 1;
+
+                for (; ; )
+                {
+                    var ex = Volatile.Read(ref terminalException);
+                    if (ex != null)
+                    {
+                        if (ex != SIGNALED)
+                        {
+                            Interlocked.Exchange(ref terminalException, SIGNALED);
+                            if (ex != DONE)
+                            {
+                                downstream.OnError(ex);
+                            }
+                            else
+                            {
+                                downstream.OnCompleted();
+                            }
+                        }
+                        Clear();
+                    }
+                    else
+                    {
+                        while (queue.TryDequeue(out var item))
+                        {
+                            downstream.OnNext(item);
+                        }
+                    }
+                        
+
+                    missed = Interlocked.Add(ref wip, -missed);
+                    if (missed == 0)
+                    {
+                        break;
+                    }
+                }
+            }
+        }
+    }
+}

+ 37 - 2
Rx.NET/Source/src/System.Reactive/Linq/Qbservable.Generated.cs

@@ -1,4 +1,4 @@
-/*
+/*
  * WARNING: Auto-generated file (5/1/2015 21:21:20)
  * Run Rx's auto-homoiconizer tool to generate this file (in the HomoIcon directory).
  */
@@ -11175,7 +11175,42 @@ namespace System.Reactive.Linq
                 )
             );
         }
-        
+
+        /// <summary>
+        /// Retries (resubscribes to) the source observable after a failure and when the observable
+        /// returned by a handler produces an arbitrary item.
+        /// </summary>
+        /// <typeparam name="TSource">The type of the elements in the source sequence.</typeparam>
+        /// <typeparam name="TSignal">The arbitrary element type signaled by the handler observable.</typeparam>
+        /// <param name="source">Observable sequence to repeat until it successfully terminates.</param>
+        /// <param name="handler">The function that is called for each observer and takes an observable sequence of
+        /// errors. It should return an observable of arbitrary items that should signal that arbitrary item in
+        /// response to receiving the failure Exception from the source observable. If this observable signals
+        /// a terminal event, the sequence is terminated with that signal instead.</param>
+        /// <returns>An observable sequence producing the elements of the given sequence repeatedly until it terminates successfully.</returns>
+        /// <exception cref="ArgumentNullException"><paramref name="source"/> is null.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="handler"/> is null.</exception>
+        public static IObservable<TSource> RetryWhen<TSource, TSignal>(this IQbservable<TSource> source, Expression<Func<IObservable<Exception>, IObservable<TSignal>>> handler)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (handler == null)
+                throw new ArgumentNullException(nameof(handler));
+
+            return source.Provider.CreateQuery<TSource>(
+                Expression.Call(
+                    null,
+#if CRIPPLED_REFLECTION
+                    InfoOf(() => Qbservable.RetryWhen<TSource, TSignal>(default(IQbservable<TSource>), default(Expression<Func<IObservable<Exception>, IObservable<TSignal>>>))),
+#else
+                    ((MethodInfo)MethodInfo.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TSignal)),
+#endif
+                    source.Expression,
+                    handler
+                )
+            );
+        }
+
         /// <summary>
         /// Returns an observable sequence that contains a single element.
         /// </summary>

+ 6 - 0
Rx.NET/Source/src/System.Reactive/Linq/QueryLanguage.Single.cs

@@ -186,6 +186,12 @@ namespace System.Reactive.Linq
             return Enumerable.Repeat(source, retryCount).Catch();
         }
 
+        public virtual IObservable<TSource> RetryWhen<TSource, TSignal>(IObservable<TSource> source, Func<IObservable<Exception>, IObservable<TSignal>> handler)
+        {
+            return new RetryWhen<TSource, TSignal>(source, handler);
+        }
+
+
         #endregion
 
         #region + Scan +

+ 231 - 0
Rx.NET/Source/tests/Tests.System.Reactive/Tests/Linq/ObservableRetryWhenTest.cs

@@ -0,0 +1,231 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the Apache 2.0 License.
+// See the LICENSE file in the project root for more information. 
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Reactive;
+using System.Reactive.Concurrency;
+using System.Reactive.Linq;
+using Microsoft.Reactive.Testing;
+using Xunit;
+using ReactiveTests.Dummies;
+
+namespace ReactiveTests.Tests
+{
+    public class ObservableRetryWhenTest : ReactiveTest
+    {
+        [Fact]
+        public void RetryWhen_Observable_ArgumentChecking()
+        {
+            ReactiveAssert.Throws<ArgumentNullException>(() => Observable.RetryWhen<int, Exception>(null, v => v));
+            ReactiveAssert.Throws<ArgumentNullException>(() => Observable.RetryWhen<int, Exception>(Observable.Return(1), null));
+            ReactiveAssert.Throws<ArgumentNullException>(() => DummyObservable<int>.Instance.RetryWhen(v => v).Subscribe(null));
+        }
+
+        [Fact]
+        public void RetryWhen_Observable_Basic()
+        {
+            var scheduler = new TestScheduler();
+
+            var xs = scheduler.CreateColdObservable(
+                OnNext(100, 1),
+                OnNext(150, 2),
+                OnNext(200, 3),
+                OnCompleted<int>(250)
+            );
+
+            var res = scheduler.Start(() =>
+                xs.RetryWhen(v => v)
+            );
+
+            res.Messages.AssertEqual(
+                OnNext(300, 1),
+                OnNext(350, 2),
+                OnNext(400, 3),
+                OnCompleted<int>(450)
+            );
+
+            xs.Subscriptions.AssertEqual(
+                Subscribe(200, 450)
+            );
+        }
+
+        [Fact]
+        public void RetryWhen_Observable_Handler_Completes()
+        {
+            var scheduler = new TestScheduler();
+
+            var ex = new Exception();
+
+            var xs = scheduler.CreateColdObservable(
+                OnNext(100, 1),
+                OnNext(150, 2),
+                OnNext(200, 3),
+                OnError<int>(250, ex)
+            );
+
+            var res = scheduler.Start(() =>
+                xs.RetryWhen(v => v.Take(1).Skip(1))
+            );
+
+            res.Messages.AssertEqual(
+                OnNext(300, 1),
+                OnNext(350, 2),
+                OnNext(400, 3),
+                OnCompleted<int>(450)
+            );
+
+            xs.Subscriptions.AssertEqual(
+                Subscribe(200, 450)
+            );
+        }
+
+
+        [Fact]
+        public void RetryWhen_Observable_Handler_Throws()
+        {
+            var scheduler = new TestScheduler();
+
+            var ex = new Exception();
+
+            var res = scheduler.Start(() =>
+                Observable.Return(1).RetryWhen<int, int>(v => { throw ex; })
+            );
+
+            res.Messages.AssertEqual(
+                OnError<int>(200, ex)
+            );
+        }
+
+        [Fact]
+        public void RetryWhen_Observable_Handler_Errors()
+        {
+            var scheduler = new TestScheduler();
+
+            var ex = new Exception();
+            var ex2 = new Exception();
+
+            var xs = scheduler.CreateColdObservable(
+                OnNext(100, 1),
+                OnNext(150, 2),
+                OnNext(200, 3),
+                OnError<int>(250, ex)
+            );
+
+            var res = scheduler.Start(() =>
+                xs.RetryWhen(v => v.SelectMany(w => Observable.Throw<int>(ex2)))
+            );
+
+            res.Messages.AssertEqual(
+                OnNext(300, 1),
+                OnNext(350, 2),
+                OnNext(400, 3),
+                OnError<int>(450, ex2)
+            );
+
+            xs.Subscriptions.AssertEqual(
+                Subscribe(200, 450)
+            );
+        }
+
+        [Fact]
+        public void RetryWhen_Observable_RetryCount_Basic()
+        {
+            var scheduler = new TestScheduler();
+
+            var ex = new Exception();
+
+            var xs = scheduler.CreateColdObservable(
+                OnNext(5, 1),
+                OnNext(10, 2),
+                OnNext(15, 3),
+                OnError<int>(20, ex)
+            );
+
+            var res = scheduler.Start(() =>
+                xs.RetryWhen(v =>
+                {
+                    int[] count = { 0 };
+                    return v.SelectMany(w => {
+                        int c = ++count[0];
+                        if (c == 3)
+                        {
+                            return Observable.Throw<int>(w);
+                        }
+                        return Observable.Return(1);
+                    });
+                })
+            );
+
+            res.Messages.AssertEqual(
+                OnNext(205, 1),
+                OnNext(210, 2),
+                OnNext(215, 3),
+                OnNext(225, 1),
+                OnNext(230, 2),
+                OnNext(235, 3),
+                OnNext(245, 1),
+                OnNext(250, 2),
+                OnNext(255, 3),
+                OnError<int>(260, ex)
+            );
+
+            xs.Subscriptions.AssertEqual(
+                Subscribe(200, 220),
+                Subscribe(220, 240),
+                Subscribe(240, 260)
+            );
+        }
+
+        [Fact]
+        public void RetryWhen_Observable_RetryCount_Delayed()
+        {
+            var scheduler = new TestScheduler();
+
+            var ex = new Exception();
+
+            var xs = scheduler.CreateColdObservable(
+                OnNext(5, 1),
+                OnNext(10, 2),
+                OnNext(15, 3),
+                OnError<int>(20, ex)
+            );
+
+            var res = scheduler.Start(() =>
+                xs.RetryWhen(v =>
+                {
+                    int[] count = { 0 };
+                    return v.SelectMany(w => {
+                        int c = ++count[0];
+                        if (c == 3)
+                        {
+                            return Observable.Throw<int>(w);
+                        }
+                        return Observable.Return(1).Delay(TimeSpan.FromTicks(c * 100), scheduler);
+                    });
+                })
+            );
+
+            res.Messages.AssertEqual(
+                OnNext(205, 1),
+                OnNext(210, 2),
+                OnNext(215, 3),
+                OnNext(325, 1),
+                OnNext(330, 2),
+                OnNext(335, 3),
+                OnNext(545, 1),
+                OnNext(550, 2),
+                OnNext(555, 3),
+                OnError<int>(560, ex)
+            );
+
+            xs.Subscriptions.AssertEqual(
+                Subscribe(200, 220),
+                Subscribe(320, 340),
+                Subscribe(540, 560)
+            );
+        }
+    }
+}