Do.cs 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Diagnostics;
  6. using System.Threading.Tasks;
  7. namespace System.Linq
  8. {
  9. public static partial class AsyncEnumerable
  10. {
  11. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext)
  12. {
  13. if (source == null)
  14. throw new ArgumentNullException(nameof(source));
  15. if (onNext == null)
  16. throw new ArgumentNullException(nameof(onNext));
  17. return DoHelper(source, onNext, null, null);
  18. }
  19. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action onCompleted)
  20. {
  21. if (source == null)
  22. throw new ArgumentNullException(nameof(source));
  23. if (onNext == null)
  24. throw new ArgumentNullException(nameof(onNext));
  25. if (onCompleted == null)
  26. throw new ArgumentNullException(nameof(onCompleted));
  27. return DoHelper(source, onNext, null, onCompleted);
  28. }
  29. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError)
  30. {
  31. if (source == null)
  32. throw new ArgumentNullException(nameof(source));
  33. if (onNext == null)
  34. throw new ArgumentNullException(nameof(onNext));
  35. if (onError == null)
  36. throw new ArgumentNullException(nameof(onError));
  37. return DoHelper(source, onNext, onError, null);
  38. }
  39. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError, Action onCompleted)
  40. {
  41. if (source == null)
  42. throw new ArgumentNullException(nameof(source));
  43. if (onNext == null)
  44. throw new ArgumentNullException(nameof(onNext));
  45. if (onError == null)
  46. throw new ArgumentNullException(nameof(onError));
  47. if (onCompleted == null)
  48. throw new ArgumentNullException(nameof(onCompleted));
  49. return DoHelper(source, onNext, onError, onCompleted);
  50. }
  51. public static IAsyncEnumerable<TSource> Do<TSource>(this IAsyncEnumerable<TSource> source, IObserver<TSource> observer)
  52. {
  53. if (source == null)
  54. throw new ArgumentNullException(nameof(source));
  55. if (observer == null)
  56. throw new ArgumentNullException(nameof(observer));
  57. return DoHelper(source, observer.OnNext, observer.OnError, observer.OnCompleted);
  58. }
  59. private static IAsyncEnumerable<TSource> DoHelper<TSource>(this IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError, Action onCompleted)
  60. {
  61. return new DoAsyncIterator<TSource>(source, onNext, onError, onCompleted);
  62. }
  63. private sealed class DoAsyncIterator<TSource> : AsyncIterator<TSource>
  64. {
  65. private readonly Action onCompleted;
  66. private readonly Action<Exception> onError;
  67. private readonly Action<TSource> onNext;
  68. private readonly IAsyncEnumerable<TSource> source;
  69. private IAsyncEnumerator<TSource> enumerator;
  70. public DoAsyncIterator(IAsyncEnumerable<TSource> source, Action<TSource> onNext, Action<Exception> onError, Action onCompleted)
  71. {
  72. Debug.Assert(source != null);
  73. Debug.Assert(onNext != null);
  74. this.source = source;
  75. this.onNext = onNext;
  76. this.onError = onError;
  77. this.onCompleted = onCompleted;
  78. }
  79. public override AsyncIterator<TSource> Clone()
  80. {
  81. return new DoAsyncIterator<TSource>(source, onNext, onError, onCompleted);
  82. }
  83. public override async Task DisposeAsync()
  84. {
  85. if (enumerator != null)
  86. {
  87. await enumerator.DisposeAsync().ConfigureAwait(false);
  88. enumerator = null;
  89. }
  90. await base.DisposeAsync().ConfigureAwait(false);
  91. }
  92. protected override async Task<bool> MoveNextCore()
  93. {
  94. switch (state)
  95. {
  96. case AsyncIteratorState.Allocated:
  97. enumerator = source.GetAsyncEnumerator();
  98. state = AsyncIteratorState.Iterating;
  99. goto case AsyncIteratorState.Iterating;
  100. case AsyncIteratorState.Iterating:
  101. try
  102. {
  103. if (await enumerator.MoveNextAsync().ConfigureAwait(false))
  104. {
  105. current = enumerator.Current;
  106. onNext(current);
  107. return true;
  108. }
  109. }
  110. catch (OperationCanceledException)
  111. {
  112. throw;
  113. }
  114. catch (Exception ex)
  115. {
  116. onError?.Invoke(ex);
  117. throw;
  118. }
  119. onCompleted?.Invoke();
  120. await DisposeAsync().ConfigureAwait(false);
  121. break;
  122. }
  123. return false;
  124. }
  125. }
  126. }
  127. }