Create.cs 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Diagnostics;
  6. using System.Threading;
  7. using System.Threading.Tasks;
  8. namespace System.Linq
  9. {
  10. public static partial class AsyncEnumerable
  11. {
  12. public static IAsyncEnumerable<T> CreateEnumerable<T>(Func<IAsyncEnumerator<T>> getEnumerator)
  13. {
  14. if (getEnumerator == null)
  15. throw new ArgumentNullException(nameof(getEnumerator));
  16. return new AnonymousAsyncEnumerable<T>(getEnumerator);
  17. }
  18. public static IAsyncEnumerable<T> CreateEnumerable<T>(Func<Task<IAsyncEnumerator<T>>> getEnumerator)
  19. {
  20. if (getEnumerator == null)
  21. throw new ArgumentNullException(nameof(getEnumerator));
  22. return new AnonymousAsyncEnumerableWithTask<T>(getEnumerator);
  23. }
  24. public static IAsyncEnumerator<T> CreateEnumerator<T>(Func<Task<bool>> moveNext, Func<T> current, Func<Task> dispose)
  25. {
  26. if (moveNext == null)
  27. throw new ArgumentNullException(nameof(moveNext));
  28. // Note: Many methods pass null in for the second two params. We're assuming
  29. // That the caller is responsible and knows what they're doing
  30. return new AnonymousAsyncIterator<T>(moveNext, current, dispose);
  31. }
  32. private static IAsyncEnumerator<T> CreateEnumerator<T>(Func<TaskCompletionSource<bool>, Task<bool>> moveNext, Func<T> current, Func<Task> dispose)
  33. {
  34. var self = new AnonymousAsyncIterator<T>(
  35. async () =>
  36. {
  37. var tcs = new TaskCompletionSource<bool>();
  38. var stop = new Action(() => tcs.TrySetCanceled());
  39. return await moveNext(tcs).ConfigureAwait(false);
  40. },
  41. current,
  42. dispose
  43. );
  44. return self;
  45. }
  46. private sealed class AnonymousAsyncEnumerable<T> : IAsyncEnumerable<T>
  47. {
  48. private readonly Func<IAsyncEnumerator<T>> getEnumerator;
  49. public AnonymousAsyncEnumerable(Func<IAsyncEnumerator<T>> getEnumerator)
  50. {
  51. Debug.Assert(getEnumerator != null);
  52. this.getEnumerator = getEnumerator;
  53. }
  54. public IAsyncEnumerator<T> GetAsyncEnumerator() => getEnumerator();
  55. }
  56. private sealed class AnonymousAsyncEnumerableWithTask<T> : IAsyncEnumerable<T>
  57. {
  58. private readonly Func<Task<IAsyncEnumerator<T>>> getEnumerator;
  59. public AnonymousAsyncEnumerableWithTask(Func<Task<IAsyncEnumerator<T>>> getEnumerator)
  60. {
  61. Debug.Assert(getEnumerator != null);
  62. this.getEnumerator = getEnumerator;
  63. }
  64. public IAsyncEnumerator<T> GetAsyncEnumerator() => new Enumerator(getEnumerator);
  65. private sealed class Enumerator : IAsyncEnumerator<T>
  66. {
  67. private Func<Task<IAsyncEnumerator<T>>> getEnumerator;
  68. private IAsyncEnumerator<T> enumerator;
  69. public Enumerator(Func<Task<IAsyncEnumerator<T>>> getEnumerator)
  70. {
  71. Debug.Assert(getEnumerator != null);
  72. this.getEnumerator = getEnumerator;
  73. }
  74. public T Current
  75. {
  76. get
  77. {
  78. if (enumerator == null)
  79. throw new InvalidOperationException();
  80. return enumerator.Current;
  81. }
  82. }
  83. public async Task DisposeAsync()
  84. {
  85. var old = Interlocked.Exchange(ref enumerator, DisposedEnumerator.Instance);
  86. if (enumerator != null)
  87. {
  88. await enumerator.DisposeAsync().ConfigureAwait(false);
  89. }
  90. }
  91. public Task<bool> MoveNextAsync()
  92. {
  93. if (enumerator == null)
  94. {
  95. return InitAndMoveNextAsync();
  96. }
  97. return enumerator.MoveNextAsync();
  98. }
  99. private async Task<bool> InitAndMoveNextAsync()
  100. {
  101. try
  102. {
  103. enumerator = await getEnumerator().ConfigureAwait(false);
  104. }
  105. catch (Exception ex)
  106. {
  107. enumerator = Throw<T>(ex).GetAsyncEnumerator();
  108. throw;
  109. }
  110. finally
  111. {
  112. getEnumerator = null;
  113. }
  114. return await enumerator.MoveNextAsync().ConfigureAwait(false);
  115. }
  116. private sealed class DisposedEnumerator : IAsyncEnumerator<T>
  117. {
  118. public static readonly DisposedEnumerator Instance = new DisposedEnumerator();
  119. public T Current => throw new ObjectDisposedException("this");
  120. public Task DisposeAsync() => TaskExt.CompletedTask;
  121. public Task<bool> MoveNextAsync() => throw new ObjectDisposedException("this");
  122. }
  123. }
  124. }
  125. private sealed class AnonymousAsyncIterator<T> : AsyncIterator<T>
  126. {
  127. private readonly Func<T> currentFunc;
  128. private readonly Func<Task> dispose;
  129. private readonly Func<Task<bool>> moveNext;
  130. public AnonymousAsyncIterator(Func<Task<bool>> moveNext, Func<T> currentFunc, Func<Task> dispose)
  131. {
  132. Debug.Assert(moveNext != null);
  133. this.moveNext = moveNext;
  134. this.currentFunc = currentFunc;
  135. this.dispose = dispose;
  136. // Explicit call to initialize enumerator mode
  137. GetAsyncEnumerator();
  138. }
  139. public override AsyncIterator<T> Clone()
  140. {
  141. throw new NotSupportedException("AnonymousAsyncIterator cannot be cloned. It is only intended for use as an iterator.");
  142. }
  143. public override async Task DisposeAsync()
  144. {
  145. if (dispose != null)
  146. {
  147. await dispose().ConfigureAwait(false);
  148. }
  149. await base.DisposeAsync().ConfigureAwait(false);
  150. }
  151. protected override async Task<bool> MoveNextCore()
  152. {
  153. switch (state)
  154. {
  155. case AsyncIteratorState.Allocated:
  156. state = AsyncIteratorState.Iterating;
  157. goto case AsyncIteratorState.Iterating;
  158. case AsyncIteratorState.Iterating:
  159. if (await moveNext().ConfigureAwait(false))
  160. {
  161. current = currentFunc();
  162. return true;
  163. }
  164. await DisposeAsync().ConfigureAwait(false);
  165. break;
  166. }
  167. return false;
  168. }
  169. }
  170. }
  171. }