Select.cs 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System;
  5. using System.Collections.Generic;
  6. using System.Diagnostics;
  7. using System.Linq;
  8. using System.Threading;
  9. using System.Threading.Tasks;
  10. namespace System.Linq
  11. {
  12. public static partial class AsyncEnumerable
  13. {
  14. public static IAsyncEnumerable<TResult> Select<TSource, TResult>(this IAsyncEnumerable<TSource> source, Func<TSource, TResult> selector)
  15. {
  16. if (source == null)
  17. throw new ArgumentNullException(nameof(source));
  18. if (selector == null)
  19. throw new ArgumentNullException(nameof(selector));
  20. var iterator = source as AsyncIterator<TSource>;
  21. if (iterator != null)
  22. {
  23. return iterator.Select(selector);
  24. }
  25. // TODO: Can we add optimizations for IList or anything else here?
  26. return new SelectEnumerableAsyncIterator<TSource, TResult>(source, selector);
  27. }
  28. public static IAsyncEnumerable<TResult> Select<TSource, TResult>(this IAsyncEnumerable<TSource> source, Func<TSource, int, TResult> selector)
  29. {
  30. if (source == null)
  31. throw new ArgumentNullException(nameof(source));
  32. if (selector == null)
  33. throw new ArgumentNullException(nameof(selector));
  34. return CreateEnumerable(
  35. () =>
  36. {
  37. var e = source.GetEnumerator();
  38. var current = default(TResult);
  39. var index = 0;
  40. var cts = new CancellationTokenDisposable();
  41. var d = Disposable.Create(cts, e);
  42. return CreateEnumerator(
  43. async ct =>
  44. {
  45. if (await e.MoveNext(cts.Token)
  46. .ConfigureAwait(false))
  47. {
  48. current = selector(e.Current, checked(index++));
  49. return true;
  50. }
  51. return false;
  52. },
  53. () => current,
  54. d.Dispose,
  55. e
  56. );
  57. });
  58. }
  59. private static Func<TSource, TResult> CombineSelectors<TSource, TMiddle, TResult>(Func<TSource, TMiddle> selector1, Func<TMiddle, TResult> selector2)
  60. {
  61. return x => selector2(selector1(x));
  62. }
  63. internal sealed class SelectEnumerableAsyncIterator<TSource, TResult> : AsyncIterator<TResult>
  64. {
  65. private readonly IAsyncEnumerable<TSource> source;
  66. private readonly Func<TSource, TResult> selector;
  67. private IAsyncEnumerator<TSource> enumerator;
  68. public SelectEnumerableAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, TResult> selector)
  69. {
  70. Debug.Assert(source != null);
  71. Debug.Assert(selector != null);
  72. this.source = source;
  73. this.selector = selector;
  74. }
  75. public override AsyncIterator<TResult> Clone()
  76. {
  77. return new SelectEnumerableAsyncIterator<TSource, TResult>(source, selector);
  78. }
  79. public override void Dispose()
  80. {
  81. if (enumerator != null)
  82. {
  83. enumerator.Dispose();
  84. enumerator = null;
  85. }
  86. base.Dispose();
  87. }
  88. public override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
  89. {
  90. switch (state)
  91. {
  92. case State.Allocated:
  93. enumerator = source.GetEnumerator();
  94. state = State.Iterating;
  95. goto case State.Iterating;
  96. case State.Iterating:
  97. if (await enumerator.MoveNext(cancellationToken)
  98. .ConfigureAwait(false))
  99. {
  100. current = selector(enumerator.Current);
  101. return true;
  102. }
  103. Dispose();
  104. break;
  105. }
  106. return false;
  107. }
  108. public override IAsyncEnumerable<TResult1> Select<TResult1>(Func<TResult, TResult1> selector)
  109. {
  110. return new SelectEnumerableAsyncIterator<TSource, TResult1>(source, CombineSelectors(this.selector, selector));
  111. }
  112. }
  113. }
  114. }