ToAsyncEnumerable.cs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections;
  5. using System.Collections.Generic;
  6. using System.Diagnostics;
  7. using System.Threading;
  8. using System.Threading.Tasks;
  9. namespace System.Linq
  10. {
  11. public static partial class AsyncEnumerable
  12. {
  13. public static IAsyncEnumerable<TSource> ToAsyncEnumerable<TSource>(this IEnumerable<TSource> source)
  14. {
  15. if (source == null)
  16. throw Error.ArgumentNull(nameof(source));
  17. // optimize these adapters for lists and collections
  18. switch (source)
  19. {
  20. case IList<TSource> list:
  21. return new AsyncIListEnumerableAdapter<TSource>(list);
  22. case ICollection<TSource> collection:
  23. return new AsyncICollectionEnumerableAdapter<TSource>(collection);
  24. }
  25. return new AsyncEnumerableAdapter<TSource>(source);
  26. }
  27. public static IAsyncEnumerable<TSource> ToAsyncEnumerable<TSource>(this IObservable<TSource> source)
  28. {
  29. if (source == null)
  30. throw Error.ArgumentNull(nameof(source));
  31. return CreateEnumerable(
  32. ct =>
  33. {
  34. var observer = new ToAsyncEnumerableObserver<TSource>();
  35. var subscription = source.Subscribe(observer);
  36. // REVIEW: Review possible concurrency issues with Dispose calls.
  37. var ctr = ct.Register(subscription.Dispose);
  38. return AsyncEnumerator.Create(
  39. tcs =>
  40. {
  41. var hasValue = false;
  42. var hasCompleted = false;
  43. var error = default(Exception);
  44. lock (observer.SyncRoot)
  45. {
  46. if (observer.Values.Count > 0)
  47. {
  48. hasValue = true;
  49. observer.Current = observer.Values.Dequeue();
  50. }
  51. else if (observer.HasCompleted)
  52. {
  53. hasCompleted = true;
  54. }
  55. else if (observer.Error != null)
  56. {
  57. error = observer.Error;
  58. }
  59. else
  60. {
  61. observer.TaskCompletionSource = tcs;
  62. }
  63. }
  64. if (hasValue)
  65. {
  66. tcs.TrySetResult(true);
  67. }
  68. else if (hasCompleted)
  69. {
  70. tcs.TrySetResult(false);
  71. }
  72. else if (error != null)
  73. {
  74. tcs.TrySetException(error);
  75. }
  76. return new ValueTask<bool>(tcs.Task);
  77. },
  78. () => observer.Current,
  79. () =>
  80. {
  81. ctr.Dispose();
  82. subscription.Dispose();
  83. // Should we cancel in-flight operations somehow?
  84. return default;
  85. });
  86. });
  87. }
  88. private sealed class AsyncEnumerableAdapter<T> : AsyncIterator<T>, IAsyncIListProvider<T>
  89. {
  90. private readonly IEnumerable<T> _source;
  91. private IEnumerator<T> _enumerator;
  92. public AsyncEnumerableAdapter(IEnumerable<T> source)
  93. {
  94. Debug.Assert(source != null);
  95. _source = source;
  96. }
  97. public override AsyncIteratorBase<T> Clone()
  98. {
  99. return new AsyncEnumerableAdapter<T>(_source);
  100. }
  101. public override async ValueTask DisposeAsync()
  102. {
  103. if (_enumerator != null)
  104. {
  105. _enumerator.Dispose();
  106. _enumerator = null;
  107. }
  108. await base.DisposeAsync().ConfigureAwait(false);
  109. }
  110. protected override async ValueTask<bool> MoveNextCore()
  111. {
  112. switch (_state)
  113. {
  114. case AsyncIteratorState.Allocated:
  115. _enumerator = _source.GetEnumerator();
  116. _state = AsyncIteratorState.Iterating;
  117. goto case AsyncIteratorState.Iterating;
  118. case AsyncIteratorState.Iterating:
  119. if (_enumerator.MoveNext())
  120. {
  121. _current = _enumerator.Current;
  122. return true;
  123. }
  124. await DisposeAsync().ConfigureAwait(false);
  125. break;
  126. }
  127. return false;
  128. }
  129. // These optimizations rely on the Sys.Linq impls from IEnumerable to optimize
  130. // and short circuit as appropriate
  131. public Task<T[]> ToArrayAsync(CancellationToken cancellationToken)
  132. {
  133. return Task.FromResult(_source.ToArray());
  134. }
  135. public Task<List<T>> ToListAsync(CancellationToken cancellationToken)
  136. {
  137. return Task.FromResult(_source.ToList());
  138. }
  139. public Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
  140. {
  141. return Task.FromResult(_source.Count());
  142. }
  143. }
  144. private sealed class AsyncIListEnumerableAdapter<T> : AsyncIterator<T>, IAsyncIListProvider<T>, IList<T>
  145. {
  146. private readonly IList<T> _source;
  147. private IEnumerator<T> _enumerator;
  148. public AsyncIListEnumerableAdapter(IList<T> source)
  149. {
  150. Debug.Assert(source != null);
  151. _source = source;
  152. }
  153. public override AsyncIteratorBase<T> Clone()
  154. {
  155. return new AsyncIListEnumerableAdapter<T>(_source);
  156. }
  157. public override async ValueTask DisposeAsync()
  158. {
  159. if (_enumerator != null)
  160. {
  161. _enumerator.Dispose();
  162. _enumerator = null;
  163. }
  164. await base.DisposeAsync().ConfigureAwait(false);
  165. }
  166. protected override async ValueTask<bool> MoveNextCore()
  167. {
  168. switch (_state)
  169. {
  170. case AsyncIteratorState.Allocated:
  171. _enumerator = _source.GetEnumerator();
  172. _state = AsyncIteratorState.Iterating;
  173. goto case AsyncIteratorState.Iterating;
  174. case AsyncIteratorState.Iterating:
  175. if (_enumerator.MoveNext())
  176. {
  177. _current = _enumerator.Current;
  178. return true;
  179. }
  180. await DisposeAsync().ConfigureAwait(false);
  181. break;
  182. }
  183. return false;
  184. }
  185. public override IAsyncEnumerable<TResult> Select<TResult>(Func<T, TResult> selector)
  186. {
  187. return new SelectIListIterator<T, TResult>(_source, selector);
  188. }
  189. // These optimizations rely on the Sys.Linq impls from IEnumerable to optimize
  190. // and short circuit as appropriate
  191. public Task<T[]> ToArrayAsync(CancellationToken cancellationToken)
  192. {
  193. return Task.FromResult(_source.ToArray());
  194. }
  195. public Task<List<T>> ToListAsync(CancellationToken cancellationToken)
  196. {
  197. return Task.FromResult(_source.ToList());
  198. }
  199. public Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
  200. {
  201. return Task.FromResult(_source.Count);
  202. }
  203. IEnumerator<T> IEnumerable<T>.GetEnumerator() => _source.GetEnumerator();
  204. IEnumerator IEnumerable.GetEnumerator() => _source.GetEnumerator();
  205. void ICollection<T>.Add(T item) => _source.Add(item);
  206. void ICollection<T>.Clear() => _source.Clear();
  207. bool ICollection<T>.Contains(T item) => _source.Contains(item);
  208. void ICollection<T>.CopyTo(T[] array, int arrayIndex) => _source.CopyTo(array, arrayIndex);
  209. bool ICollection<T>.Remove(T item) => _source.Remove(item);
  210. int ICollection<T>.Count => _source.Count;
  211. bool ICollection<T>.IsReadOnly => _source.IsReadOnly;
  212. int IList<T>.IndexOf(T item) => _source.IndexOf(item);
  213. void IList<T>.Insert(int index, T item) => _source.Insert(index, item);
  214. void IList<T>.RemoveAt(int index) => _source.RemoveAt(index);
  215. T IList<T>.this[int index]
  216. {
  217. get { return _source[index]; }
  218. set { _source[index] = value; }
  219. }
  220. }
  221. private sealed class AsyncICollectionEnumerableAdapter<T> : AsyncIterator<T>, IAsyncIListProvider<T>, ICollection<T>
  222. {
  223. private readonly ICollection<T> _source;
  224. private IEnumerator<T> _enumerator;
  225. public AsyncICollectionEnumerableAdapter(ICollection<T> source)
  226. {
  227. Debug.Assert(source != null);
  228. _source = source;
  229. }
  230. public override AsyncIteratorBase<T> Clone()
  231. {
  232. return new AsyncICollectionEnumerableAdapter<T>(_source);
  233. }
  234. public override async ValueTask DisposeAsync()
  235. {
  236. if (_enumerator != null)
  237. {
  238. _enumerator.Dispose();
  239. _enumerator = null;
  240. }
  241. await base.DisposeAsync().ConfigureAwait(false);
  242. }
  243. protected override async ValueTask<bool> MoveNextCore()
  244. {
  245. switch (_state)
  246. {
  247. case AsyncIteratorState.Allocated:
  248. _enumerator = _source.GetEnumerator();
  249. _state = AsyncIteratorState.Iterating;
  250. goto case AsyncIteratorState.Iterating;
  251. case AsyncIteratorState.Iterating:
  252. if (_enumerator.MoveNext())
  253. {
  254. _current = _enumerator.Current;
  255. return true;
  256. }
  257. await DisposeAsync().ConfigureAwait(false);
  258. break;
  259. }
  260. return false;
  261. }
  262. // These optimizations rely on the Sys.Linq impls from IEnumerable to optimize
  263. // and short circuit as appropriate
  264. public Task<T[]> ToArrayAsync(CancellationToken cancellationToken)
  265. {
  266. return Task.FromResult(_source.ToArray());
  267. }
  268. public Task<List<T>> ToListAsync(CancellationToken cancellationToken)
  269. {
  270. return Task.FromResult(_source.ToList());
  271. }
  272. public Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
  273. {
  274. return Task.FromResult(_source.Count);
  275. }
  276. IEnumerator<T> IEnumerable<T>.GetEnumerator() => _source.GetEnumerator();
  277. IEnumerator IEnumerable.GetEnumerator() => _source.GetEnumerator();
  278. void ICollection<T>.Add(T item) => _source.Add(item);
  279. void ICollection<T>.Clear() => _source.Clear();
  280. bool ICollection<T>.Contains(T item) => _source.Contains(item);
  281. void ICollection<T>.CopyTo(T[] array, int arrayIndex) => _source.CopyTo(array, arrayIndex);
  282. bool ICollection<T>.Remove(T item) => _source.Remove(item);
  283. int ICollection<T>.Count => _source.Count;
  284. bool ICollection<T>.IsReadOnly => _source.IsReadOnly;
  285. }
  286. private sealed class ToAsyncEnumerableObserver<T> : IObserver<T>
  287. {
  288. public readonly Queue<T> Values;
  289. public T Current;
  290. public Exception Error;
  291. public bool HasCompleted;
  292. public TaskCompletionSource<bool> TaskCompletionSource;
  293. public ToAsyncEnumerableObserver()
  294. {
  295. Values = new Queue<T>();
  296. }
  297. public object SyncRoot
  298. {
  299. get { return Values; }
  300. }
  301. public void OnCompleted()
  302. {
  303. var tcs = default(TaskCompletionSource<bool>);
  304. lock (SyncRoot)
  305. {
  306. HasCompleted = true;
  307. if (TaskCompletionSource != null)
  308. {
  309. tcs = TaskCompletionSource;
  310. TaskCompletionSource = null;
  311. }
  312. }
  313. tcs?.TrySetResult(false);
  314. }
  315. public void OnError(Exception error)
  316. {
  317. var tcs = default(TaskCompletionSource<bool>);
  318. lock (SyncRoot)
  319. {
  320. Error = error;
  321. if (TaskCompletionSource != null)
  322. {
  323. tcs = TaskCompletionSource;
  324. TaskCompletionSource = null;
  325. }
  326. }
  327. tcs?.TrySetException(error);
  328. }
  329. public void OnNext(T value)
  330. {
  331. var tcs = default(TaskCompletionSource<bool>);
  332. lock (SyncRoot)
  333. {
  334. if (TaskCompletionSource == null)
  335. {
  336. Values.Enqueue(value);
  337. }
  338. else
  339. {
  340. Current = value;
  341. tcs = TaskCompletionSource;
  342. TaskCompletionSource = null;
  343. }
  344. }
  345. tcs?.TrySetResult(true);
  346. }
  347. }
  348. }
  349. }