1
0

Do.cs 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Diagnostics;
  6. using System.Threading;
  7. using System.Threading.Tasks;
  8. namespace System.Linq
  9. {
  10. public static partial class AsyncEnumerableEx
  11. {
  12. // REVIEW: Should we convert Task-based overloads to ValueTask?
  13. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext)
  14. {
  15. if (source == null)
  16. throw Error.ArgumentNull(nameof(source));
  17. if (onNext == null)
  18. throw Error.ArgumentNull(nameof(onNext));
  19. return DoCore(source, onNext: onNext, onError: null, onCompleted: null);
  20. }
  21. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action onCompleted)
  22. {
  23. if (source == null)
  24. throw Error.ArgumentNull(nameof(source));
  25. if (onNext == null)
  26. throw Error.ArgumentNull(nameof(onNext));
  27. if (onCompleted == null)
  28. throw Error.ArgumentNull(nameof(onCompleted));
  29. return DoCore(source, onNext: onNext, onError: null, onCompleted: onCompleted);
  30. }
  31. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError)
  32. {
  33. if (source == null)
  34. throw Error.ArgumentNull(nameof(source));
  35. if (onNext == null)
  36. throw Error.ArgumentNull(nameof(onNext));
  37. if (onError == null)
  38. throw Error.ArgumentNull(nameof(onError));
  39. return DoCore(source, onNext: onNext, onError: onError, onCompleted: null);
  40. }
  41. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError, Action onCompleted)
  42. {
  43. if (source == null)
  44. throw Error.ArgumentNull(nameof(source));
  45. if (onNext == null)
  46. throw Error.ArgumentNull(nameof(onNext));
  47. if (onError == null)
  48. throw Error.ArgumentNull(nameof(onError));
  49. if (onCompleted == null)
  50. throw Error.ArgumentNull(nameof(onCompleted));
  51. return DoCore(source, onNext, onError, onCompleted);
  52. }
  53. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext)
  54. {
  55. if (source == null)
  56. throw Error.ArgumentNull(nameof(source));
  57. if (onNext == null)
  58. throw Error.ArgumentNull(nameof(onNext));
  59. return DoCore(source, onNext: onNext, onError: null, onCompleted: null);
  60. }
  61. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Task> onCompleted)
  62. {
  63. if (source == null)
  64. throw Error.ArgumentNull(nameof(source));
  65. if (onNext == null)
  66. throw Error.ArgumentNull(nameof(onNext));
  67. if (onCompleted == null)
  68. throw Error.ArgumentNull(nameof(onCompleted));
  69. return DoCore(source, onNext: onNext, onError: null, onCompleted: onCompleted);
  70. }
  71. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Exception, Task> onError)
  72. {
  73. if (source == null)
  74. throw Error.ArgumentNull(nameof(source));
  75. if (onNext == null)
  76. throw Error.ArgumentNull(nameof(onNext));
  77. if (onError == null)
  78. throw Error.ArgumentNull(nameof(onError));
  79. return DoCore(source, onNext: onNext, onError: onError, onCompleted: null);
  80. }
  81. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Exception, Task> onError, Func<Task> onCompleted)
  82. {
  83. if (source == null)
  84. throw Error.ArgumentNull(nameof(source));
  85. if (onNext == null)
  86. throw Error.ArgumentNull(nameof(onNext));
  87. if (onError == null)
  88. throw Error.ArgumentNull(nameof(onError));
  89. if (onCompleted == null)
  90. throw Error.ArgumentNull(nameof(onCompleted));
  91. return DoCore(source, onNext, onError, onCompleted);
  92. }
  93. #if !NO_DEEP_CANCELLATION
  94. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, Task> onNext)
  95. {
  96. if (source == null)
  97. throw Error.ArgumentNull(nameof(source));
  98. if (onNext == null)
  99. throw Error.ArgumentNull(nameof(onNext));
  100. return DoCore(source, onNext: onNext, onError: null, onCompleted: null);
  101. }
  102. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, Task> onNext, Func<CancellationToken, Task> onCompleted)
  103. {
  104. if (source == null)
  105. throw Error.ArgumentNull(nameof(source));
  106. if (onNext == null)
  107. throw Error.ArgumentNull(nameof(onNext));
  108. if (onCompleted == null)
  109. throw Error.ArgumentNull(nameof(onCompleted));
  110. return DoCore(source, onNext: onNext, onError: null, onCompleted: onCompleted);
  111. }
  112. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, Task> onNext, Func<Exception, CancellationToken, Task> onError)
  113. {
  114. if (source == null)
  115. throw Error.ArgumentNull(nameof(source));
  116. if (onNext == null)
  117. throw Error.ArgumentNull(nameof(onNext));
  118. if (onError == null)
  119. throw Error.ArgumentNull(nameof(onError));
  120. return DoCore(source, onNext: onNext, onError: onError, onCompleted: null);
  121. }
  122. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, Task> onNext, Func<Exception, CancellationToken, Task> onError, Func<CancellationToken, Task> onCompleted)
  123. {
  124. if (source == null)
  125. throw Error.ArgumentNull(nameof(source));
  126. if (onNext == null)
  127. throw Error.ArgumentNull(nameof(onNext));
  128. if (onError == null)
  129. throw Error.ArgumentNull(nameof(onError));
  130. if (onCompleted == null)
  131. throw Error.ArgumentNull(nameof(onCompleted));
  132. return DoCore(source, onNext, onError, onCompleted);
  133. }
  134. #endif
  135. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, IObserver<TSource> observer)
  136. {
  137. if (source == null)
  138. throw Error.ArgumentNull(nameof(source));
  139. if (observer == null)
  140. throw Error.ArgumentNull(nameof(observer));
  141. return DoCore(source, new Action<TSource>(observer.OnNext), new Action<Exception>(observer.OnError), new Action(observer.OnCompleted));
  142. }
  143. private static IAsyncEnumerable<TSource> DoCore<TSource>(IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError, Action onCompleted)
  144. {
  145. #if USE_ASYNC_ITERATOR
  146. return AsyncEnumerable.Create(Core);
  147. async IAsyncEnumerator<TSource> Core(CancellationToken cancellationToken)
  148. {
  149. var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false);
  150. try // TODO: Switch to `await using` in preview 3 (https://github.com/dotnet/roslyn/pull/32731)
  151. {
  152. while (true)
  153. {
  154. TSource item;
  155. try
  156. {
  157. if (!await e.MoveNextAsync())
  158. {
  159. break;
  160. }
  161. item = e.Current;
  162. onNext(item);
  163. }
  164. catch (OperationCanceledException)
  165. {
  166. throw;
  167. }
  168. catch (Exception ex) when (onError != null)
  169. {
  170. onError(ex);
  171. throw;
  172. }
  173. yield return item;
  174. }
  175. onCompleted?.Invoke();
  176. }
  177. finally
  178. {
  179. await e.DisposeAsync();
  180. }
  181. }
  182. #else
  183. return new DoAsyncIterator<TSource>(source, onNext, onError, onCompleted);
  184. #endif
  185. }
  186. private static IAsyncEnumerable<TSource> DoCore<TSource>(IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Exception, Task> onError, Func<Task> onCompleted)
  187. {
  188. #if USE_ASYNC_ITERATOR
  189. return AsyncEnumerable.Create(Core);
  190. async IAsyncEnumerator<TSource> Core(CancellationToken cancellationToken)
  191. {
  192. var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false);
  193. try // TODO: Switch to `await using` in preview 3 (https://github.com/dotnet/roslyn/pull/32731)
  194. {
  195. while (true)
  196. {
  197. TSource item;
  198. try
  199. {
  200. if (!await e.MoveNextAsync())
  201. {
  202. break;
  203. }
  204. item = e.Current;
  205. await onNext(item).ConfigureAwait(false);
  206. }
  207. catch (OperationCanceledException)
  208. {
  209. throw;
  210. }
  211. catch (Exception ex) when (onError != null)
  212. {
  213. await onError(ex).ConfigureAwait(false);
  214. throw;
  215. }
  216. yield return item;
  217. }
  218. if (onCompleted != null)
  219. {
  220. await onCompleted().ConfigureAwait(false);
  221. }
  222. }
  223. finally
  224. {
  225. await e.DisposeAsync();
  226. }
  227. }
  228. #else
  229. return new DoAsyncIteratorWithTask<TSource>(source, onNext, onError, onCompleted);
  230. #endif
  231. }
  232. #if !NO_DEEP_CANCELLATION
  233. private static IAsyncEnumerable<TSource> DoCore<TSource>(IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, Task> onNext, Func<Exception, CancellationToken, Task> onError, Func<CancellationToken, Task> onCompleted)
  234. {
  235. #if USE_ASYNC_ITERATOR
  236. return AsyncEnumerable.Create(Core);
  237. async IAsyncEnumerator<TSource> Core(CancellationToken cancellationToken)
  238. {
  239. var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false);
  240. try // TODO: Switch to `await using` in preview 3 (https://github.com/dotnet/roslyn/pull/32731)
  241. {
  242. while (true)
  243. {
  244. TSource item;
  245. try
  246. {
  247. if (!await e.MoveNextAsync())
  248. {
  249. break;
  250. }
  251. item = e.Current;
  252. await onNext(item, cancellationToken).ConfigureAwait(false);
  253. }
  254. catch (OperationCanceledException)
  255. {
  256. throw;
  257. }
  258. catch (Exception ex) when (onError != null)
  259. {
  260. await onError(ex, cancellationToken).ConfigureAwait(false);
  261. throw;
  262. }
  263. yield return item;
  264. }
  265. if (onCompleted != null)
  266. {
  267. await onCompleted(cancellationToken).ConfigureAwait(false);
  268. }
  269. }
  270. finally
  271. {
  272. await e.DisposeAsync();
  273. }
  274. }
  275. #else
  276. return new DoAsyncIteratorWithTaskAndCancellation<TSource>(source, onNext, onError, onCompleted);
  277. #endif
  278. }
  279. #endif
  280. #if !USE_ASYNC_ITERATOR
  281. private sealed class DoAsyncIterator<TSource> : AsyncIterator<TSource>
  282. {
  283. private readonly Action _onCompleted;
  284. private readonly Action<Exception> _onError;
  285. private readonly Action<TSource> _onNext;
  286. private readonly IAsyncEnumerable<TSource> _source;
  287. private IAsyncEnumerator<TSource> _enumerator;
  288. public DoAsyncIterator(IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError, Action onCompleted)
  289. {
  290. Debug.Assert(source != null);
  291. Debug.Assert(onNext != null);
  292. _source = source;
  293. _onNext = onNext;
  294. _onError = onError;
  295. _onCompleted = onCompleted;
  296. }
  297. public override AsyncIteratorBase<TSource> Clone()
  298. {
  299. return new DoAsyncIterator<TSource>(_source, _onNext, _onError, _onCompleted);
  300. }
  301. public override async ValueTask DisposeAsync()
  302. {
  303. if (_enumerator != null)
  304. {
  305. await _enumerator.DisposeAsync().ConfigureAwait(false);
  306. _enumerator = null;
  307. }
  308. await base.DisposeAsync().ConfigureAwait(false);
  309. }
  310. protected override async ValueTask<bool> MoveNextCore()
  311. {
  312. switch (_state)
  313. {
  314. case AsyncIteratorState.Allocated:
  315. _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
  316. _state = AsyncIteratorState.Iterating;
  317. goto case AsyncIteratorState.Iterating;
  318. case AsyncIteratorState.Iterating:
  319. try
  320. {
  321. if (await _enumerator.MoveNextAsync().ConfigureAwait(false))
  322. {
  323. _current = _enumerator.Current;
  324. _onNext(_current);
  325. return true;
  326. }
  327. }
  328. catch (OperationCanceledException)
  329. {
  330. throw;
  331. }
  332. catch (Exception ex) when (_onError != null)
  333. {
  334. _onError(ex);
  335. throw;
  336. }
  337. _onCompleted?.Invoke();
  338. await DisposeAsync().ConfigureAwait(false);
  339. break;
  340. }
  341. return false;
  342. }
  343. }
  344. private sealed class DoAsyncIteratorWithTask<TSource> : AsyncIterator<TSource>
  345. {
  346. private readonly Func<Task> _onCompleted;
  347. private readonly Func<Exception, Task> _onError;
  348. private readonly Func<TSource, Task> _onNext;
  349. private readonly IAsyncEnumerable<TSource> _source;
  350. private IAsyncEnumerator<TSource> _enumerator;
  351. public DoAsyncIteratorWithTask(IAsyncEnumerable<TSource> source, Func<TSource, Task> onNext, Func<Exception, Task> onError, Func<Task> onCompleted)
  352. {
  353. Debug.Assert(source != null);
  354. Debug.Assert(onNext != null);
  355. _source = source;
  356. _onNext = onNext;
  357. _onError = onError;
  358. _onCompleted = onCompleted;
  359. }
  360. public override AsyncIteratorBase<TSource> Clone()
  361. {
  362. return new DoAsyncIteratorWithTask<TSource>(_source, _onNext, _onError, _onCompleted);
  363. }
  364. public override async ValueTask DisposeAsync()
  365. {
  366. if (_enumerator != null)
  367. {
  368. await _enumerator.DisposeAsync().ConfigureAwait(false);
  369. _enumerator = null;
  370. }
  371. await base.DisposeAsync().ConfigureAwait(false);
  372. }
  373. protected override async ValueTask<bool> MoveNextCore()
  374. {
  375. switch (_state)
  376. {
  377. case AsyncIteratorState.Allocated:
  378. _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
  379. _state = AsyncIteratorState.Iterating;
  380. goto case AsyncIteratorState.Iterating;
  381. case AsyncIteratorState.Iterating:
  382. try
  383. {
  384. if (await _enumerator.MoveNextAsync().ConfigureAwait(false))
  385. {
  386. _current = _enumerator.Current;
  387. await _onNext(_current).ConfigureAwait(false);
  388. return true;
  389. }
  390. }
  391. catch (OperationCanceledException)
  392. {
  393. throw;
  394. }
  395. catch (Exception ex) when (_onError != null)
  396. {
  397. await _onError(ex).ConfigureAwait(false);
  398. throw;
  399. }
  400. if (_onCompleted != null)
  401. {
  402. await _onCompleted().ConfigureAwait(false);
  403. }
  404. await DisposeAsync().ConfigureAwait(false);
  405. break;
  406. }
  407. return false;
  408. }
  409. }
  410. #if !NO_DEEP_CANCELLATION
  411. private sealed class DoAsyncIteratorWithTaskAndCancellation<TSource> : AsyncIterator<TSource>
  412. {
  413. private readonly Func<CancellationToken, Task> _onCompleted;
  414. private readonly Func<Exception, CancellationToken, Task> _onError;
  415. private readonly Func<TSource, CancellationToken, Task> _onNext;
  416. private readonly IAsyncEnumerable<TSource> _source;
  417. private IAsyncEnumerator<TSource> _enumerator;
  418. public DoAsyncIteratorWithTaskAndCancellation(IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, Task> onNext, Func<Exception, CancellationToken, Task> onError, Func<CancellationToken, Task> onCompleted)
  419. {
  420. Debug.Assert(source != null);
  421. Debug.Assert(onNext != null);
  422. _source = source;
  423. _onNext = onNext;
  424. _onError = onError;
  425. _onCompleted = onCompleted;
  426. }
  427. public override AsyncIteratorBase<TSource> Clone()
  428. {
  429. return new DoAsyncIteratorWithTaskAndCancellation<TSource>(_source, _onNext, _onError, _onCompleted);
  430. }
  431. public override async ValueTask DisposeAsync()
  432. {
  433. if (_enumerator != null)
  434. {
  435. await _enumerator.DisposeAsync().ConfigureAwait(false);
  436. _enumerator = null;
  437. }
  438. await base.DisposeAsync().ConfigureAwait(false);
  439. }
  440. protected override async ValueTask<bool> MoveNextCore()
  441. {
  442. switch (_state)
  443. {
  444. case AsyncIteratorState.Allocated:
  445. _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
  446. _state = AsyncIteratorState.Iterating;
  447. goto case AsyncIteratorState.Iterating;
  448. case AsyncIteratorState.Iterating:
  449. try
  450. {
  451. if (await _enumerator.MoveNextAsync().ConfigureAwait(false))
  452. {
  453. _current = _enumerator.Current;
  454. await _onNext(_current, _cancellationToken).ConfigureAwait(false);
  455. return true;
  456. }
  457. }
  458. catch (OperationCanceledException)
  459. {
  460. throw;
  461. }
  462. catch (Exception ex) when (_onError != null)
  463. {
  464. await _onError(ex, _cancellationToken).ConfigureAwait(false);
  465. throw;
  466. }
  467. if (_onCompleted != null)
  468. {
  469. await _onCompleted(_cancellationToken).ConfigureAwait(false);
  470. }
  471. await DisposeAsync().ConfigureAwait(false);
  472. break;
  473. }
  474. return false;
  475. }
  476. }
  477. #endif
  478. #endif
  479. }
  480. }