Defer.cs 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Diagnostics;
  6. using System.Threading;
  7. using System.Threading.Tasks;
  8. namespace System.Linq
  9. {
  10. public static partial class AsyncEnumerableEx
  11. {
  12. public static IAsyncEnumerable<TSource> Defer<TSource>(Func<IAsyncEnumerable<TSource>> factory)
  13. {
  14. if (factory == null)
  15. throw Error.ArgumentNull(nameof(factory));
  16. return new DeferIterator<TSource>(factory);
  17. }
  18. public static IAsyncEnumerable<TSource> Defer<TSource>(Func<Task<IAsyncEnumerable<TSource>>> factory)
  19. {
  20. if (factory == null)
  21. throw Error.ArgumentNull(nameof(factory));
  22. return new AsyncDeferIterator<TSource>(factory);
  23. }
  24. #if !NO_DEEP_CANCELLATION
  25. public static IAsyncEnumerable<TSource> Defer<TSource>(Func<CancellationToken, Task<IAsyncEnumerable<TSource>>> factory)
  26. {
  27. if (factory == null)
  28. throw Error.ArgumentNull(nameof(factory));
  29. return new AsyncDeferIteratorWithCancellation<TSource>(factory);
  30. }
  31. #endif
  32. private sealed class DeferIterator<T> : AsyncIteratorBase<T>
  33. {
  34. private readonly Func<IAsyncEnumerable<T>> _factory;
  35. private IAsyncEnumerator<T> _enumerator;
  36. public DeferIterator(Func<IAsyncEnumerable<T>> factory)
  37. {
  38. Debug.Assert(factory != null);
  39. _factory = factory;
  40. }
  41. public override T Current => _enumerator == null ? default : _enumerator.Current;
  42. public override AsyncIteratorBase<T> Clone()
  43. {
  44. return new DeferIterator<T>(_factory);
  45. }
  46. public override async ValueTask DisposeAsync()
  47. {
  48. if (_enumerator != null)
  49. {
  50. await _enumerator.DisposeAsync().ConfigureAwait(false);
  51. _enumerator = null;
  52. }
  53. await base.DisposeAsync().ConfigureAwait(false);
  54. }
  55. protected override ValueTask<bool> MoveNextCore()
  56. {
  57. if (_enumerator == null)
  58. {
  59. return InitializeAndMoveNextAsync();
  60. }
  61. return _enumerator.MoveNextAsync();
  62. }
  63. private async ValueTask<bool> InitializeAndMoveNextAsync()
  64. {
  65. // NB: Using an async method to ensure any exception is reported via the task.
  66. try
  67. {
  68. _enumerator = _factory().GetAsyncEnumerator(_cancellationToken);
  69. }
  70. catch (Exception ex)
  71. {
  72. _enumerator = Throw<T>(ex).GetAsyncEnumerator(_cancellationToken);
  73. throw;
  74. }
  75. return await _enumerator.MoveNextAsync().ConfigureAwait(false);
  76. }
  77. }
  78. private sealed class AsyncDeferIterator<T> : AsyncIteratorBase<T>
  79. {
  80. private readonly Func<Task<IAsyncEnumerable<T>>> _factory;
  81. private IAsyncEnumerator<T> _enumerator;
  82. public AsyncDeferIterator(Func< Task<IAsyncEnumerable<T>>> factory)
  83. {
  84. Debug.Assert(factory != null);
  85. _factory = factory;
  86. }
  87. public override T Current => _enumerator == null ? default : _enumerator.Current;
  88. public override AsyncIteratorBase<T> Clone()
  89. {
  90. return new AsyncDeferIterator<T>(_factory);
  91. }
  92. public override async ValueTask DisposeAsync()
  93. {
  94. if (_enumerator != null)
  95. {
  96. await _enumerator.DisposeAsync().ConfigureAwait(false);
  97. _enumerator = null;
  98. }
  99. await base.DisposeAsync().ConfigureAwait(false);
  100. }
  101. protected override ValueTask<bool> MoveNextCore()
  102. {
  103. if (_enumerator == null)
  104. {
  105. return InitializeAndMoveNextAsync();
  106. }
  107. return _enumerator.MoveNextAsync();
  108. }
  109. private async ValueTask<bool> InitializeAndMoveNextAsync()
  110. {
  111. try
  112. {
  113. _enumerator = (await _factory().ConfigureAwait(false)).GetAsyncEnumerator(_cancellationToken);
  114. }
  115. catch (Exception ex)
  116. {
  117. _enumerator = Throw<T>(ex).GetAsyncEnumerator(_cancellationToken);
  118. throw;
  119. }
  120. return await _enumerator.MoveNextAsync().ConfigureAwait(false);
  121. }
  122. }
  123. #if !NO_DEEP_CANCELLATION
  124. private sealed class AsyncDeferIteratorWithCancellation<T> : AsyncIteratorBase<T>
  125. {
  126. private readonly Func<CancellationToken, Task<IAsyncEnumerable<T>>> _factory;
  127. private IAsyncEnumerator<T> _enumerator;
  128. public AsyncDeferIteratorWithCancellation(Func<CancellationToken, Task<IAsyncEnumerable<T>>> factory)
  129. {
  130. Debug.Assert(factory != null);
  131. _factory = factory;
  132. }
  133. public override T Current => _enumerator == null ? default : _enumerator.Current;
  134. public override AsyncIteratorBase<T> Clone()
  135. {
  136. return new AsyncDeferIteratorWithCancellation<T>(_factory);
  137. }
  138. public override async ValueTask DisposeAsync()
  139. {
  140. if (_enumerator != null)
  141. {
  142. await _enumerator.DisposeAsync().ConfigureAwait(false);
  143. _enumerator = null;
  144. }
  145. await base.DisposeAsync().ConfigureAwait(false);
  146. }
  147. protected override ValueTask<bool> MoveNextCore()
  148. {
  149. if (_enumerator == null)
  150. {
  151. return InitializeAndMoveNextAsync();
  152. }
  153. return _enumerator.MoveNextAsync();
  154. }
  155. private async ValueTask<bool> InitializeAndMoveNextAsync()
  156. {
  157. try
  158. {
  159. _enumerator = (await _factory(_cancellationToken).ConfigureAwait(false)).GetAsyncEnumerator(_cancellationToken);
  160. }
  161. catch (Exception ex)
  162. {
  163. _enumerator = Throw<T>(ex).GetAsyncEnumerator(_cancellationToken);
  164. throw;
  165. }
  166. return await _enumerator.MoveNextAsync().ConfigureAwait(false);
  167. }
  168. }
  169. #endif
  170. }
  171. }