Browse Source

4.x: Dedicated Amb implementation for arrays and enumerables (#545)

David Karnok 7 years ago
parent
commit
68acb4b737

+ 216 - 0
Rx.NET/Source/src/System.Reactive/Linq/Observable/AmbMany.cs

@@ -0,0 +1,216 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the Apache 2.0 License.
+// See the LICENSE file in the project root for more information. 
+
+using System;
+using System.Collections.Generic;
+using System.Reactive.Disposables;
+using System.Text;
+using System.Threading;
+using System.Linq;
+
+namespace System.Reactive.Linq.ObservableImpl
+{
+    internal sealed class AmbManyArray<T> : BasicProducer<T>
+    {
+        readonly IObservable<T>[] sources;
+
+        public AmbManyArray(IObservable<T>[] sources)
+        {
+            this.sources = sources;
+        }
+
+        protected override IDisposable Run(IObserver<T> observer)
+        {
+            return AmbCoordinator<T>.Create(observer, sources);
+        }
+    }
+
+    internal sealed class AmbManyEnumerable<T> : BasicProducer<T>
+    {
+        readonly IEnumerable<IObservable<T>> sources;
+
+        public AmbManyEnumerable(IEnumerable<IObservable<T>> sources)
+        {
+            this.sources = sources;
+        }
+
+        protected override IDisposable Run(IObserver<T> observer)
+        {
+            var sourcesEnumerable = this.sources;
+            var sources = default(IObservable<T>[]);
+
+            try
+            {
+                sources = sourcesEnumerable.ToArray();
+            }
+            catch (Exception ex)
+            {
+                observer.OnError(ex);
+                return Disposable.Empty;
+            }
+
+            return AmbCoordinator<T>.Create(observer, sources);
+        }
+    }
+
+    internal sealed class AmbCoordinator<T> : IDisposable
+    {
+        readonly IObserver<T> downstream;
+
+        readonly InnerObserver[] observers;
+
+        int winner;
+
+        internal AmbCoordinator(IObserver<T> downstream, int n)
+        {
+            this.downstream = downstream;
+            var o = new InnerObserver[n];
+            for (int i = 0; i < n; i++)
+            {
+                o[i] = new InnerObserver(this, i);
+            }
+            observers = o;
+            Volatile.Write(ref winner, -1);
+        }
+
+        internal static IDisposable Create(IObserver<T> observer, IObservable<T>[] sources)
+        {
+            var n = sources.Length;
+            if (n == 0)
+            {
+                observer.OnCompleted();
+                return Disposable.Empty;
+            }
+
+            if (n == 1)
+            {
+                return sources[0].Subscribe(observer);
+            }
+
+            var parent = new AmbCoordinator<T>(observer, n);
+
+            parent.Subscribe(sources);
+
+            return parent;
+        }
+
+        internal void Subscribe(IObservable<T>[] sources)
+        {
+            for (var i = 0; i < observers.Length; i++)
+            {
+                var inner = Volatile.Read(ref observers[i]);
+                if (inner == null)
+                {
+                    break;
+                }
+                inner.OnSubscribe(sources[i].Subscribe(inner));
+            }
+        }
+
+        public void Dispose()
+        {
+            for (var i = 0; i < observers.Length; i++)
+            {
+                Interlocked.Exchange(ref observers[i], null)?.Dispose();
+            }
+        }
+
+        bool TryWin(int index)
+        {
+            if (Volatile.Read(ref winner) == -1 && Interlocked.CompareExchange(ref winner, index, -1) == -1)
+            {
+                for (var i = 0; i < observers.Length; i++)
+                {
+                    if (index != i)
+                    {
+                        Interlocked.Exchange(ref observers[i], null)?.Dispose();
+                    }
+                }
+                return true;
+            }
+            return false;
+        }
+
+        internal sealed class InnerObserver : IObserver<T>, IDisposable
+        {
+            readonly IObserver<T> downstream;
+
+            readonly AmbCoordinator<T> parent;
+
+            readonly int index;
+
+            IDisposable upstream;
+
+            bool won;
+
+            public InnerObserver(AmbCoordinator<T> parent, int index)
+            {
+                this.downstream = parent.downstream;
+                this.parent = parent;
+                this.index = index;
+            }
+
+            public void Dispose()
+            {
+                Interlocked.Exchange(ref upstream, BooleanDisposable.True)?.Dispose();
+            }
+
+            public void OnCompleted()
+            {
+                if (won)
+                {
+                    downstream.OnCompleted();
+                }
+                else
+                if (parent.TryWin(index))
+                {
+                    won = true;
+                    downstream.OnCompleted();
+                }
+                Dispose();
+            }
+
+            public void OnError(Exception error)
+            {
+                if (won)
+                {
+                    downstream.OnError(error);
+                }
+                else
+                if (parent.TryWin(index))
+                {
+                    won = true;
+                    downstream.OnError(error);
+                }
+                Dispose();
+            }
+
+            public void OnNext(T value)
+            {
+                if (won)
+                {
+                    downstream.OnNext(value);
+                }
+                else
+                if (parent.TryWin(index))
+                {
+                    won = true;
+                    downstream.OnNext(value);
+                } else
+                {
+                    Dispose();
+                }
+            }
+
+            internal void OnSubscribe(IDisposable d)
+            {
+                if (Interlocked.CompareExchange(ref upstream, d, null) != null)
+                {
+                    d?.Dispose();
+                }
+            }
+        }
+
+    }
+}

+ 3 - 13
Rx.NET/Source/src/System.Reactive/Linq/QueryLanguage.Multiple.cs

@@ -18,27 +18,17 @@ namespace System.Reactive.Linq
 
         public virtual IObservable<TSource> Amb<TSource>(IObservable<TSource> first, IObservable<TSource> second)
         {
-            return Amb_(first, second);
+            return new Amb<TSource>(first, second);
         }
 
         public virtual IObservable<TSource> Amb<TSource>(params IObservable<TSource>[] sources)
         {
-            return Amb_(sources);
+            return new AmbManyArray<TSource>(sources);
         }
 
         public virtual IObservable<TSource> Amb<TSource>(IEnumerable<IObservable<TSource>> sources)
         {
-            return Amb_(sources);
-        }
-
-        private static IObservable<TSource> Amb_<TSource>(IEnumerable<IObservable<TSource>> sources)
-        {
-            return sources.Aggregate(Observable.Never<TSource>(), (previous, current) => previous.Amb(current));
-        }
-
-        private static IObservable<TSource> Amb_<TSource>(IObservable<TSource> leftSource, IObservable<TSource> rightSource)
-        {
-            return new Amb<TSource>(leftSource, rightSource);
+            return new AmbManyEnumerable<TSource>(sources);
         }
 
         #endregion

+ 320 - 0
Rx.NET/Source/tests/Tests.System.Reactive/Tests/Linq/Observable/AmbTest.cs

@@ -376,5 +376,325 @@ namespace ReactiveTests.Tests
             );
         }
 
