using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerable { internal abstract class AsyncIterator : IAsyncEnumerable, IAsyncEnumerator { public enum State { New = 0, Allocated = 1, Iterating = 2, Disposed = -1, } private readonly int threadId; internal State state = State.New; internal TSource current; private CancellationTokenSource cancellationTokenSource; protected AsyncIterator() { threadId = Environment.CurrentManagedThreadId; } public abstract AsyncIterator Clone(); public IAsyncEnumerator GetEnumerator() { var enumerator = state == State.New && threadId == Environment.CurrentManagedThreadId ? this : Clone(); enumerator.state = State.Allocated; enumerator.cancellationTokenSource = new CancellationTokenSource(); return enumerator; } public virtual void Dispose() { if (!cancellationTokenSource.IsCancellationRequested) { cancellationTokenSource.Cancel(); } cancellationTokenSource.Dispose(); current = default(TSource); state = State.Disposed; } public TSource Current => current; public async Task MoveNext(CancellationToken cancellationToken) { if (state == State.Disposed) return false; using (var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, cancellationTokenSource.Token)) using (cancellationToken.Register(Dispose)) { try { var result = await MoveNextCore(cts.Token).ConfigureAwait(false); return result; } catch { Dispose(); throw; } } } protected abstract Task MoveNextCore(CancellationToken cancellationToken); public virtual IAsyncEnumerable Select(Func selector) { return new SelectEnumerableAsyncIterator(this, selector); } public virtual IAsyncEnumerable Where(Func predicate) { return new WhereEnumerableAsyncIterator(this, predicate); } } } }