Scan.cs 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Diagnostics;
  6. using System.Threading;
  7. using System.Threading.Tasks;
  8. namespace System.Linq
  9. {
  10. public static partial class AsyncEnumerable
  11. {
  12. public static IAsyncEnumerable<TAccumulate> Scan<TSource, TAccumulate>(this IAsyncEnumerable<TSource> source, TAccumulate seed, Func<TAccumulate, TSource, TAccumulate> accumulator)
  13. {
  14. if (source == null)
  15. {
  16. throw new ArgumentNullException(nameof(source));
  17. }
  18. if (accumulator == null)
  19. {
  20. throw new ArgumentNullException(nameof(accumulator));
  21. }
  22. return new ScanAsyncEnumerable<TSource, TAccumulate>(source, seed, accumulator);
  23. }
  24. public static IAsyncEnumerable<TSource> Scan<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, TSource, TSource> accumulator)
  25. {
  26. if (source == null)
  27. {
  28. throw new ArgumentNullException(nameof(source));
  29. }
  30. if (accumulator == null)
  31. {
  32. throw new ArgumentNullException(nameof(accumulator));
  33. }
  34. return new ScanAsyncEnumerable<TSource>(source, accumulator);
  35. }
  36. private sealed class ScanAsyncEnumerable<TSource, TAccumulate> : AsyncIterator<TAccumulate>
  37. {
  38. private readonly Func<TAccumulate, TSource, TAccumulate> accumulator;
  39. private readonly TAccumulate seed;
  40. private readonly IAsyncEnumerable<TSource> source;
  41. private TAccumulate accumulated;
  42. private IAsyncEnumerator<TSource> enumerator;
  43. public ScanAsyncEnumerable(IAsyncEnumerable<TSource> source, TAccumulate seed, Func<TAccumulate, TSource, TAccumulate> accumulator)
  44. {
  45. Debug.Assert(source != null);
  46. Debug.Assert(accumulator != null);
  47. this.source = source;
  48. this.seed = seed;
  49. this.accumulator = accumulator;
  50. }
  51. public override AsyncIterator<TAccumulate> Clone()
  52. {
  53. return new ScanAsyncEnumerable<TSource, TAccumulate>(source, seed, accumulator);
  54. }
  55. public override void Dispose()
  56. {
  57. if (enumerator != null)
  58. {
  59. enumerator.Dispose();
  60. enumerator = null;
  61. accumulated = default;
  62. }
  63. base.Dispose();
  64. }
  65. protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
  66. {
  67. switch (state)
  68. {
  69. case AsyncIteratorState.Allocated:
  70. enumerator = source.GetEnumerator();
  71. accumulated = seed;
  72. state = AsyncIteratorState.Iterating;
  73. goto case AsyncIteratorState.Iterating;
  74. case AsyncIteratorState.Iterating:
  75. if (await enumerator.MoveNext(cancellationToken)
  76. .ConfigureAwait(false))
  77. {
  78. var item = enumerator.Current;
  79. accumulated = accumulator(accumulated, item);
  80. current = accumulated;
  81. return true;
  82. }
  83. break;
  84. }
  85. Dispose();
  86. return false;
  87. }
  88. }
  89. private sealed class ScanAsyncEnumerable<TSource> : AsyncIterator<TSource>
  90. {
  91. private readonly Func<TSource, TSource, TSource> accumulator;
  92. private readonly IAsyncEnumerable<TSource> source;
  93. private TSource accumulated;
  94. private IAsyncEnumerator<TSource> enumerator;
  95. private bool hasSeed;
  96. public ScanAsyncEnumerable(IAsyncEnumerable<TSource> source, Func<TSource, TSource, TSource> accumulator)
  97. {
  98. Debug.Assert(source != null);
  99. Debug.Assert(accumulator != null);
  100. this.source = source;
  101. this.accumulator = accumulator;
  102. }
  103. public override AsyncIterator<TSource> Clone()
  104. {
  105. return new ScanAsyncEnumerable<TSource>(source, accumulator);
  106. }
  107. public override void Dispose()
  108. {
  109. if (enumerator != null)
  110. {
  111. enumerator.Dispose();
  112. enumerator = null;
  113. accumulated = default;
  114. }
  115. base.Dispose();
  116. }
  117. protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
  118. {
  119. switch (state)
  120. {
  121. case AsyncIteratorState.Allocated:
  122. enumerator = source.GetEnumerator();
  123. hasSeed = false;
  124. accumulated = default;
  125. state = AsyncIteratorState.Iterating;
  126. goto case AsyncIteratorState.Iterating;
  127. case AsyncIteratorState.Iterating:
  128. while (await enumerator.MoveNext(cancellationToken)
  129. .ConfigureAwait(false))
  130. {
  131. var item = enumerator.Current;
  132. if (!hasSeed)
  133. {
  134. hasSeed = true;
  135. accumulated = item;
  136. continue; // loop
  137. }
  138. accumulated = accumulator(accumulated, item);
  139. current = accumulated;
  140. return true;
  141. }
  142. break; // case
  143. }
  144. Dispose();
  145. return false;
  146. }
  147. }
  148. }
  149. }