+        [Fact]
+        public void Amb_Many_Array_OnNext()
+        {
+            var scheduler = new TestScheduler();
+
+            var ex = new Exception();
+
+            var o1 = scheduler.CreateColdObservable(
+                OnNext(150, 1),
+                OnNext(220, 3),
+                OnCompleted<int>(250)
+            );
+
+            var o2 = scheduler.CreateColdObservable(
+                OnNext(150, 2),
+                OnError<int>(210, ex)
+            );
+
+            var o3 = scheduler.CreateColdObservable(
+                OnCompleted<int>(150)
+            );
+
+            var res = scheduler.Start(() =>
+                Observable.Amb(o1, o2, o3)
+            );
+
+            res.Messages.AssertEqual(
+                OnNext(350, 1),
+                OnNext(420, 3),
+                OnCompleted<int>(450)
+            );
+
+            o1.Subscriptions.AssertEqual(
+                Subscribe(200, 450)
+            );
+
+            o2.Subscriptions.AssertEqual(
+                Subscribe(200, 350)
+            );
+
+            o3.Subscriptions.AssertEqual(
+                Subscribe(200, 350)
+            );
+        }
+
+        [Fact]
+        public void Amb_Many_Array_OnError()
+        {
+            var scheduler = new TestScheduler();
+
+            var ex = new Exception();
+
+            var o1 = scheduler.CreateColdObservable(
+                OnError<int>(150, ex)
+            );
+
+            var o2 = scheduler.CreateColdObservable(
+                OnNext(150, 1),
+                OnNext(220, 3),
+                OnCompleted<int>(250)
+            );
+
+            var o3 = scheduler.CreateColdObservable(
+                OnCompleted<int>(150)
+            );
+
+            var res = scheduler.Start(() =>
+                Observable.Amb(o1, o2, o3)
+            );
+
+            res.Messages.AssertEqual(
+                OnError<int>(350, ex)
+            );
+
+            o1.Subscriptions.AssertEqual(
+                Subscribe(200, 350)
+            );
+
+            o2.Subscriptions.AssertEqual(
+                Subscribe(200, 350)
+            );
+
+            o3.Subscriptions.AssertEqual(
+                Subscribe(200, 350)
+            );
+        }
+
+        [Fact]
+        public void Amb_Many_Array_OnCompleted()
+        {
+            var scheduler = new TestScheduler();
+
+            var ex = new Exception();
+
+            var o1 = scheduler.CreateColdObservable(
+                OnCompleted<int>(150)
+            );
+
+            var o2 = scheduler.CreateColdObservable(
+                OnNext(150, 1),
+                OnNext(220, 3),
+                OnCompleted<int>(250)
+            );
+
+            var o3 = scheduler.CreateColdObservable(
+                OnNext(150, 2),
+                OnError<int>(210, ex)
+            );
+
+
+            var res = scheduler.Start(() =>
+                Observable.Amb(o1, o2, o3)
+            );
+
+            res.Messages.AssertEqual(
+                OnCompleted<int>(350)
+            );
+
+            o1.Subscriptions.AssertEqual(
+                Subscribe(200, 350)
+            );
+
+            o2.Subscriptions.AssertEqual(
+                Subscribe(200, 350)
+            );
+
+            o3.Subscriptions.AssertEqual(
+                Subscribe(200, 350)
+            );
+        }
+
+
+        [Fact]
+        public void Amb_Many_Enumerable_OnNext()
+        {
+            var scheduler = new TestScheduler();
+
+            var ex = new Exception();
+
+            var o1 = scheduler.CreateColdObservable(
+                OnNext(150, 1),
+                OnNext(220, 3),
+                OnCompleted<int>(250)
+            );
+
+            var o2 = scheduler.CreateColdObservable(
+                OnNext(150, 2),
+                OnError<int>(210, ex)
+            );
+
+            var o3 = scheduler.CreateColdObservable(
+                OnCompleted<int>(150)
+            );
+
+            var res = scheduler.Start(() =>
+                new[] { o1, o2, o3 }.Amb()
+            );
+
+            res.Messages.AssertEqual(
+                OnNext(350, 1),
+                OnNext(420, 3),
+                OnCompleted<int>(450)
+            );
+
+            o1.Subscriptions.AssertEqual(
+                Subscribe(200, 450)
+            );
+
+            o2.Subscriptions.AssertEqual(
+                Subscribe(200, 350)
+            );
+
+            o3.Subscriptions.AssertEqual(
+                Subscribe(200, 350)
+            );
+        }
+
+        [Fact]
+        public void Amb_Many_Enumerable_OnError()
+        {
+            var scheduler = new TestScheduler();
+
+            var ex = new Exception();
+
+            var o1 = scheduler.CreateColdObservable(
+                OnError<int>(150, ex)
+            );
+
+            var o2 = scheduler.CreateColdObservable(
+                OnNext(150, 1),
+                OnNext(220, 3),
+                OnCompleted<int>(250)
+            );
+
+            var o3 = scheduler.CreateColdObservable(
+                OnCompleted<int>(150)
+            );
+
+            var res = scheduler.Start(() =>
+                new[] { o1, o2, o3 }.Amb()
+            );
+
+            res.Messages.AssertEqual(
+                OnError<int>(350, ex)
+            );
+
+            o1.Subscriptions.AssertEqual(
+                Subscribe(200, 350)
+            );
+
+            o2.Subscriptions.AssertEqual(
+                Subscribe(200, 350)
+            );
+
+            o3.Subscriptions.AssertEqual(
+                Subscribe(200, 350)
+            );
+        }
+
+        [Fact]
+        public void Amb_Many_Enumerable_OnCompleted()
+        {
+            var scheduler = new TestScheduler();
+
+            var ex = new Exception();
+
+            var o1 = scheduler.CreateColdObservable(
+                OnCompleted<int>(150)
+            );
+
+            var o2 = scheduler.CreateColdObservable(
+                OnNext(150, 1),
+                OnNext(220, 3),
+                OnCompleted<int>(250)
+            );
+
+            var o3 = scheduler.CreateColdObservable(
+                OnNext(150, 2),
+                OnError<int>(210, ex)
+            );
+
+
+            var res = scheduler.Start(() =>
+                new[] { o1, o2, o3 }.Amb()
+            );
+
+            res.Messages.AssertEqual(
+                OnCompleted<int>(350)
+            );
+
+            o1.Subscriptions.AssertEqual(
+                Subscribe(200, 350)
+            );
+
+            o2.Subscriptions.AssertEqual(
+                Subscribe(200, 350)
+            );
+
+            o3.Subscriptions.AssertEqual(
+                Subscribe(200, 350)
+            );
+        }
+
+
+        [Fact]
+        public void Amb_Many_Enumerable_Many_Sources()
+        {
+            for (int i = 0; i < 32; i++)
+            {
+                var sources = new List<IObservable<int>>();
+                for (var j = 0; j < i; j++)
+                {
+                    sources.Add(Observable.Return(j));
+                }
+
+                var result = sources.Amb().ToList().First();
+
+                if (i == 0)
+                {
+                    Assert.Equal(0, result.Count);
+                }
+                else
+                {
+                    Assert.Equal(1, result.Count);
+                    Assert.Equal(0, result[0]);
+                }
+            }
+        }
+
+        [Fact]
+        public void Amb_Many_Enumerable_Many_Sources_NoStackOverflow()
+        {
+            for (int i = 0; i < 100; i++)
+            {
+                var sources = new List<IObservable<int>>();
+                for (var j = 0; j < i; j++)
+                {
+                    if (j == i - 1)
+                    {
+                        sources.Add(Observable.Return(j));
+                    }
+                    else
+                    {
+                        sources.Add(Observable.Never<int>());
+                    }
+                }
+
+                var result = sources.Amb().ToList().First();
+
+                if (i == 0)
+                {
+                    Assert.Equal(0, result.Count);
+                }
+                else
+                {
+                    Assert.Equal(1, result.Count);
+                    Assert.Equal(i - 1, result[0]);
+                }
+            }
+        }
     }
 }