Sfoglia il codice sorgente

Add WithCancellation operator.

Bart De Smet 7 anni fa
parent
commit
7499e2a070

+ 48 - 0
Ix.NET/Source/System.Linq.Async/System/Linq/AsyncEnumerable.cs

@@ -5,6 +5,7 @@
 using System.Collections.Generic;
 using System.Diagnostics;
 using System.Threading;
+using System.Threading.Tasks;
 
 namespace System.Linq
 {
@@ -18,6 +19,14 @@ namespace System.Linq
             return new AnonymousAsyncEnumerable<T>(getEnumerator);
         }
 
+        public static IAsyncEnumerable<T> WithCancellation<T>(this IAsyncEnumerable<T> source, CancellationToken cancellationToken)
+        {
+            if (source == null)
+                throw Error.ArgumentNull(nameof(source));
+
+            return new WithCancellationAsyncEnumerable<T>(source, cancellationToken);
+        }
+
         private sealed class AnonymousAsyncEnumerable<T> : IAsyncEnumerable<T>
         {
             private readonly Func<CancellationToken, IAsyncEnumerator<T>> _getEnumerator;
@@ -31,5 +40,44 @@ namespace System.Linq
 
             public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken) => _getEnumerator(cancellationToken);
         }
+
+        // REVIEW: Explicit implementation of the interfaces allows for composition with other "modifier operators" such as ConfigureAwait.
+        //         We expect that the "await foreach" statement will bind to the public struct methods, thus avoiding boxing.
+
+        public struct WithCancellationAsyncEnumerable<T> : IAsyncEnumerable<T>
+        {
+            private readonly IAsyncEnumerable<T> _source;
+            private readonly CancellationToken _cancellationToken;
+
+            public WithCancellationAsyncEnumerable(IAsyncEnumerable<T> source, CancellationToken cancellationToken)
+            {
+                _source = source;
+                _cancellationToken = cancellationToken;
+            }
+
+            // REVIEW: Should we simply ignore the second cancellation token or should we link the two?
+
+            public WithCancellationAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken)
+                => new WithCancellationAsyncEnumerator(_source.GetAsyncEnumerator(_cancellationToken));
+
+            IAsyncEnumerator<T> IAsyncEnumerable<T>.GetAsyncEnumerator(CancellationToken cancellationToken)
+                => GetAsyncEnumerator(cancellationToken);
+
+            public struct WithCancellationAsyncEnumerator : IAsyncEnumerator<T>
+            {
+                private readonly IAsyncEnumerator<T> _enumerator;
+
+                public WithCancellationAsyncEnumerator(IAsyncEnumerator<T> enumerator)
+                {
+                    _enumerator = enumerator;
+                }
+
+                public T Current => _enumerator.Current;
+
+                public ValueTask DisposeAsync() => _enumerator.DisposeAsync();
+
+                public ValueTask<bool> MoveNextAsync() => _enumerator.MoveNextAsync();
+            }
+        }
     }
 }