Browse Source

Merge branch 'master' into ZipWithEnumerableDisposeFix

Daniel C. Weber 6 years ago
parent
commit
b9321cdf9d

+ 1 - 1
Ix.NET/Source/Directory.build.props

@@ -23,7 +23,7 @@
 
   <ItemGroup>
     <PackageReference Include="Microsoft.SourceLink.GitHub" Version="1.0.0-beta2-19554-01" PrivateAssets="All"/>
-    <PackageReference Include="Nerdbank.GitVersioning" Version="3.0.26" PrivateAssets="all" />
+    <PackageReference Include="Nerdbank.GitVersioning" Version="3.0.28" PrivateAssets="all" />
   </ItemGroup>
 
 

+ 342 - 0
Ix.NET/Source/System.Interactive.Async.Tests/System/Linq/Operators/Amb.cs

@@ -0,0 +1,342 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the Apache 2.0 License.
+// See the LICENSE file in the project root for more information. 
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using Xunit;
+
+namespace Tests
+{
+    public class Amb : AsyncEnumerableExTests
+    {
+        [Fact]
+        public void Amb_Null()
+        {
+            Assert.Throws<ArgumentNullException>(() => AsyncEnumerableEx.Amb(default, Return42));
+            Assert.Throws<ArgumentNullException>(() => AsyncEnumerableEx.Amb(Return42, default));
+        }
+
+        [Fact]
+        public async Task Amb_First_Wins()
+        {
+            var source = AsyncEnumerable.Range(1, 5).Amb(AsyncEnumerableEx.Never<int>());
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                for (var i = 1; i <= 5; i++)
+                {
+                    Assert.True(await xs.MoveNextAsync());
+                    Assert.Equal(i, xs.Current);
+                }
+
+                Assert.False(await xs.MoveNextAsync());
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Amb_First_Wins_Alt()
+        {
+            var source = AsyncEnumerable.Range(1, 5).Amb(AsyncEnumerable.Range(1, 5).SelectAwait(async v =>
+            {
+                await Task.Delay(500);
+                return v;
+            }));
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                for (var i = 1; i <= 5; i++)
+                {
+                    Assert.True(await xs.MoveNextAsync());
+                    Assert.Equal(i, xs.Current);
+                }
+
+                Assert.False(await xs.MoveNextAsync());
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Amb_Second_Wins()
+        {
+            var source = AsyncEnumerableEx.Never<int>().Amb(AsyncEnumerable.Range(1, 5));
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                for (var i = 1; i <= 5; i++)
+                {
+                    Assert.True(await xs.MoveNextAsync());
+                    Assert.Equal(i, xs.Current);
+                }
+
+                Assert.False(await xs.MoveNextAsync());
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Amb_Second_Wins_Alt()
+        {
+            var source = AsyncEnumerable.Range(1, 5).SelectAwait(async v =>
+            {
+                await Task.Delay(500);
+                return v;
+            }).Amb(AsyncEnumerable.Range(6, 5));
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                for (var i = 1; i <= 5; i++)
+                {
+                    Assert.True(await xs.MoveNextAsync());
+                    Assert.Equal(i + 5, xs.Current);
+                }
+
+                Assert.False(await xs.MoveNextAsync());
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Amb_Many_First_Wins()
+        {
+            var source = AsyncEnumerableEx.Amb(
+                AsyncEnumerable.Range(1, 5),
+                AsyncEnumerableEx.Never<int>(),
+                AsyncEnumerableEx.Never<int>()
+            );
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                for (var i = 1; i <= 5; i++)
+                {
+                    Assert.True(await xs.MoveNextAsync());
+                    Assert.Equal(i, xs.Current);
+                }
+
+                Assert.False(await xs.MoveNextAsync());
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Amb_Many_Last_Wins()
+        {
+            var source = AsyncEnumerableEx.Amb(
+                AsyncEnumerableEx.Never<int>(),
+                AsyncEnumerableEx.Never<int>(),
+                AsyncEnumerable.Range(1, 5)
+            );
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                for (var i = 1; i <= 5; i++)
+                {
+                    Assert.True(await xs.MoveNextAsync());
+                    Assert.Equal(i, xs.Current);
+                }
+
+                Assert.False(await xs.MoveNextAsync());
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Amb_Many_Enum_First_Wins()
+        {
+            var source = AsyncEnumerableEx.Amb(new[] {
+                    AsyncEnumerable.Range(1, 5),
+                    AsyncEnumerableEx.Never<int>(),
+                    AsyncEnumerableEx.Never<int>()
+                }.AsEnumerable()
+            );
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                for (var i = 1; i <= 5; i++)
+                {
+                    Assert.True(await xs.MoveNextAsync());
+                    Assert.Equal(i, xs.Current);
+                }
+
+                Assert.False(await xs.MoveNextAsync());
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Amb_Many_Enum_Last_Wins()
+        {
+            var source = AsyncEnumerableEx.Amb(new[] {
+                    AsyncEnumerableEx.Never<int>(),
+                    AsyncEnumerableEx.Never<int>(),
+                    AsyncEnumerable.Range(1, 5)
+                }.AsEnumerable()
+            );
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                for (var i = 1; i <= 5; i++)
+                {
+                    Assert.True(await xs.MoveNextAsync());
+                    Assert.Equal(i, xs.Current);
+                }
+
+                Assert.False(await xs.MoveNextAsync());
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+
+
+        [Fact]
+        public async Task Amb_First_GetAsyncEnumerator_Crashes()
+        {
+            var source = new FailingGetAsyncEnumerator<int>().Amb(AsyncEnumerableEx.Never<int>());
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                await xs.MoveNextAsync();
+
+                Assert.False(true, "Should not have gotten here");
+            }
+            catch (InvalidOperationException)
+            {
+                // we expect this
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Amb_Second_GetAsyncEnumerator_Crashes()
+        {
+            var source = AsyncEnumerableEx.Never<int>().Amb(new FailingGetAsyncEnumerator<int>());
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                await xs.MoveNextAsync();
+
+                Assert.False(true, "Should not have gotten here");
+            }
+            catch (InvalidOperationException)
+            {
+                // we expect this
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Amb_Many_First_GetAsyncEnumerator_Crashes()
+        {
+            var source = AsyncEnumerableEx.Amb(
+                new FailingGetAsyncEnumerator<int>(),
+                AsyncEnumerableEx.Never<int>(),
+                AsyncEnumerableEx.Never<int>()
+            );
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                await xs.MoveNextAsync();
+
+                Assert.False(true, "Should not have gotten here");
+            }
+            catch (InvalidOperationException)
+            {
+                // we expect this
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Amb_Many_Last_GetAsyncEnumerator_Crashes()
+        {
+            var source = AsyncEnumerableEx.Amb(
+                AsyncEnumerableEx.Never<int>(),
+                AsyncEnumerableEx.Never<int>(),
+                new FailingGetAsyncEnumerator<int>()
+            );
+
+            var xs = source.GetAsyncEnumerator();
+
+            try
+            {
+                await xs.MoveNextAsync();
+
+                Assert.False(true, "Should not have gotten here");
+            }
+            catch (InvalidOperationException)
+            {
+                // we expect this
+            }
+            finally
+            {
+                await xs.DisposeAsync();
+            }
+        }
+
+        private class FailingGetAsyncEnumerator<T> : IAsyncEnumerable<T>
+        {
+            public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
+            {
+                throw new InvalidOperationException();
+            }
+        }
+    }
+}

+ 132 - 0
Ix.NET/Source/System.Interactive.Async.Tests/System/Linq/Operators/Timeout.cs

@@ -0,0 +1,132 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the Apache 2.0 License.
+// See the LICENSE file in the project root for more information. 
+
+using System;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using Xunit;
+
+namespace Tests
+{
+    public class Timeout : AsyncEnumerableExTests
+    {
+        [Fact]
+        public async Task Timeout_Never()
+        {
+            var source = AsyncEnumerableEx.Never<int>().Timeout(TimeSpan.FromMilliseconds(100));
+
+            var en = source.GetAsyncEnumerator();
+
+            try
+            {
+                await en.MoveNextAsync();
+
+                Assert.False(true, "MoveNextAsync should have thrown");
+            }
+            catch (TimeoutException)
+            {
+                // expected
+            }
+            finally
+            {
+                await en.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Timeout_Double_Never()
+        {
+            var source = AsyncEnumerableEx.Never<int>()
+                .Timeout(TimeSpan.FromMilliseconds(300))
+                .Timeout(TimeSpan.FromMilliseconds(100));
+
+            var en = source.GetAsyncEnumerator();
+
+            try
+            {
+                await en.MoveNextAsync();
+
+                Assert.False(true, "MoveNextAsync should have thrown");
+            }
+            catch (TimeoutException)
+            {
+                // expected
+            }
+            finally
+            {
+                await en.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Timeout_Delayed_Main()
+        {
+            var source = AsyncEnumerable.Range(1, 5)
+                .SelectAwait(async v =>
+                {
+                    await Task.Delay(300);
+                    return v;
+                })
+                .Timeout(TimeSpan.FromMilliseconds(100));
+
+            var en = source.GetAsyncEnumerator();
+
+            try
+            {
+                await en.MoveNextAsync();
+
+                Assert.False(true, "MoveNextAsync should have thrown");
+            }
+            catch (TimeoutException)
+            {
+                // expected
+            }
+            finally
+            {
+                await en.DisposeAsync();
+            }
+        }
+
+        [Fact]
+        public async Task Timeout_Delayed_Main_Canceled()
+        {
+            var tcs = new TaskCompletionSource<int>();
+
+            var source = AsyncEnumerable.Range(1, 5)
+                .SelectAwaitWithCancellation(async (v, ct) =>
+                {
+                    try
+                    {
+                        await Task.Delay(500, ct);
+                    }
+                    catch (TaskCanceledException)
+                    {
+                        tcs.SetResult(0);
+                    }
+                    return v;
+                })
+                .Timeout(TimeSpan.FromMilliseconds(250));
+
+            var en = source.GetAsyncEnumerator();
+
+            try
+            {
+                await en.MoveNextAsync();
+
+                Assert.False(true, "MoveNextAsync should have thrown");
+            }
+            catch (TimeoutException)
+            {
+                // expected
+            }
+            finally
+            {
+                await en.DisposeAsync();
+            }
+
+            Assert.Equal(0, await tcs.Task);
+        }
+    }
+}

+ 36 - 6
Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Amb.cs

@@ -27,6 +27,14 @@ namespace System.Linq
                 Task<bool>? firstMoveNext = null;
                 Task<bool>? secondMoveNext = null;
 
+                //
+                // We need separate tokens for each source so that the non-winner can get disposed and unblocked
+                // i.e., see Never()
+                //
+
+                var firstCancelToken = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
+                var secondCancelToken = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
+
                 try
                 {
                     //
@@ -36,7 +44,7 @@ namespace System.Linq
                     //         adding a WhenAny combinator that does exactly that. We can even avoid calling AsTask.
                     //
 
-                    firstEnumerator = first.GetAsyncEnumerator(cancellationToken);
+                    firstEnumerator = first.GetAsyncEnumerator(firstCancelToken.Token);
                     firstMoveNext = firstEnumerator.MoveNextAsync().AsTask();
 
                     //
@@ -44,11 +52,14 @@ namespace System.Linq
                     //         overload which performs GetAsyncEnumerator/MoveNextAsync in pairs, rather than phased.
                     //
 
-                    secondEnumerator = second.GetAsyncEnumerator(cancellationToken);
+                    secondEnumerator = second.GetAsyncEnumerator(secondCancelToken.Token);
                     secondMoveNext = secondEnumerator.MoveNextAsync().AsTask();
                 }
                 catch
                 {
+                    secondCancelToken.Cancel();
+                    firstCancelToken.Cancel();
+
                     // NB: AwaitMoveNextAsyncAndDispose checks for null for both arguments, reducing the need for many null
                     //     checks over here.
 
@@ -58,6 +69,7 @@ namespace System.Linq
                         AwaitMoveNextAsyncAndDispose(firstMoveNext, firstEnumerator)
                     };
 
+
                     await Task.WhenAll(cleanup).ConfigureAwait(false);
 
                     throw;
@@ -83,11 +95,13 @@ namespace System.Linq
                 if (moveNextWinner == firstMoveNext)
                 {
                     winner = firstEnumerator;
+                    secondCancelToken.Cancel();
                     disposeLoser = AwaitMoveNextAsyncAndDispose(secondMoveNext, secondEnumerator);
                 }
                 else
                 {
                     winner = secondEnumerator;
+                    firstCancelToken.Cancel();
                     disposeLoser = AwaitMoveNextAsyncAndDispose(firstMoveNext, firstEnumerator);
                 }
 
@@ -143,12 +157,17 @@ namespace System.Linq
 
                 var enumerators = new IAsyncEnumerator<TSource>[n];
                 var moveNexts = new Task<bool>[n];
+                var individualTokenSources = new CancellationTokenSource[n];
+                for (var i = 0; i < n; i++)
+                {
+                    individualTokenSources[i] = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
+                }
 
                 try
                 {
                     for (var i = 0; i < n; i++)
                     {
-                        var enumerator = sources[i].GetAsyncEnumerator(cancellationToken);
+                        var enumerator = sources[i].GetAsyncEnumerator(individualTokenSources[i].Token);
 
                         enumerators[i] = enumerator;
                         moveNexts[i] = enumerator.MoveNextAsync().AsTask();
@@ -158,12 +177,15 @@ namespace System.Linq
                 {
                     var cleanup = new Task[n];
 
-                    for (var i = 0; i < n; i++)
+                    for (var i = n - 1; i >= 0; i--)
                     {
+                        individualTokenSources[i].Cancel();
+
                         cleanup[i] = AwaitMoveNextAsyncAndDispose(moveNexts[i], enumerators[i]);
                     }
 
                     await Task.WhenAll(cleanup).ConfigureAwait(false);
+
                     throw;
                 }
 
@@ -185,10 +207,11 @@ namespace System.Linq
 
                 var loserCleanupTasks = new List<Task>(n - 1);
 
-                for (var i = 0; i < n; i++)
+                for (var i = n - 1; i >= 0; i--)
                 {
                     if (i != winnerIndex)
                     {
+                        individualTokenSources[i].Cancel();
                         var loserCleanupTask = AwaitMoveNextAsyncAndDispose(moveNexts[i], enumerators[i]);
                         loserCleanupTasks.Add(loserCleanupTask);
                     }
@@ -236,7 +259,14 @@ namespace System.Linq
                 {
                     if (moveNextAsync != null)
                     {
-                        await moveNextAsync.ConfigureAwait(false);
+                        try
+                        {
+                            await moveNextAsync.ConfigureAwait(false);
+                        }
+                        catch (TaskCanceledException)
+                        {
+                            // ignored because of cancelling the non-winners
+                        }
                     }
                 }
             }

+ 1 - 1
Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Never.cs

@@ -49,7 +49,7 @@ namespace System.Linq
 
                     _once = true;
                     var task = new TaskCompletionSource<bool>();
-                    _registration = _token.Register(state => ((TaskCompletionSource<bool>)state!).SetCanceled(), task);
+                    _registration = _token.Register(state => ((TaskCompletionSource<bool>)state).TrySetCanceled(_token), task);
                     return new ValueTask<bool>(task.Task);
                 }
             }

+ 12 - 2
Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Timeout.cs

@@ -31,6 +31,8 @@ namespace System.Linq
 
             private Task? _loserTask;
 
+            private CancellationTokenSource? _sourceCTS;
+
             public TimeoutAsyncIterator(IAsyncEnumerable<TSource> source, TimeSpan timeout)
             {
                 _source = source;
@@ -55,6 +57,11 @@ namespace System.Linq
                     await _enumerator.DisposeAsync().ConfigureAwait(false);
                     _enumerator = null;
                 }
+                if (_sourceCTS != null)
+                {
+                    _sourceCTS.Dispose();
+                    _sourceCTS = null;
+                }
 
                 await base.DisposeAsync().ConfigureAwait(false);
             }
@@ -64,7 +71,8 @@ namespace System.Linq
                 switch (_state)
                 {
                     case AsyncIteratorState.Allocated:
-                        _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
+                        _sourceCTS = CancellationTokenSource.CreateLinkedTokenSource(_cancellationToken);
+                        _enumerator = _source.GetAsyncEnumerator(_sourceCTS.Token);
 
                         _state = AsyncIteratorState.Iterating;
                         goto case AsyncIteratorState.Iterating;
@@ -74,7 +82,7 @@ namespace System.Linq
 
                         if (!moveNext.IsCompleted)
                         {
-                            using var delayCts = new CancellationTokenSource();
+                            using var delayCts = CancellationTokenSource.CreateLinkedTokenSource(_cancellationToken);
 
                             var delay = Task.Delay(_timeout, delayCts.Token);
 
@@ -98,6 +106,8 @@ namespace System.Linq
 
                                 _loserTask = next.ContinueWith((_, state) => ((IAsyncDisposable)state!).DisposeAsync().AsTask(), _enumerator);
 
+                                _sourceCTS!.Cancel();
+
                                 throw new TimeoutException();
                             }
 

+ 1 - 1
Rx.NET/Source/Directory.build.props

@@ -25,7 +25,7 @@
 
   <ItemGroup>
     <PackageReference Include="Microsoft.SourceLink.GitHub" Version="1.0.0-beta2-19554-01" PrivateAssets="All"/>
-    <PackageReference Include="Nerdbank.GitVersioning" Version="3.0.26" PrivateAssets="all" />
+    <PackageReference Include="Nerdbank.GitVersioning" Version="3.0.28" PrivateAssets="all" />
   </ItemGroup>
 
   <ItemGroup Condition="'$(IsTestProject)' == 'true'">

+ 10 - 2
Rx.NET/Source/src/System.Reactive/Concurrency/EventLoopScheduler.cs

@@ -153,7 +153,7 @@ namespace System.Reactive.Concurrency
             {
                 if (_disposed)
                 {
-                    throw new ObjectDisposedException("");
+                    throw new ObjectDisposedException(nameof(EventLoopScheduler));
                 }
 
                 if (dueTime <= TimeSpan.Zero)
@@ -351,7 +351,15 @@ namespace System.Reactive.Concurrency
                     {
                         if (!item.IsCanceled)
                         {
-                            item.Invoke();
+                            try
+                            {
+                                item.Invoke();
+                            }
+                            catch (ObjectDisposedException ex) when (nameof(EventLoopScheduler).Equals(ex.ObjectName))
+                            {
+                                // Since we are not inside the lock at this point
+                                // the scheduler can be disposed before the item had a chance to run
+                            }
                         }
                     }
                 }

+ 8 - 61
Rx.NET/Source/src/System.Reactive/Linq/Observable/FirstLastBlocking.cs

@@ -2,105 +2,52 @@
 // The .NET Foundation licenses this file to you under the Apache 2.0 License.
 // See the LICENSE file in the project root for more information. 
 
-using System.Reactive.Disposables;
 using System.Threading;
 
 namespace System.Reactive.Linq.ObservableImpl
 {
-    internal abstract class BaseBlocking<T> : CountdownEvent, IObserver<T>
+    internal abstract class BaseBlocking<T> : ManualResetEventSlim, IObserver<T>
     {
-        protected IDisposable _upstream;
-
         internal T _value;
         internal bool _hasValue;
         internal Exception _error;
-        private int _once;
-
-        internal BaseBlocking() : base(1) { }
 
-        internal void SetUpstream(IDisposable d)
-        {
-            Disposable.SetSingle(ref _upstream, d);
-        }
+        internal BaseBlocking() { }
 
-        protected void Unblock()
+        public void OnCompleted()
         {
-            if (Interlocked.CompareExchange(ref _once, 1, 0) == 0)
-            {
-                Signal();
-            }
+            Set();
         }
 
-        public abstract void OnCompleted();
-        public virtual void OnError(Exception error)
+        public void OnError(Exception error)
         {
             _value = default;
             _error = error;
-            Unblock();
+            Set();
         }
-        public abstract void OnNext(T value);
 
-        public new void Dispose()
-        {
-            base.Dispose();
-            if (!Disposable.GetIsDisposed(ref _upstream))
-            {
-                Disposable.TryDispose(ref _upstream);
-            }
-        }
+        public abstract void OnNext(T value);
     }
 
     internal sealed class FirstBlocking<T> : BaseBlocking<T>
     {
-        public override void OnCompleted()
-        {
-            Unblock();
-            if (!Disposable.GetIsDisposed(ref _upstream))
-            {
-                Disposable.TryDispose(ref _upstream);
-            }
-        }
-
-        public override void OnError(Exception error)
-        {
-            base.OnError(error);
-            if (!Disposable.GetIsDisposed(ref _upstream))
-            {
-                Disposable.TryDispose(ref _upstream);
-            }
-        }
-
         public override void OnNext(T value)
         {
             if (!_hasValue)
             {
                 _value = value;
                 _hasValue = true;
-                Disposable.TryDispose(ref _upstream);
-                Unblock();
+                Set();
             }
         }
     }
 
     internal sealed class LastBlocking<T> : BaseBlocking<T>
     {
-        public override void OnCompleted()
-        {
-            Unblock();
-            Disposable.TryDispose(ref _upstream);
-        }
-
-        public override void OnError(Exception error)
-        {
-            base.OnError(error);
-            Disposable.TryDispose(ref _upstream);
-        }
-
         public override void OnNext(T value)
         {
             _value = value;
             _hasValue = true;
         }
-
     }
 }

+ 4 - 14
Rx.NET/Source/src/System.Reactive/Linq/QueryLanguage.Blocking.cs

@@ -69,14 +69,9 @@ namespace System.Reactive.Linq
         {
             using (var consumer = new FirstBlocking<TSource>())
             {
-                using (var d = source.Subscribe(consumer))
+                using (source.Subscribe(consumer))
                 {
-                    consumer.SetUpstream(d);
-
-                    if (consumer.CurrentCount != 0)
-                    {
-                        consumer.Wait();
-                    }
+                    consumer.Wait();
                 }
 
                 consumer._error.ThrowIfNotNull();
@@ -166,14 +161,9 @@ namespace System.Reactive.Linq
             using (var consumer = new LastBlocking<TSource>())
             {
 
-                using (var d = source.Subscribe(consumer))
+                using (source.Subscribe(consumer))
                 {
-                    consumer.SetUpstream(d);
-
-                    if (consumer.CurrentCount != 0)
-                    {
-                        consumer.Wait();
-                    }
+                    consumer.Wait();
                 }
 
                 consumer._error.ThrowIfNotNull();

+ 1 - 1
Rx.NET/Source/tests/Tests.System.Reactive.ApiApprovals/Tests.System.Reactive.ApiApprovals.csproj

@@ -31,7 +31,7 @@
     <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.3.0" />
     <PackageReference Include="xunit" Version="2.4.1" />
     <PackageReference Include="xunit.runner.visualstudio" Version="2.4.1" />
-    <PackageReference Include="ApprovalTests" Version="4.2.2" />
+    <PackageReference Include="ApprovalTests" Version="4.4.0" />
     <PackageReference Include="DiffPlex" Version="1.4.4" />
     <PackageReference Include="PublicApiGenerator" Version="9.3.0" />
   </ItemGroup>

+ 14 - 0
Rx.NET/Source/tests/Tests.System.Reactive/Tests/Concurrency/EventLoopSchedulerTest.cs

@@ -7,6 +7,7 @@ using System.Collections.Generic;
 using System.Diagnostics;
 using System.Reactive.Concurrency;
 using System.Reactive.Disposables;
+using System.Reactive.Linq;
 using System.Threading;
 using Microsoft.Reactive.Testing;
 using Xunit;
@@ -41,6 +42,19 @@ namespace ReactiveTests.Tests
             Assert.True(res.Seconds < 1);
         }
 
+        [Fact]
+        public void EventLoop_DisposeWithInFlightActions()
+        {
+            using (var scheduler = new EventLoopScheduler())
+            using (var subscription = Observable
+                .Range(1, 10)
+                .ObserveOn(scheduler)
+                .Subscribe(_ => Thread.Sleep(50)))
+            {
+                Thread.Sleep(50);
+            }
+        }
+
         [Fact]
         public void EventLoop_ScheduleAction()
         {