Browse Source

Modify OfType to accept nullable input element type (#2260)

This enables e.g.

```cs
IObservable<string?> s = ...;
IObservable<string> s.OfType<string>()
```

to filter it down to just the non-null strings.

This is consistent with how System.Linq.Async was changed (and how the .NET runtime library replacement for this, System.Linq.AsyncEnumerable` also works).
Ian Griffiths 1 month ago
parent
commit
0aa0a00924

+ 1 - 1
Rx.NET/Source/src/System.Reactive/Linq/IQueryLanguage.cs

@@ -646,7 +646,7 @@ namespace System.Reactive.Linq
         IObservable<IGroupedObservable<TKey, TSource>> GroupByUntil<TSource, TKey, TDuration>(IObservable<TSource> source, Func<TSource, TKey> keySelector, Func<IGroupedObservable<TKey, TSource>, IObservable<TDuration>> durationSelector, int capacity);
         IObservable<IGroupedObservable<TKey, TSource>> GroupByUntil<TSource, TKey, TDuration>(IObservable<TSource> source, Func<TSource, TKey> keySelector, Func<IGroupedObservable<TKey, TSource>, IObservable<TDuration>> durationSelector, int capacity);
         IObservable<TResult> GroupJoin<TLeft, TRight, TLeftDuration, TRightDuration, TResult>(IObservable<TLeft> left, IObservable<TRight> right, Func<TLeft, IObservable<TLeftDuration>> leftDurationSelector, Func<TRight, IObservable<TRightDuration>> rightDurationSelector, Func<TLeft, IObservable<TRight>, TResult> resultSelector);
         IObservable<TResult> GroupJoin<TLeft, TRight, TLeftDuration, TRightDuration, TResult>(IObservable<TLeft> left, IObservable<TRight> right, Func<TLeft, IObservable<TLeftDuration>> leftDurationSelector, Func<TRight, IObservable<TRightDuration>> rightDurationSelector, Func<TLeft, IObservable<TRight>, TResult> resultSelector);
         IObservable<TResult> Join<TLeft, TRight, TLeftDuration, TRightDuration, TResult>(IObservable<TLeft> left, IObservable<TRight> right, Func<TLeft, IObservable<TLeftDuration>> leftDurationSelector, Func<TRight, IObservable<TRightDuration>> rightDurationSelector, Func<TLeft, TRight, TResult> resultSelector);
         IObservable<TResult> Join<TLeft, TRight, TLeftDuration, TRightDuration, TResult>(IObservable<TLeft> left, IObservable<TRight> right, Func<TLeft, IObservable<TLeftDuration>> leftDurationSelector, Func<TRight, IObservable<TRightDuration>> rightDurationSelector, Func<TLeft, TRight, TResult> resultSelector);
-        IObservable<TResult> OfType<TResult>(IObservable<object> source);
+        IObservable<TResult> OfType<TResult>(IObservable<object?> source);
         IObservable<TResult> Select<TSource, TResult>(IObservable<TSource> source, Func<TSource, TResult> selector);
         IObservable<TResult> Select<TSource, TResult>(IObservable<TSource> source, Func<TSource, TResult> selector);
         IObservable<TResult> Select<TSource, TResult>(IObservable<TSource> source, Func<TSource, int, TResult> selector);
         IObservable<TResult> Select<TSource, TResult>(IObservable<TSource> source, Func<TSource, int, TResult> selector);
         IObservable<TOther> SelectMany<TSource, TOther>(IObservable<TSource> source, IObservable<TOther> other);
         IObservable<TOther> SelectMany<TSource, TOther>(IObservable<TSource> source, IObservable<TOther> other);

+ 1 - 1
Rx.NET/Source/src/System.Reactive/Linq/Observable.StandardSequenceOperators.cs

@@ -929,7 +929,7 @@ namespace System.Reactive.Linq
         /// <param name="source">The observable sequence that contains the elements to be filtered.</param>
         /// <param name="source">The observable sequence that contains the elements to be filtered.</param>
         /// <returns>An observable sequence that contains elements from the input sequence of type TResult.</returns>
         /// <returns>An observable sequence that contains elements from the input sequence of type TResult.</returns>
         /// <exception cref="ArgumentNullException"><paramref name="source"/> is null.</exception>
         /// <exception cref="ArgumentNullException"><paramref name="source"/> is null.</exception>
-        public static IObservable<TResult> OfType<TResult>(this IObservable<object> source)
+        public static IObservable<TResult> OfType<TResult>(this IObservable<object?> source)
         {
         {
             if (source == null)
             if (source == null)
             {
             {

+ 2 - 2
Rx.NET/Source/src/System.Reactive/Linq/QueryLanguage.StandardSequenceOperators.cs

@@ -188,9 +188,9 @@ namespace System.Reactive.Linq
 
 
         #region + OfType +
         #region + OfType +
 
 
-        public virtual IObservable<TResult> OfType<TResult>(IObservable<object> source)
+        public virtual IObservable<TResult> OfType<TResult>(IObservable<object?> source)
         {
         {
-            return new OfType<object, TResult>(source);
+            return new OfType<object?, TResult>(source);
         }
         }
 
 
         #endregion
         #endregion

+ 1 - 1
Rx.NET/Source/tests/Tests.System.Reactive.ApiApprovals/Api/ApiApprovalTests.Core.verified.cs

@@ -1297,7 +1297,7 @@ namespace System.Reactive.Linq
         public static System.Collections.Generic.IEnumerable<TSource> Next<TSource>(this System.IObservable<TSource> source) { }
         public static System.Collections.Generic.IEnumerable<TSource> Next<TSource>(this System.IObservable<TSource> source) { }
         public static System.IObservable<TSource> ObserveOn<TSource>(this System.IObservable<TSource> source, System.Reactive.Concurrency.IScheduler scheduler) { }
         public static System.IObservable<TSource> ObserveOn<TSource>(this System.IObservable<TSource> source, System.Reactive.Concurrency.IScheduler scheduler) { }
         public static System.IObservable<TSource> ObserveOn<TSource>(this System.IObservable<TSource> source, System.Threading.SynchronizationContext context) { }
         public static System.IObservable<TSource> ObserveOn<TSource>(this System.IObservable<TSource> source, System.Threading.SynchronizationContext context) { }
-        public static System.IObservable<TResult> OfType<TResult>(this System.IObservable<object> source) { }
+        public static System.IObservable<TResult> OfType<TResult>(this System.IObservable<object?> source) { }
         public static System.IObservable<TSource> OnErrorResumeNext<TSource>(this System.Collections.Generic.IEnumerable<System.IObservable<TSource>> sources) { }
         public static System.IObservable<TSource> OnErrorResumeNext<TSource>(this System.Collections.Generic.IEnumerable<System.IObservable<TSource>> sources) { }
         public static System.IObservable<TSource> OnErrorResumeNext<TSource>(params System.IObservable<TSource>[] sources) { }
         public static System.IObservable<TSource> OnErrorResumeNext<TSource>(params System.IObservable<TSource>[] sources) { }
         public static System.IObservable<TSource> OnErrorResumeNext<TSource>(this System.IObservable<TSource> first, System.IObservable<TSource> second) { }
         public static System.IObservable<TSource> OnErrorResumeNext<TSource>(this System.IObservable<TSource> first, System.IObservable<TSource> second) { }

+ 43 - 0
Rx.NET/Source/tests/Tests.System.Reactive/Tests/Linq/Observable/OfTypeTest.cs

@@ -64,6 +64,49 @@ namespace ReactiveTests.Tests
             );
             );
         }
         }
 
 
+#nullable enable
+        [TestMethod]
+        public void OfType_NullableSourceOfTypeNonNull()
+        {
+            var scheduler = new TestScheduler();
+
+            var xs = scheduler.CreateHotObservable(
+                OnNext<A?>(210, new B(0)),
+                OnNext<A?>(220, new A(1)),
+                OnNext<A?>(230, default(A?)),
+                OnNext<A?>(240, new D(3)),
+                OnNext<A?>(250, new C(4)),
+                OnNext<A?>(260, new B(5)),
+                OnNext<A?>(270, new B(6)),
+                OnNext<A?>(280, new D(7)),
+                OnNext<A?>(290, new A(8)),
+                OnNext<A?>(340, new B(10)),
+                OnCompleted<A?>(350)
+            );
+
+            var res = scheduler.Start(() =>
+                xs.OfType<A>()
+            );
+
+            res.Messages.AssertEqual(
+                OnNext<A>(210, new B(0)),
+                OnNext<A>(220, new A(1)),
+                OnNext<A>(240, new D(3)),
+                OnNext<A>(250, new C(4)),
+                OnNext<A>(260, new B(5)),
+                OnNext<A>(270, new B(6)),
+                OnNext<A>(280, new D(7)),
+                OnNext<A>(290, new A(8)),
+                OnNext<A>(340, new B(10)),
+                OnCompleted<A>(350)
+            );
+
+            xs.Subscriptions.AssertEqual(
+                Subscribe(200, 350)
+            );
+        }
+#nullable restore
+
         [TestMethod]
         [TestMethod]
         public void OfType_Error()
         public void OfType_Error()
         {
         {