Catch.cs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Diagnostics;
  6. using System.Runtime.ExceptionServices;
  7. using System.Threading;
  8. using System.Threading.Tasks;
  9. namespace System.Linq
  10. {
  11. public static partial class AsyncEnumerableEx
  12. {
  13. public static IAsyncEnumerable<TSource> Catch<TSource, TException>(this IAsyncEnumerable<TSource> source, Func<TException, IAsyncEnumerable<TSource>> handler)
  14. where TException : Exception
  15. {
  16. if (source == null)
  17. throw new ArgumentNullException(nameof(source));
  18. if (handler == null)
  19. throw new ArgumentNullException(nameof(handler));
  20. return new CatchAsyncIterator<TSource, TException>(source, handler);
  21. }
  22. public static IAsyncEnumerable<TSource> Catch<TSource, TException>(this IAsyncEnumerable<TSource> source, Func<TException, Task<IAsyncEnumerable<TSource>>> handler)
  23. where TException : Exception
  24. {
  25. if (source == null)
  26. throw new ArgumentNullException(nameof(source));
  27. if (handler == null)
  28. throw new ArgumentNullException(nameof(handler));
  29. return new CatchAsyncIteratorWithTask<TSource, TException>(source, handler);
  30. }
  31. public static IAsyncEnumerable<TSource> Catch<TSource>(this IEnumerable<IAsyncEnumerable<TSource>> sources)
  32. {
  33. if (sources == null)
  34. throw new ArgumentNullException(nameof(sources));
  35. return CatchCore(sources);
  36. }
  37. public static IAsyncEnumerable<TSource> Catch<TSource>(params IAsyncEnumerable<TSource>[] sources)
  38. {
  39. if (sources == null)
  40. throw new ArgumentNullException(nameof(sources));
  41. return CatchCore(sources);
  42. }
  43. public static IAsyncEnumerable<TSource> Catch<TSource>(this IAsyncEnumerable<TSource> first, IAsyncEnumerable<TSource> second)
  44. {
  45. if (first == null)
  46. throw new ArgumentNullException(nameof(first));
  47. if (second == null)
  48. throw new ArgumentNullException(nameof(second));
  49. return CatchCore(new[] { first, second });
  50. }
  51. private static IAsyncEnumerable<TSource> CatchCore<TSource>(IEnumerable<IAsyncEnumerable<TSource>> sources)
  52. {
  53. return new CatchAsyncIterator<TSource>(sources);
  54. }
  55. private sealed class CatchAsyncIterator<TSource, TException> : AsyncIterator<TSource> where TException : Exception
  56. {
  57. private readonly Func<TException, IAsyncEnumerable<TSource>> handler;
  58. private readonly IAsyncEnumerable<TSource> source;
  59. private IAsyncEnumerator<TSource> enumerator;
  60. private bool isDone;
  61. public CatchAsyncIterator(IAsyncEnumerable<TSource> source, Func<TException, IAsyncEnumerable<TSource>> handler)
  62. {
  63. Debug.Assert(source != null);
  64. Debug.Assert(handler != null);
  65. this.source = source;
  66. this.handler = handler;
  67. }
  68. public override AsyncIterator<TSource> Clone()
  69. {
  70. return new CatchAsyncIterator<TSource, TException>(source, handler);
  71. }
  72. public override async ValueTask DisposeAsync()
  73. {
  74. if (enumerator != null)
  75. {
  76. await enumerator.DisposeAsync().ConfigureAwait(false);
  77. enumerator = null;
  78. }
  79. await base.DisposeAsync().ConfigureAwait(false);
  80. }
  81. protected override async ValueTask<bool> MoveNextCore(CancellationToken cancellationToken)
  82. {
  83. switch (state)
  84. {
  85. case AsyncIteratorState.Allocated:
  86. enumerator = source.GetAsyncEnumerator(cancellationToken);
  87. isDone = false;
  88. state = AsyncIteratorState.Iterating;
  89. goto case AsyncIteratorState.Iterating;
  90. case AsyncIteratorState.Iterating:
  91. while (true)
  92. {
  93. if (!isDone)
  94. {
  95. try
  96. {
  97. if (await enumerator.MoveNextAsync().ConfigureAwait(false))
  98. {
  99. current = enumerator.Current;
  100. return true;
  101. }
  102. }
  103. catch (TException ex)
  104. {
  105. // Note: Ideally we'd dipose of the previous enumerator before
  106. // invoking the handler, but we use this order to preserve
  107. // current behavior
  108. var inner = handler(ex);
  109. var err = inner.GetAsyncEnumerator(cancellationToken);
  110. if (enumerator != null)
  111. {
  112. await enumerator.DisposeAsync().ConfigureAwait(false);
  113. }
  114. enumerator = err;
  115. isDone = true;
  116. continue; // loop so we hit the catch state
  117. }
  118. }
  119. if (await enumerator.MoveNextAsync().ConfigureAwait(false))
  120. {
  121. current = enumerator.Current;
  122. return true;
  123. }
  124. break; // while
  125. }
  126. break; // case
  127. }
  128. await DisposeAsync().ConfigureAwait(false);
  129. return false;
  130. }
  131. }
  132. private sealed class CatchAsyncIteratorWithTask<TSource, TException> : AsyncIterator<TSource> where TException : Exception
  133. {
  134. private readonly Func<TException, Task<IAsyncEnumerable<TSource>>> handler;
  135. private readonly IAsyncEnumerable<TSource> source;
  136. private IAsyncEnumerator<TSource> enumerator;
  137. private bool isDone;
  138. public CatchAsyncIteratorWithTask(IAsyncEnumerable<TSource> source, Func<TException, Task<IAsyncEnumerable<TSource>>> handler)
  139. {
  140. Debug.Assert(source != null);
  141. Debug.Assert(handler != null);
  142. this.source = source;
  143. this.handler = handler;
  144. }
  145. public override AsyncIterator<TSource> Clone()
  146. {
  147. return new CatchAsyncIteratorWithTask<TSource, TException>(source, handler);
  148. }
  149. public override async ValueTask DisposeAsync()
  150. {
  151. if (enumerator != null)
  152. {
  153. await enumerator.DisposeAsync().ConfigureAwait(false);
  154. enumerator = null;
  155. }
  156. await base.DisposeAsync().ConfigureAwait(false);
  157. }
  158. protected override async ValueTask<bool> MoveNextCore(CancellationToken cancellationToken)
  159. {
  160. switch (state)
  161. {
  162. case AsyncIteratorState.Allocated:
  163. enumerator = source.GetAsyncEnumerator(cancellationToken);
  164. isDone = false;
  165. state = AsyncIteratorState.Iterating;
  166. goto case AsyncIteratorState.Iterating;
  167. case AsyncIteratorState.Iterating:
  168. while (true)
  169. {
  170. if (!isDone)
  171. {
  172. try
  173. {
  174. if (await enumerator.MoveNextAsync().ConfigureAwait(false))
  175. {
  176. current = enumerator.Current;
  177. return true;
  178. }
  179. }
  180. catch (TException ex)
  181. {
  182. // Note: Ideally we'd dipose of the previous enumerator before
  183. // invoking the handler, but we use this order to preserve
  184. // current behavior
  185. var inner = await handler(ex).ConfigureAwait(false);
  186. var err = inner.GetAsyncEnumerator(cancellationToken);
  187. if (enumerator != null)
  188. {
  189. await enumerator.DisposeAsync().ConfigureAwait(false);
  190. }
  191. enumerator = err;
  192. isDone = true;
  193. continue; // loop so we hit the catch state
  194. }
  195. }
  196. if (await enumerator.MoveNextAsync().ConfigureAwait(false))
  197. {
  198. current = enumerator.Current;
  199. return true;
  200. }
  201. break; // while
  202. }
  203. break; // case
  204. }
  205. await DisposeAsync().ConfigureAwait(false);
  206. return false;
  207. }
  208. }
  209. private sealed class CatchAsyncIterator<TSource> : AsyncIterator<TSource>
  210. {
  211. private readonly IEnumerable<IAsyncEnumerable<TSource>> sources;
  212. private IAsyncEnumerator<TSource> enumerator;
  213. private ExceptionDispatchInfo error;
  214. private IEnumerator<IAsyncEnumerable<TSource>> sourcesEnumerator;
  215. public CatchAsyncIterator(IEnumerable<IAsyncEnumerable<TSource>> sources)
  216. {
  217. Debug.Assert(sources != null);
  218. this.sources = sources;
  219. }
  220. public override AsyncIterator<TSource> Clone()
  221. {
  222. return new CatchAsyncIterator<TSource>(sources);
  223. }
  224. public override async ValueTask DisposeAsync()
  225. {
  226. if (sourcesEnumerator != null)
  227. {
  228. sourcesEnumerator.Dispose();
  229. sourcesEnumerator = null;
  230. }
  231. if (enumerator != null)
  232. {
  233. await enumerator.DisposeAsync().ConfigureAwait(false);
  234. enumerator = null;
  235. }
  236. error = null;
  237. await base.DisposeAsync().ConfigureAwait(false);
  238. }
  239. protected override async ValueTask<bool> MoveNextCore(CancellationToken cancellationToken)
  240. {
  241. switch (state)
  242. {
  243. case AsyncIteratorState.Allocated:
  244. sourcesEnumerator = sources.GetEnumerator();
  245. state = AsyncIteratorState.Iterating;
  246. goto case AsyncIteratorState.Iterating;
  247. case AsyncIteratorState.Iterating:
  248. while (true)
  249. {
  250. if (enumerator == null)
  251. {
  252. if (!sourcesEnumerator.MoveNext())
  253. {
  254. // only throw if we have an error on the last one
  255. error?.Throw();
  256. break; // done, nothing else to do
  257. }
  258. error = null;
  259. enumerator = sourcesEnumerator.Current.GetAsyncEnumerator(cancellationToken);
  260. }
  261. try
  262. {
  263. if (await enumerator.MoveNextAsync().ConfigureAwait(false))
  264. {
  265. current = enumerator.Current;
  266. return true;
  267. }
  268. }
  269. catch (Exception ex)
  270. {
  271. // Done with the current one, go to the next
  272. await enumerator.DisposeAsync().ConfigureAwait(false);
  273. enumerator = null;
  274. error = ExceptionDispatchInfo.Capture(ex);
  275. continue;
  276. }
  277. break; // while
  278. }
  279. break; // case
  280. }
  281. await DisposeAsync().ConfigureAwait(false);
  282. return false;
  283. }
  284. }
  285. }
  286. }