Defer.cs 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Diagnostics;
  6. using System.Threading;
  7. using System.Threading.Tasks;
  8. using static System.Linq.AsyncEnumerable;
  9. namespace System.Linq
  10. {
  11. public static partial class AsyncEnumerableEx
  12. {
  13. public static IAsyncEnumerable<TSource> Defer<TSource>(Func<IAsyncEnumerable<TSource>> factory)
  14. {
  15. if (factory == null)
  16. throw Error.ArgumentNull(nameof(factory));
  17. return new DeferIterator<TSource>(factory);
  18. }
  19. public static IAsyncEnumerable<TSource> Defer<TSource>(Func<Task<IAsyncEnumerable<TSource>>> factory)
  20. {
  21. if (factory == null)
  22. throw Error.ArgumentNull(nameof(factory));
  23. return new AsyncDeferIterator<TSource>(factory);
  24. }
  25. private sealed class DeferIterator<T> : AsyncIteratorBase<T>
  26. {
  27. private readonly Func<IAsyncEnumerable<T>> _factory;
  28. private IAsyncEnumerator<T> _enumerator;
  29. public DeferIterator(Func<IAsyncEnumerable<T>> factory)
  30. {
  31. Debug.Assert(factory != null);
  32. _factory = factory;
  33. }
  34. public override T Current => _enumerator == null ? default : _enumerator.Current;
  35. public override AsyncIteratorBase<T> Clone()
  36. {
  37. return new DeferIterator<T>(_factory);
  38. }
  39. public override async ValueTask DisposeAsync()
  40. {
  41. if (_enumerator != null)
  42. {
  43. await _enumerator.DisposeAsync().ConfigureAwait(false);
  44. _enumerator = null;
  45. }
  46. await base.DisposeAsync().ConfigureAwait(false);
  47. }
  48. protected override ValueTask<bool> MoveNextCore()
  49. {
  50. if (_enumerator == null)
  51. {
  52. return InitializeAndMoveNextAsync();
  53. }
  54. return _enumerator.MoveNextAsync();
  55. }
  56. private async ValueTask<bool> InitializeAndMoveNextAsync()
  57. {
  58. // NB: Using an async method to ensure any exception is reported via the task.
  59. try
  60. {
  61. _enumerator = _factory().GetAsyncEnumerator(_cancellationToken);
  62. }
  63. catch (Exception ex)
  64. {
  65. _enumerator = Throw<T>(ex).GetAsyncEnumerator(_cancellationToken);
  66. throw;
  67. }
  68. return await _enumerator.MoveNextAsync().ConfigureAwait(false);
  69. }
  70. }
  71. private sealed class AsyncDeferIterator<T> : AsyncIteratorBase<T>
  72. {
  73. private readonly Func<Task<IAsyncEnumerable<T>>> _factory;
  74. private IAsyncEnumerator<T> _enumerator;
  75. public AsyncDeferIterator(Func< Task<IAsyncEnumerable<T>>> factory)
  76. {
  77. Debug.Assert(factory != null);
  78. _factory = factory;
  79. }
  80. public override T Current => _enumerator == null ? default : _enumerator.Current;
  81. public override AsyncIteratorBase<T> Clone()
  82. {
  83. return new AsyncDeferIterator<T>(_factory);
  84. }
  85. public override async ValueTask DisposeAsync()
  86. {
  87. if (_enumerator != null)
  88. {
  89. await _enumerator.DisposeAsync().ConfigureAwait(false);
  90. _enumerator = null;
  91. }
  92. await base.DisposeAsync().ConfigureAwait(false);
  93. }
  94. protected override ValueTask<bool> MoveNextCore()
  95. {
  96. if (_enumerator == null)
  97. {
  98. return InitializeAndMoveNextAsync();
  99. }
  100. return _enumerator.MoveNextAsync();
  101. }
  102. private async ValueTask<bool> InitializeAndMoveNextAsync()
  103. {
  104. try
  105. {
  106. _enumerator = (await _factory().ConfigureAwait(false)).GetAsyncEnumerator(_cancellationToken);
  107. }
  108. catch (Exception ex)
  109. {
  110. _enumerator = Throw<T>(ex).GetAsyncEnumerator(_cancellationToken);
  111. throw;
  112. }
  113. return await _enumerator.MoveNextAsync().ConfigureAwait(false);
  114. }
  115. }
  116. }
  117. }