Do.cs 11 KB


  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Diagnostics;
  6. using System.Threading;
  7. using System.Threading.Tasks;
  8. namespace System.Linq
  9. {
  10. public static partial class AsyncEnumerableEx
  11. {
  12. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext)
  13. {
  14. if (source == null)
  15. throw Error.ArgumentNull(nameof(source));
  16. if (onNext == null)
  17. throw Error.ArgumentNull(nameof(onNext));
  18. return DoCore(source, onNext: onNext, onError: null, onCompleted: null);
  19. }
  20. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action onCompleted)
  21. {
  22. if (source == null)
  23. throw Error.ArgumentNull(nameof(source));
  24. if (onNext == null)
  25. throw Error.ArgumentNull(nameof(onNext));
  26. if (onCompleted == null)
  27. throw Error.ArgumentNull(nameof(onCompleted));
  28. return DoCore(source, onNext: onNext, onError: null, onCompleted: onCompleted);
  29. }
  30. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError)
  31. {
  32. if (source == null)
  33. throw Error.ArgumentNull(nameof(source));
  34. if (onNext == null)
  35. throw Error.ArgumentNull(nameof(onNext));
  36. if (onError == null)
  37. throw Error.ArgumentNull(nameof(onError));
  38. return DoCore(source, onNext: onNext, onError: onError, onCompleted: null);
  39. }
  40. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError, Action onCompleted)
  41. {
  42. if (source == null)
  43. throw Error.ArgumentNull(nameof(source));
  44. if (onNext == null)
  45. throw Error.ArgumentNull(nameof(onNext));
  46. if (onError == null)
  47. throw Error.ArgumentNull(nameof(onError));
  48. if (onCompleted == null)
  49. throw Error.ArgumentNull(nameof(onCompleted));
  50. return DoCore(source, onNext, onError, onCompleted);
  51. }
  52. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext)
  53. {
  54. if (source == null)
  55. throw Error.ArgumentNull(nameof(source));
  56. if (onNext == null)
  57. throw Error.ArgumentNull(nameof(onNext));
  58. return DoCore(source, onNext: onNext, onError: null, onCompleted: null);
  59. }
  60. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Task> onCompleted)
  61. {
  62. if (source == null)
  63. throw Error.ArgumentNull(nameof(source));
  64. if (onNext == null)
  65. throw Error.ArgumentNull(nameof(onNext));
  66. if (onCompleted == null)
  67. throw Error.ArgumentNull(nameof(onCompleted));
  68. return DoCore(source, onNext: onNext, onError: null, onCompleted: onCompleted);
  69. }
  70. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Exception, Task> onError)
  71. {
  72. if (source == null)
  73. throw Error.ArgumentNull(nameof(source));
  74. if (onNext == null)
  75. throw Error.ArgumentNull(nameof(onNext));
  76. if (onError == null)
  77. throw Error.ArgumentNull(nameof(onError));
  78. return DoCore(source, onNext: onNext, onError: onError, onCompleted: null);
  79. }
  80. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Exception, Task> onError, Func<Task> onCompleted)
  81. {
  82. if (source == null)
  83. throw Error.ArgumentNull(nameof(source));
  84. if (onNext == null)
  85. throw Error.ArgumentNull(nameof(onNext));
  86. if (onError == null)
  87. throw Error.ArgumentNull(nameof(onError));
  88. if (onCompleted == null)
  89. throw Error.ArgumentNull(nameof(onCompleted));
  90. return DoCore(source, onNext, onError, onCompleted);
  91. }
  92. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, IObserver<TSource> observer)
  93. {
  94. if (source == null)
  95. throw Error.ArgumentNull(nameof(source));
  96. if (observer == null)
  97. throw Error.ArgumentNull(nameof(observer));
  98. return DoCore(source, new Action<TSource>(observer.OnNext), new Action<Exception>(observer.OnError), new Action(observer.OnCompleted));
  99. }
  100. private static IAsyncEnumerable<TSource> DoCore<TSource>(IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError, Action onCompleted)
  101. {
  102. return new DoAsyncIterator<TSource>(source, onNext, onError, onCompleted);
  103. }
  104. private static IAsyncEnumerable<TSource> DoCore<TSource>(IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Exception, Task> onError, Func<Task> onCompleted)
  105. {
  106. return new DoAsyncIteratorWithTask<TSource>(source, onNext, onError, onCompleted);
  107. }
  108. private sealed class DoAsyncIterator<TSource> : AsyncIterator<TSource>
  109. {
  110. private readonly Action _onCompleted;
  111. private readonly Action<Exception> _onError;
  112. private readonly Action<TSource> _onNext;
  113. private readonly IAsyncEnumerable<TSource> _source;
  114. private IAsyncEnumerator<TSource> _enumerator;
  115. public DoAsyncIterator(IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError, Action onCompleted)
  116. {
  117. Debug.Assert(source != null);
  118. Debug.Assert(onNext != null);
  119. _source = source;
  120. _onNext = onNext;
  121. _onError = onError;
  122. _onCompleted = onCompleted;
  123. }
  124. public override AsyncIterator<TSource> Clone()
  125. {
  126. return new DoAsyncIterator<TSource>(_source, _onNext, _onError, _onCompleted);
  127. }
  128. public override async ValueTask DisposeAsync()
  129. {
  130. if (_enumerator != null)
  131. {
  132. await _enumerator.DisposeAsync().ConfigureAwait(false);
  133. _enumerator = null;
  134. }
  135. await base.DisposeAsync().ConfigureAwait(false);
  136. }
  137. protected override async ValueTask<bool> MoveNextCore()
  138. {
  139. switch (state)
  140. {
  141. case AsyncIteratorState.Allocated:
  142. _enumerator = _source.GetAsyncEnumerator(cancellationToken);
  143. state = AsyncIteratorState.Iterating;
  144. goto case AsyncIteratorState.Iterating;
  145. case AsyncIteratorState.Iterating:
  146. try
  147. {
  148. if (await _enumerator.MoveNextAsync().ConfigureAwait(false))
  149. {
  150. current = _enumerator.Current;
  151. _onNext(current);
  152. return true;
  153. }
  154. }
  155. catch (OperationCanceledException)
  156. {
  157. throw;
  158. }
  159. catch (Exception ex) when (_onError != null)
  160. {
  161. _onError(ex);
  162. throw;
  163. }
  164. _onCompleted?.Invoke();
  165. await DisposeAsync().ConfigureAwait(false);
  166. break;
  167. }
  168. return false;
  169. }
  170. }
  171. private sealed class DoAsyncIteratorWithTask<TSource> : AsyncIterator<TSource>
  172. {
  173. private readonly Func<Task> _onCompleted;
  174. private readonly Func<Exception, Task> _onError;
  175. private readonly Func<TSource, Task> _onNext;
  176. private readonly IAsyncEnumerable<TSource> _source;
  177. private IAsyncEnumerator<TSource> _enumerator;
  178. public DoAsyncIteratorWithTask(IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Exception, Task> onError, Func<Task> onCompleted)
  179. {
  180. Debug.Assert(source != null);
  181. Debug.Assert(onNext != null);
  182. _source = source;
  183. _onNext = onNext;
  184. _onError = onError;
  185. _onCompleted = onCompleted;
  186. }
  187. public override AsyncIterator<TSource> Clone()
  188. {
  189. return new DoAsyncIteratorWithTask<TSource>(_source, _onNext, _onError, _onCompleted);
  190. }
  191. public override async ValueTask DisposeAsync()
  192. {
  193. if (_enumerator != null)
  194. {
  195. await _enumerator.DisposeAsync().ConfigureAwait(false);
  196. _enumerator = null;
  197. }
  198. await base.DisposeAsync().ConfigureAwait(false);
  199. }
  200. protected override async ValueTask<bool> MoveNextCore()
  201. {
  202. switch (state)
  203. {
  204. case AsyncIteratorState.Allocated:
  205. _enumerator = _source.GetAsyncEnumerator(cancellationToken);
  206. state = AsyncIteratorState.Iterating;
  207. goto case AsyncIteratorState.Iterating;
  208. case AsyncIteratorState.Iterating:
  209. try
  210. {
  211. if (await _enumerator.MoveNextAsync().ConfigureAwait(false))
  212. {
  213. current = _enumerator.Current;
  214. await _onNext(current).ConfigureAwait(false);
  215. return true;
  216. }
  217. }
  218. catch (OperationCanceledException)
  219. {
  220. throw;
  221. }
  222. catch (Exception ex) when (_onError != null)
  223. {
  224. await _onError(ex).ConfigureAwait(false);
  225. throw;
  226. }
  227. if (_onCompleted != null)
  228. {
  229. await _onCompleted().ConfigureAwait(false);
  230. }
  231. await DisposeAsync().ConfigureAwait(false);
  232. break;
  233. }
  234. return false;
  235. }
  236. }
  237. }
  238. }