Do.cs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Diagnostics;
  6. using System.Threading;
  7. using System.Threading.Tasks;
  8. namespace System.Linq
  9. {
  10. public static partial class AsyncEnumerableEx
  11. {
  12. // REVIEW: Should we convert Task-based overloads to ValueTask?
  13. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext)
  14. {
  15. if (source == null)
  16. throw Error.ArgumentNull(nameof(source));
  17. if (onNext == null)
  18. throw Error.ArgumentNull(nameof(onNext));
  19. return DoCore(source, onNext: onNext, onError: null, onCompleted: null);
  20. }
  21. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action onCompleted)
  22. {
  23. if (source == null)
  24. throw Error.ArgumentNull(nameof(source));
  25. if (onNext == null)
  26. throw Error.ArgumentNull(nameof(onNext));
  27. if (onCompleted == null)
  28. throw Error.ArgumentNull(nameof(onCompleted));
  29. return DoCore(source, onNext: onNext, onError: null, onCompleted: onCompleted);
  30. }
  31. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError)
  32. {
  33. if (source == null)
  34. throw Error.ArgumentNull(nameof(source));
  35. if (onNext == null)
  36. throw Error.ArgumentNull(nameof(onNext));
  37. if (onError == null)
  38. throw Error.ArgumentNull(nameof(onError));
  39. return DoCore(source, onNext: onNext, onError: onError, onCompleted: null);
  40. }
  41. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError, Action onCompleted)
  42. {
  43. if (source == null)
  44. throw Error.ArgumentNull(nameof(source));
  45. if (onNext == null)
  46. throw Error.ArgumentNull(nameof(onNext));
  47. if (onError == null)
  48. throw Error.ArgumentNull(nameof(onError));
  49. if (onCompleted == null)
  50. throw Error.ArgumentNull(nameof(onCompleted));
  51. return DoCore(source, onNext, onError, onCompleted);
  52. }
  53. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext)
  54. {
  55. if (source == null)
  56. throw Error.ArgumentNull(nameof(source));
  57. if (onNext == null)
  58. throw Error.ArgumentNull(nameof(onNext));
  59. return DoCore(source, onNext: onNext, onError: null, onCompleted: null);
  60. }
  61. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Task> onCompleted)
  62. {
  63. if (source == null)
  64. throw Error.ArgumentNull(nameof(source));
  65. if (onNext == null)
  66. throw Error.ArgumentNull(nameof(onNext));
  67. if (onCompleted == null)
  68. throw Error.ArgumentNull(nameof(onCompleted));
  69. return DoCore(source, onNext: onNext, onError: null, onCompleted: onCompleted);
  70. }
  71. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Exception, Task> onError)
  72. {
  73. if (source == null)
  74. throw Error.ArgumentNull(nameof(source));
  75. if (onNext == null)
  76. throw Error.ArgumentNull(nameof(onNext));
  77. if (onError == null)
  78. throw Error.ArgumentNull(nameof(onError));
  79. return DoCore(source, onNext: onNext, onError: onError, onCompleted: null);
  80. }
  81. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Exception, Task> onError, Func<Task> onCompleted)
  82. {
  83. if (source == null)
  84. throw Error.ArgumentNull(nameof(source));
  85. if (onNext == null)
  86. throw Error.ArgumentNull(nameof(onNext));
  87. if (onError == null)
  88. throw Error.ArgumentNull(nameof(onError));
  89. if (onCompleted == null)
  90. throw Error.ArgumentNull(nameof(onCompleted));
  91. return DoCore(source, onNext, onError, onCompleted);
  92. }
  93. #if !NO_DEEP_CANCELLATION
  94. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, Task> onNext)
  95. {
  96. if (source == null)
  97. throw Error.ArgumentNull(nameof(source));
  98. if (onNext == null)
  99. throw Error.ArgumentNull(nameof(onNext));
  100. return DoCore(source, onNext: onNext, onError: null, onCompleted: null);
  101. }
  102. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, Task> onNext, Func<CancellationToken, Task> onCompleted)
  103. {
  104. if (source == null)
  105. throw Error.ArgumentNull(nameof(source));
  106. if (onNext == null)
  107. throw Error.ArgumentNull(nameof(onNext));
  108. if (onCompleted == null)
  109. throw Error.ArgumentNull(nameof(onCompleted));
  110. return DoCore(source, onNext: onNext, onError: null, onCompleted: onCompleted);
  111. }
  112. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, Task> onNext, Func<Exception, CancellationToken, Task> onError)
  113. {
  114. if (source == null)
  115. throw Error.ArgumentNull(nameof(source));
  116. if (onNext == null)
  117. throw Error.ArgumentNull(nameof(onNext));
  118. if (onError == null)
  119. throw Error.ArgumentNull(nameof(onError));
  120. return DoCore(source, onNext: onNext, onError: onError, onCompleted: null);
  121. }
  122. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, Task> onNext, Func<Exception, CancellationToken, Task> onError, Func<CancellationToken, Task> onCompleted)
  123. {
  124. if (source == null)
  125. throw Error.ArgumentNull(nameof(source));
  126. if (onNext == null)
  127. throw Error.ArgumentNull(nameof(onNext));
  128. if (onError == null)
  129. throw Error.ArgumentNull(nameof(onError));
  130. if (onCompleted == null)
  131. throw Error.ArgumentNull(nameof(onCompleted));
  132. return DoCore(source, onNext, onError, onCompleted);
  133. }
  134. #endif
  135. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, IObserver<TSource> observer)
  136. {
  137. if (source == null)
  138. throw Error.ArgumentNull(nameof(source));
  139. if (observer == null)
  140. throw Error.ArgumentNull(nameof(observer));
  141. return DoCore(source, new Action<TSource>(observer.OnNext), new Action<Exception>(observer.OnError), new Action(observer.OnCompleted));
  142. }
  143. private static IAsyncEnumerable<TSource> DoCore<TSource>(IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError, Action onCompleted)
  144. {
  145. return new DoAsyncIterator<TSource>(source, onNext, onError, onCompleted);
  146. }
  147. private static IAsyncEnumerable<TSource> DoCore<TSource>(IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Exception, Task> onError, Func<Task> onCompleted)
  148. {
  149. return new DoAsyncIteratorWithTask<TSource>(source, onNext, onError, onCompleted);
  150. }
  151. #if !NO_DEEP_CANCELLATION
  152. private static IAsyncEnumerable<TSource> DoCore<TSource>(IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, Task> onNext, Func<Exception, CancellationToken, Task> onError, Func<CancellationToken, Task> onCompleted)
  153. {
  154. return new DoAsyncIteratorWithTaskAndCancellation<TSource>(source, onNext, onError, onCompleted);
  155. }
  156. #endif
  157. private sealed class DoAsyncIterator<TSource> : AsyncIterator<TSource>
  158. {
  159. private readonly Action _onCompleted;
  160. private readonly Action<Exception> _onError;
  161. private readonly Action<TSource> _onNext;
  162. private readonly IAsyncEnumerable<TSource> _source;
  163. private IAsyncEnumerator<TSource> _enumerator;
  164. public DoAsyncIterator(IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError, Action onCompleted)
  165. {
  166. Debug.Assert(source != null);
  167. Debug.Assert(onNext != null);
  168. _source = source;
  169. _onNext = onNext;
  170. _onError = onError;
  171. _onCompleted = onCompleted;
  172. }
  173. public override AsyncIteratorBase<TSource> Clone()
  174. {
  175. return new DoAsyncIterator<TSource>(_source, _onNext, _onError, _onCompleted);
  176. }
  177. public override async ValueTask DisposeAsync()
  178. {
  179. if (_enumerator != null)
  180. {
  181. await _enumerator.DisposeAsync().ConfigureAwait(false);
  182. _enumerator = null;
  183. }
  184. await base.DisposeAsync().ConfigureAwait(false);
  185. }
  186. protected override async ValueTask<bool> MoveNextCore()
  187. {
  188. switch (_state)
  189. {
  190. case AsyncIteratorState.Allocated:
  191. _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
  192. _state = AsyncIteratorState.Iterating;
  193. goto case AsyncIteratorState.Iterating;
  194. case AsyncIteratorState.Iterating:
  195. try
  196. {
  197. if (await _enumerator.MoveNextAsync().ConfigureAwait(false))
  198. {
  199. _current = _enumerator.Current;
  200. _onNext(_current);
  201. return true;
  202. }
  203. }
  204. catch (OperationCanceledException)
  205. {
  206. throw;
  207. }
  208. catch (Exception ex) when (_onError != null)
  209. {
  210. _onError(ex);
  211. throw;
  212. }
  213. _onCompleted?.Invoke();
  214. await DisposeAsync().ConfigureAwait(false);
  215. break;
  216. }
  217. return false;
  218. }
  219. }
  220. private sealed class DoAsyncIteratorWithTask<TSource> : AsyncIterator<TSource>
  221. {
  222. private readonly Func<Task> _onCompleted;
  223. private readonly Func<Exception, Task> _onError;
  224. private readonly Func<TSource, Task> _onNext;
  225. private readonly IAsyncEnumerable<TSource> _source;
  226. private IAsyncEnumerator<TSource> _enumerator;
  227. public DoAsyncIteratorWithTask(IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Exception, Task> onError, Func<Task> onCompleted)
  228. {
  229. Debug.Assert(source != null);
  230. Debug.Assert(onNext != null);
  231. _source = source;
  232. _onNext = onNext;
  233. _onError = onError;
  234. _onCompleted = onCompleted;
  235. }
  236. public override AsyncIteratorBase<TSource> Clone()
  237. {
  238. return new DoAsyncIteratorWithTask<TSource>(_source, _onNext, _onError, _onCompleted);
  239. }
  240. public override async ValueTask DisposeAsync()
  241. {
  242. if (_enumerator != null)
  243. {
  244. await _enumerator.DisposeAsync().ConfigureAwait(false);
  245. _enumerator = null;
  246. }
  247. await base.DisposeAsync().ConfigureAwait(false);
  248. }
  249. protected override async ValueTask<bool> MoveNextCore()
  250. {
  251. switch (_state)
  252. {
  253. case AsyncIteratorState.Allocated:
  254. _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
  255. _state = AsyncIteratorState.Iterating;
  256. goto case AsyncIteratorState.Iterating;
  257. case AsyncIteratorState.Iterating:
  258. try
  259. {
  260. if (await _enumerator.MoveNextAsync().ConfigureAwait(false))
  261. {
  262. _current = _enumerator.Current;
  263. await _onNext(_current).ConfigureAwait(false);
  264. return true;
  265. }
  266. }
  267. catch (OperationCanceledException)
  268. {
  269. throw;
  270. }
  271. catch (Exception ex) when (_onError != null)
  272. {
  273. await _onError(ex).ConfigureAwait(false);
  274. throw;
  275. }
  276. if (_onCompleted != null)
  277. {
  278. await _onCompleted().ConfigureAwait(false);
  279. }
  280. await DisposeAsync().ConfigureAwait(false);
  281. break;
  282. }
  283. return false;
  284. }
  285. }
  286. #if !NO_DEEP_CANCELLATION
  287. private sealed class DoAsyncIteratorWithTaskAndCancellation<TSource> : AsyncIterator<TSource>
  288. {
  289. private readonly Func<CancellationToken, Task> _onCompleted;
  290. private readonly Func<Exception, CancellationToken, Task> _onError;
  291. private readonly Func<TSource, CancellationToken, Task> _onNext;
  292. private readonly IAsyncEnumerable<TSource> _source;
  293. private IAsyncEnumerator<TSource> _enumerator;
  294. public DoAsyncIteratorWithTaskAndCancellation(IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, Task> onNext, Func<Exception, CancellationToken, Task> onError, Func<CancellationToken, Task> onCompleted)
  295. {
  296. Debug.Assert(source != null);
  297. Debug.Assert(onNext != null);
  298. _source = source;
  299. _onNext = onNext;
  300. _onError = onError;
  301. _onCompleted = onCompleted;
  302. }
  303. public override AsyncIteratorBase<TSource> Clone()
  304. {
  305. return new DoAsyncIteratorWithTaskAndCancellation<TSource>(_source, _onNext, _onError, _onCompleted);
  306. }
  307. public override async ValueTask DisposeAsync()
  308. {
  309. if (_enumerator != null)
  310. {
  311. await _enumerator.DisposeAsync().ConfigureAwait(false);
  312. _enumerator = null;
  313. }
  314. await base.DisposeAsync().ConfigureAwait(false);
  315. }
  316. protected override async ValueTask<bool> MoveNextCore()
  317. {
  318. switch (_state)
  319. {
  320. case AsyncIteratorState.Allocated:
  321. _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
  322. _state = AsyncIteratorState.Iterating;
  323. goto case AsyncIteratorState.Iterating;
  324. case AsyncIteratorState.Iterating:
  325. try
  326. {
  327. if (await _enumerator.MoveNextAsync().ConfigureAwait(false))
  328. {
  329. _current = _enumerator.Current;
  330. await _onNext(_current, _cancellationToken).ConfigureAwait(false);
  331. return true;
  332. }
  333. }
  334. catch (OperationCanceledException)
  335. {
  336. throw;
  337. }
  338. catch (Exception ex) when (_onError != null)
  339. {
  340. await _onError(ex, _cancellationToken).ConfigureAwait(false);
  341. throw;
  342. }
  343. if (_onCompleted != null)
  344. {
  345. await _onCompleted(_cancellationToken).ConfigureAwait(false);
  346. }
  347. await DisposeAsync().ConfigureAwait(false);
  348. break;
  349. }
  350. return false;
  351. }
  352. }
  353. #endif
  354. }
  355. }