Cast.cs 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Threading;
  6. using System.Threading.Tasks;
  7. namespace System.Linq
  8. {
  9. public static partial class AsyncEnumerable
  10. {
  11. // NB: This is a non-standard LINQ operator, because we don't have a non-generic IAsyncEnumerable.
  12. // We're keeping it to enable `from T x in xs` binding in C#.
  13. public static IAsyncEnumerable<TResult> Cast<TResult>(this IAsyncEnumerable<object> source)
  14. {
  15. if (source == null)
  16. throw Error.ArgumentNull(nameof(source));
  17. if (source is IAsyncEnumerable<TResult> typedSource)
  18. {
  19. return typedSource;
  20. }
  21. #if USE_ASYNC_ITERATOR
  22. return Create(Core);
  23. async IAsyncEnumerator<TResult> Core(CancellationToken cancellationToken)
  24. {
  25. await foreach (var obj in AsyncEnumerableExtensions.WithCancellation(source, cancellationToken).ConfigureAwait(false))
  26. {
  27. yield return (TResult)obj;
  28. }
  29. }
  30. #else
  31. return new CastAsyncIterator<TResult>(source);
  32. #endif
  33. }
  34. #if !USE_ASYNC_ITERATOR
  35. private sealed class CastAsyncIterator<TResult> : AsyncIterator<TResult>
  36. {
  37. private readonly IAsyncEnumerable<object> _source;
  38. private IAsyncEnumerator<object> _enumerator;
  39. public CastAsyncIterator(IAsyncEnumerable<object> source)
  40. {
  41. _source = source;
  42. }
  43. public override AsyncIteratorBase<TResult> Clone()
  44. {
  45. return new CastAsyncIterator<TResult>(_source);
  46. }
  47. public override async ValueTask DisposeAsync()
  48. {
  49. if (_enumerator != null)
  50. {
  51. await _enumerator.DisposeAsync().ConfigureAwait(false);
  52. _enumerator = null;
  53. }
  54. await base.DisposeAsync().ConfigureAwait(false);
  55. }
  56. protected override async ValueTask<bool> MoveNextCore()
  57. {
  58. switch (_state)
  59. {
  60. case AsyncIteratorState.Allocated:
  61. _enumerator = _source.GetAsyncEnumerator(_cancellationToken);
  62. _state = AsyncIteratorState.Iterating;
  63. goto case AsyncIteratorState.Iterating;
  64. case AsyncIteratorState.Iterating:
  65. if (await _enumerator.MoveNextAsync().ConfigureAwait(false))
  66. {
  67. _current = (TResult)_enumerator.Current;
  68. return true;
  69. }
  70. await DisposeAsync().ConfigureAwait(false);
  71. break;
  72. }
  73. return false;
  74. }
  75. }
  76. #endif
  77. }
  78. }