Quellcode durchsuchen

Async variants of Do.

Bart De Smet vor 8 Jahren
Ursprung
Commit
12063f6d82
1 geänderte Dateien mit 135 neuen und 3 gelöschten Zeilen
  1. 135 3
      Ix.NET/Source/System.Interactive.Async/Do.cs

+ 135 - 3
Ix.NET/Source/System.Interactive.Async/Do.cs

@@ -58,6 +58,54 @@ namespace System.Linq
             return DoHelper(source, onNext, onError, onCompleted);
         }
 
+        public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (onNext == null)
+                throw new ArgumentNullException(nameof(onNext));
+
+            return DoHelper(source, onNext, null, null);
+        }
+
+        public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Task> onCompleted)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (onNext == null)
+                throw new ArgumentNullException(nameof(onNext));
+            if (onCompleted == null)
+                throw new ArgumentNullException(nameof(onCompleted));
+
+            return DoHelper(source, onNext, null, onCompleted);
+        }
+
+        public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Exception, Task> onError)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (onNext == null)
+                throw new ArgumentNullException(nameof(onNext));
+            if (onError == null)
+                throw new ArgumentNullException(nameof(onError));
+
+            return DoHelper(source, onNext, onError, null);
+        }
+
+        public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Exception, Task> onError, Func<Task> onCompleted)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (onNext == null)
+                throw new ArgumentNullException(nameof(onNext));
+            if (onError == null)
+                throw new ArgumentNullException(nameof(onError));
+            if (onCompleted == null)
+                throw new ArgumentNullException(nameof(onCompleted));
+
+            return DoHelper(source, onNext, onError, onCompleted);
+        }
+
         public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, IObserver<TSource> observer)
         {
             if (source == null)
@@ -65,7 +113,7 @@ namespace System.Linq
             if (observer == null)
                 throw new ArgumentNullException(nameof(observer));
 
-            return DoHelper(source, observer.OnNext, observer.OnError, observer.OnCompleted);
+            return DoHelper(source, new Action<TSource>(observer.OnNext), new Action<Exception>(observer.OnError), new Action(observer.OnCompleted));
         }
 
         private static IAsyncEnumerable<TSource> DoHelper<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError, Action onCompleted)
@@ -73,6 +121,11 @@ namespace System.Linq
             return new DoAsyncIterator<TSource>(source, onNext, onError, onCompleted);
         }
 
+        private static IAsyncEnumerable<TSource> DoHelper<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Exception, Task> onError, Func<Task> onCompleted)
+        {
+            return new DoAsyncIteratorWithTask<TSource>(source, onNext, onError, onCompleted);
+        }
+
         private sealed class DoAsyncIterator<TSource> : AsyncIterator<TSource>
         {
             private readonly Action onCompleted;
@@ -133,9 +186,9 @@ namespace System.Linq
                         {
                             throw;
                         }
-                        catch (Exception ex)
+                        catch (Exception ex) when (onError != null)
                         {
-                            onError?.Invoke(ex);
+                            onError(ex);
                             throw;
                         }
 
@@ -148,5 +201,84 @@ namespace System.Linq
                 return false;
             }
         }
+
+        private sealed class DoAsyncIteratorWithTask<TSource> : AsyncIterator<TSource>
+        {
+            private readonly Func<Task> onCompleted;
+            private readonly Func<Exception, Task> onError;
+            private readonly Func<TSource, Task> onNext;
+            private readonly IAsyncEnumerable<TSource> source;
+
+            private IAsyncEnumerator<TSource> enumerator;
+
+            public DoAsyncIteratorWithTask(IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Exception, Task> onError, Func<Task> onCompleted)
+            {
+                Debug.Assert(source != null);
+                Debug.Assert(onNext != null);
+
+                this.source = source;
+                this.onNext = onNext;
+                this.onError = onError;
+                this.onCompleted = onCompleted;
+            }
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new DoAsyncIteratorWithTask<TSource>(source, onNext, onError, onCompleted);
+            }
+
+            public override async Task DisposeAsync()
+            {
+                if (enumerator != null)
+                {
+                    await enumerator.DisposeAsync().ConfigureAwait(false);
+                    enumerator = null;
+                }
+
+                await base.DisposeAsync().ConfigureAwait(false);
+            }
+
+            protected override async Task<bool> MoveNextCore()
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetAsyncEnumerator();
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        try
+                        {
+                            if (await enumerator.MoveNextAsync().ConfigureAwait(false))
+                            {
+                                current = enumerator.Current;
+                                await onNext(current).ConfigureAwait(false);
+
+                                return true;
+                            }
+                        }
+                        catch (OperationCanceledException)
+                        {
+                            throw;
+                        }
+                        catch (Exception ex) when (onError != null)
+                        {
+                            await onError(ex).ConfigureAwait(false);
+                            throw;
+                        }
+
+                        if (onCompleted != null)
+                        {
+                            await onCompleted().ConfigureAwait(false);
+                        }
+
+                        await DisposeAsync().ConfigureAwait(false);
+                        break;
+                }
+
+                return false;
+            }
+        }
     }
 }