AsyncIterator.cs 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Threading;
  6. using System.Threading.Tasks;
  7. namespace System.Linq
  8. {
  9. internal abstract class AsyncIterator<TSource> : IAsyncEnumerable<TSource>, IAsyncEnumerator<TSource>
  10. {
  11. private readonly int threadId;
  12. private CancellationTokenSource cancellationTokenSource;
  13. private bool currentIsInvalid = true;
  14. internal TSource current;
  15. internal AsyncIteratorState state = AsyncIteratorState.New;
  16. protected AsyncIterator()
  17. {
  18. threadId = Environment.CurrentManagedThreadId;
  19. }
  20. public IAsyncEnumerator<TSource> GetAsyncEnumerator()
  21. {
  22. var enumerator = state == AsyncIteratorState.New && threadId == Environment.CurrentManagedThreadId ?
  23. this :
  24. Clone();
  25. enumerator.state = AsyncIteratorState.Allocated;
  26. enumerator.cancellationTokenSource = new CancellationTokenSource();
  27. try
  28. {
  29. enumerator.OnGetEnumerator();
  30. }
  31. catch
  32. {
  33. enumerator.DisposeAsync(); // REVIEW: fire-and-forget?
  34. throw;
  35. }
  36. return enumerator;
  37. }
  38. public virtual Task DisposeAsync()
  39. {
  40. if (cancellationTokenSource != null)
  41. {
  42. if (!cancellationTokenSource.IsCancellationRequested)
  43. {
  44. cancellationTokenSource.Cancel();
  45. }
  46. cancellationTokenSource.Dispose();
  47. }
  48. current = default(TSource);
  49. state = AsyncIteratorState.Disposed;
  50. return TaskExt.CompletedTask;
  51. }
  52. public TSource Current
  53. {
  54. get
  55. {
  56. if (currentIsInvalid)
  57. throw new InvalidOperationException("Enumerator is in an invalid state");
  58. return current;
  59. }
  60. }
  61. public async Task<bool> MoveNextAsync()
  62. {
  63. // Note: MoveNext *must* be implemented as an async method to ensure
  64. // that any exceptions thrown from the MoveNextCore call are handled
  65. // by the try/catch, whether they're sync or async
  66. if (state == AsyncIteratorState.Disposed)
  67. {
  68. return false;
  69. }
  70. try
  71. {
  72. var result = await MoveNextCore().ConfigureAwait(false);
  73. currentIsInvalid = !result; // if move next is false, invalid otherwise valid
  74. return result;
  75. }
  76. catch
  77. {
  78. currentIsInvalid = true;
  79. await DisposeAsync().ConfigureAwait(false);
  80. throw;
  81. }
  82. }
  83. public abstract AsyncIterator<TSource> Clone();
  84. protected abstract Task<bool> MoveNextCore();
  85. protected virtual void OnGetEnumerator()
  86. {
  87. }
  88. }
  89. internal enum AsyncIteratorState
  90. {
  91. New = 0,
  92. Allocated = 1,
  93. Iterating = 2,
  94. Disposed = -1
  95. }
  96. }