Do.cs 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Diagnostics;
  6. using System.Threading;
  7. using System.Threading.Tasks;
  8. namespace System.Linq
  9. {
  10. public static partial class AsyncEnumerableEx
  11. {
  12. // REVIEW: Should we convert Task-based overloads to ValueTask?
  13. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext)
  14. {
  15. if (source == null)
  16. throw Error.ArgumentNull(nameof(source));
  17. if (onNext == null)
  18. throw Error.ArgumentNull(nameof(onNext));
  19. return DoCore(source, onNext: onNext, onError: null, onCompleted: null);
  20. }
  21. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action onCompleted)
  22. {
  23. if (source == null)
  24. throw Error.ArgumentNull(nameof(source));
  25. if (onNext == null)
  26. throw Error.ArgumentNull(nameof(onNext));
  27. if (onCompleted == null)
  28. throw Error.ArgumentNull(nameof(onCompleted));
  29. return DoCore(source, onNext: onNext, onError: null, onCompleted: onCompleted);
  30. }
  31. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError)
  32. {
  33. if (source == null)
  34. throw Error.ArgumentNull(nameof(source));
  35. if (onNext == null)
  36. throw Error.ArgumentNull(nameof(onNext));
  37. if (onError == null)
  38. throw Error.ArgumentNull(nameof(onError));
  39. return DoCore(source, onNext: onNext, onError: onError, onCompleted: null);
  40. }
  41. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError, Action onCompleted)
  42. {
  43. if (source == null)
  44. throw Error.ArgumentNull(nameof(source));
  45. if (onNext == null)
  46. throw Error.ArgumentNull(nameof(onNext));
  47. if (onError == null)
  48. throw Error.ArgumentNull(nameof(onError));
  49. if (onCompleted == null)
  50. throw Error.ArgumentNull(nameof(onCompleted));
  51. return DoCore(source, onNext, onError, onCompleted);
  52. }
  53. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext)
  54. {
  55. if (source == null)
  56. throw Error.ArgumentNull(nameof(source));
  57. if (onNext == null)
  58. throw Error.ArgumentNull(nameof(onNext));
  59. return DoCore(source, onNext: onNext, onError: null, onCompleted: null);
  60. }
  61. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Task> onCompleted)
  62. {
  63. if (source == null)
  64. throw Error.ArgumentNull(nameof(source));
  65. if (onNext == null)
  66. throw Error.ArgumentNull(nameof(onNext));
  67. if (onCompleted == null)
  68. throw Error.ArgumentNull(nameof(onCompleted));
  69. return DoCore(source, onNext: onNext, onError: null, onCompleted: onCompleted);
  70. }
  71. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Exception, Task> onError)
  72. {
  73. if (source == null)
  74. throw Error.ArgumentNull(nameof(source));
  75. if (onNext == null)
  76. throw Error.ArgumentNull(nameof(onNext));
  77. if (onError == null)
  78. throw Error.ArgumentNull(nameof(onError));
  79. return DoCore(source, onNext: onNext, onError: onError, onCompleted: null);
  80. }
  81. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Exception, Task> onError, Func<Task> onCompleted)
  82. {
  83. if (source == null)
  84. throw Error.ArgumentNull(nameof(source));
  85. if (onNext == null)
  86. throw Error.ArgumentNull(nameof(onNext));
  87. if (onError == null)
  88. throw Error.ArgumentNull(nameof(onError));
  89. if (onCompleted == null)
  90. throw Error.ArgumentNull(nameof(onCompleted));
  91. return DoCore(source, onNext, onError, onCompleted);
  92. }
  93. #if !NO_DEEP_CANCELLATION
  94. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, Task> onNext)
  95. {
  96. if (source == null)
  97. throw Error.ArgumentNull(nameof(source));
  98. if (onNext == null)
  99. throw Error.ArgumentNull(nameof(onNext));
  100. return DoCore(source, onNext: onNext, onError: null, onCompleted: null);
  101. }
  102. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, Task> onNext, Func<CancellationToken, Task> onCompleted)
  103. {
  104. if (source == null)
  105. throw Error.ArgumentNull(nameof(source));
  106. if (onNext == null)
  107. throw Error.ArgumentNull(nameof(onNext));
  108. if (onCompleted == null)
  109. throw Error.ArgumentNull(nameof(onCompleted));
  110. return DoCore(source, onNext: onNext, onError: null, onCompleted: onCompleted);
  111. }
  112. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, Task> onNext, Func<Exception, CancellationToken, Task> onError)
  113. {
  114. if (source == null)
  115. throw Error.ArgumentNull(nameof(source));
  116. if (onNext == null)
  117. throw Error.ArgumentNull(nameof(onNext));
  118. if (onError == null)
  119. throw Error.ArgumentNull(nameof(onError));
  120. return DoCore(source, onNext: onNext, onError: onError, onCompleted: null);
  121. }
  122. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, Task> onNext, Func<Exception, CancellationToken, Task> onError, Func<CancellationToken, Task> onCompleted)
  123. {
  124. if (source == null)
  125. throw Error.ArgumentNull(nameof(source));
  126. if (onNext == null)
  127. throw Error.ArgumentNull(nameof(onNext));
  128. if (onError == null)
  129. throw Error.ArgumentNull(nameof(onError));
  130. if (onCompleted == null)
  131. throw Error.ArgumentNull(nameof(onCompleted));
  132. return DoCore(source, onNext, onError, onCompleted);
  133. }
  134. #endif
  135. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, IObserver<TSource> observer)
  136. {
  137. if (source == null)
  138. throw Error.ArgumentNull(nameof(source));
  139. if (observer == null)
  140. throw Error.ArgumentNull(nameof(observer));
  141. return DoCore(source, new Action<TSource>(observer.OnNext), new Action<Exception>(observer.OnError), new Action(observer.OnCompleted));
  142. }
  143. private static IAsyncEnumerable<TSource> DoCore<TSource>(IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError, Action onCompleted)
  144. {
  145. #if USE_ASYNC_ITERATOR
  146. return AsyncEnumerable.Create(Core);
  147. async IAsyncEnumerator<TSource> Core(CancellationToken cancellationToken)
  148. {
  149. await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false))
  150. {
  151. while (true)
  152. {
  153. TSource item;
  154. try
  155. {
  156. if (!await e.MoveNextAsync())
  157. {
  158. break;
  159. }
  160. item = e.Current;
  161. onNext(item);
  162. }
  163. catch (OperationCanceledException)
  164. {
  165. throw;
  166. }
  167. catch (Exception ex) when (onError != null)
  168. {
  169. onError(ex);
  170. throw;
  171. }
  172. yield return item;
  173. }
  174. onCompleted?.Invoke();
  175. }
  176. }
  177. #else
  178. return new DoAsyncIterator<TSource>(source, onNext, onError, onCompleted);
  179. #endif
  180. }
  181. private static IAsyncEnumerable<TSource> DoCore<TSource>(IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Exception, Task> onError, Func<Task> onCompleted)
  182. {
  183. #if USE_ASYNC_ITERATOR
  184. return AsyncEnumerable.Create(Core);
  185. async IAsyncEnumerator<TSource> Core(CancellationToken cancellationToken)
  186. {
  187. await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false))
  188. {
  189. while (true)
  190. {
  191. TSource item;
  192. try
  193. {
  194. if (!await e.MoveNextAsync())
  195. {
  196. break;
  197. }
  198. item = e.Current;
  199. await onNext(item).ConfigureAwait(false);
  200. }
  201. catch (OperationCanceledException)
  202. {
  203. throw;
  204. }
  205. catch (Exception ex) when (onError != null)
  206. {
  207. await onError(ex).ConfigureAwait(false);
  208. throw;
  209. }
  210. yield return item;
  211. }
  212. if (onCompleted != null)
  213. {
  214. await onCompleted().ConfigureAwait(false);
  215. }
  216. }
  217. }
  218. #else
  219. return new DoAsyncIteratorWithTask<TSource>(source, onNext, onError, onCompleted);
  220. #endif
  221. }
  222. #if !NO_DEEP_CANCELLATION
  223. private static IAsyncEnumerable<TSource> DoCore<TSource>(IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, Task> onNext, Func<Exception, CancellationToken, Task> onError, Func<CancellationToken, Task> onCompleted)
  224. {
  225. #if USE_ASYNC_ITERATOR
  226. return AsyncEnumerable.Create(Core);
  227. async IAsyncEnumerator<TSource> Core(CancellationToken cancellationToken)
  228. {
  229. await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false))
  230. {
  231. while (true)
  232. {
  233. TSource item;
  234. try
  235. {
  236. if (!await e.MoveNextAsync())
  237. {
  238. break;
  239. }
  240. item = e.Current;
  241. await onNext(item, cancellationToken).ConfigureAwait(false);
  242. }
  243. catch (OperationCanceledException)
  244. {
  245. throw;
  246. }
  247. catch (Exception ex) when (onError != null)
  248. {
  249. await onError(ex, cancellationToken).ConfigureAwait(false);
  250. throw;
  251. }
  252. yield return item;
  253. }
  254. if (onCompleted != null)
  255. {
  256. await onCompleted(cancellationToken).ConfigureAwait(false);
  257. }
  258. }
  259. }
  260. #else
  261. return new DoAsyncIteratorWithTaskAndCancellation<TSource>(source, onNext, onError, onCompleted);
  262. #endif
  263. }
  264. #endif
  265. #if !USE_ASYNC_ITERATOR
  266. private sealed class DoAsyncIterator<TSource> : AsyncIterator<TSource>
  267. {
  268. private readonly Action _onCompleted;
  269. private readonly Action<Exception> _onError;
  270. private readonly Action<TSource> _onNext;
  271. private readonly IAsyncEnumerable<TSource> _source;
  272. private IAsyncEnumerator<TSource> _enumerator;
  273. public DoAsyncIterator(IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError, Action onCompleted)
  274. {
  275. Debug.Assert(source != null);
  276. Debug.Assert(onNext != null);
  277. _source = source;
  278. _onNext = onNext;
  279. _onError = onError;
  280. _onCompleted = onCompleted;
  281. }
  282. public override AsyncIteratorBase<TSource> Clone()
  283. {
  284. return new DoAsyncIterator<TSource>(_source, _onNext, _onError, _onCompleted);
  285. }
  286. public override async ValueTask DisposeAsync()
  287. {
  288. if (_enumerator != null)
  289. {
  290. await _enumerator.DisposeAsync().ConfigureAwait(false);
  291. _enumerator = null;
  292. }
  293. await base.DisposeAsync().ConfigureAwait(false);
  294. }
  295. protected override async ValueTask<bool> MoveNextCore()
  296. {
  297. switch (_state)
  298. {
  299. case AsyncIteratorState.Allocated:
  300. _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
  301. _state = AsyncIteratorState.Iterating;
  302. goto case AsyncIteratorState.Iterating;
  303. case AsyncIteratorState.Iterating:
  304. try
  305. {
  306. if (await _enumerator.MoveNextAsync().ConfigureAwait(false))
  307. {
  308. _current = _enumerator.Current;
  309. _onNext(_current);
  310. return true;
  311. }
  312. }
  313. catch (OperationCanceledException)
  314. {
  315. throw;
  316. }
  317. catch (Exception ex) when (_onError != null)
  318. {
  319. _onError(ex);
  320. throw;
  321. }
  322. _onCompleted?.Invoke();
  323. await DisposeAsync().ConfigureAwait(false);
  324. break;
  325. }
  326. return false;
  327. }
  328. }
  329. private sealed class DoAsyncIteratorWithTask<TSource> : AsyncIterator<TSource>
  330. {
  331. private readonly Func<Task> _onCompleted;
  332. private readonly Func<Exception, Task> _onError;
  333. private readonly Func<TSource, Task> _onNext;
  334. private readonly IAsyncEnumerable<TSource> _source;
  335. private IAsyncEnumerator<TSource> _enumerator;
  336. public DoAsyncIteratorWithTask(IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Exception, Task> onError, Func<Task> onCompleted)
  337. {
  338. Debug.Assert(source != null);
  339. Debug.Assert(onNext != null);
  340. _source = source;
  341. _onNext = onNext;
  342. _onError = onError;
  343. _onCompleted = onCompleted;
  344. }
  345. public override AsyncIteratorBase<TSource> Clone()
  346. {
  347. return new DoAsyncIteratorWithTask<TSource>(_source, _onNext, _onError, _onCompleted);
  348. }
  349. public override async ValueTask DisposeAsync()
  350. {
  351. if (_enumerator != null)
  352. {
  353. await _enumerator.DisposeAsync().ConfigureAwait(false);
  354. _enumerator = null;
  355. }
  356. await base.DisposeAsync().ConfigureAwait(false);
  357. }
  358. protected override async ValueTask<bool> MoveNextCore()
  359. {
  360. switch (_state)
  361. {
  362. case AsyncIteratorState.Allocated:
  363. _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
  364. _state = AsyncIteratorState.Iterating;
  365. goto case AsyncIteratorState.Iterating;
  366. case AsyncIteratorState.Iterating:
  367. try
  368. {
  369. if (await _enumerator.MoveNextAsync().ConfigureAwait(false))
  370. {
  371. _current = _enumerator.Current;
  372. await _onNext(_current).ConfigureAwait(false);
  373. return true;
  374. }
  375. }
  376. catch (OperationCanceledException)
  377. {
  378. throw;
  379. }
  380. catch (Exception ex) when (_onError != null)
  381. {
  382. await _onError(ex).ConfigureAwait(false);
  383. throw;
  384. }
  385. if (_onCompleted != null)
  386. {
  387. await _onCompleted().ConfigureAwait(false);
  388. }
  389. await DisposeAsync().ConfigureAwait(false);
  390. break;
  391. }
  392. return false;
  393. }
  394. }
  395. #if !NO_DEEP_CANCELLATION
  396. private sealed class DoAsyncIteratorWithTaskAndCancellation<TSource> : AsyncIterator<TSource>
  397. {
  398. private readonly Func<CancellationToken, Task> _onCompleted;
  399. private readonly Func<Exception, CancellationToken, Task> _onError;
  400. private readonly Func<TSource, CancellationToken, Task> _onNext;
  401. private readonly IAsyncEnumerable<TSource> _source;
  402. private IAsyncEnumerator<TSource> _enumerator;
  403. public DoAsyncIteratorWithTaskAndCancellation(IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, Task> onNext, Func<Exception, CancellationToken, Task> onError, Func<CancellationToken, Task> onCompleted)
  404. {
  405. Debug.Assert(source != null);
  406. Debug.Assert(onNext != null);
  407. _source = source;
  408. _onNext = onNext;
  409. _onError = onError;
  410. _onCompleted = onCompleted;
  411. }
  412. public override AsyncIteratorBase<TSource> Clone()
  413. {
  414. return new DoAsyncIteratorWithTaskAndCancellation<TSource>(_source, _onNext, _onError, _onCompleted);
  415. }
  416. public override async ValueTask DisposeAsync()
  417. {
  418. if (_enumerator != null)
  419. {
  420. await _enumerator.DisposeAsync().ConfigureAwait(false);
  421. _enumerator = null;
  422. }
  423. await base.DisposeAsync().ConfigureAwait(false);
  424. }
  425. protected override async ValueTask<bool> MoveNextCore()
  426. {
  427. switch (_state)
  428. {
  429. case AsyncIteratorState.Allocated:
  430. _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
  431. _state = AsyncIteratorState.Iterating;
  432. goto case AsyncIteratorState.Iterating;
  433. case AsyncIteratorState.Iterating:
  434. try
  435. {
  436. if (await _enumerator.MoveNextAsync().ConfigureAwait(false))
  437. {
  438. _current = _enumerator.Current;
  439. await _onNext(_current, _cancellationToken).ConfigureAwait(false);
  440. return true;
  441. }
  442. }
  443. catch (OperationCanceledException)
  444. {
  445. throw;
  446. }
  447. catch (Exception ex) when (_onError != null)
  448. {
  449. await _onError(ex, _cancellationToken).ConfigureAwait(false);
  450. throw;
  451. }
  452. if (_onCompleted != null)
  453. {
  454. await _onCompleted(_cancellationToken).ConfigureAwait(false);
  455. }
  456. await DisposeAsync().ConfigureAwait(false);
  457. break;
  458. }
  459. return false;
  460. }
  461. }
  462. #endif
  463. #endif
  464. }
  465. }