// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System.Collections.Generic; using System.Diagnostics; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerable { public static IAsyncEnumerable Skip(this IAsyncEnumerable source, int count) { if (source == null) throw new ArgumentNullException(nameof(source)); if (count <= 0) { // Return source if not actually skipping, but only if it's a type from here, to avoid // issues if collections are used as keys or otherwise must not be aliased. if (source is AsyncIterator) { return source; } count = 0; } return new SkipAsyncIterator(source, count); } public static IAsyncEnumerable SkipLast(this IAsyncEnumerable source, int count) { if (source == null) throw new ArgumentNullException(nameof(source)); if (count <= 0) { // Return source if not actually skipping, but only if it's a type from here, to avoid // issues if collections are used as keys or otherwise must not be aliased. if (source is AsyncIterator) { return source; } count = 0; } return new SkipLastAsyncIterator(source, count); } public static IAsyncEnumerable SkipWhile(this IAsyncEnumerable source, Func predicate) { if (source == null) throw new ArgumentNullException(nameof(source)); if (predicate == null) throw new ArgumentNullException(nameof(predicate)); return new SkipWhileAsyncIterator(source, predicate); } public static IAsyncEnumerable SkipWhile(this IAsyncEnumerable source, Func predicate) { if (source == null) throw new ArgumentNullException(nameof(source)); if (predicate == null) throw new ArgumentNullException(nameof(predicate)); return new SkipWhileWithIndexAsyncIterator(source, predicate); } private sealed class SkipAsyncIterator : AsyncIterator { private readonly int count; private readonly IAsyncEnumerable source; private int currentCount; private IAsyncEnumerator enumerator; public SkipAsyncIterator(IAsyncEnumerable source, int count) { Debug.Assert(source != null); this.source = source; this.count = count; currentCount = count; } public override AsyncIterator Clone() { return new SkipAsyncIterator(source, count); } public override void Dispose() { if (enumerator != null) { enumerator.Dispose(); enumerator = null; } base.Dispose(); } protected override async Task MoveNextCore(CancellationToken cancellationToken) { switch (state) { case AsyncIteratorState.Allocated: enumerator = source.GetEnumerator(); // skip elements as requested while (currentCount > 0 && await enumerator.MoveNext(cancellationToken) .ConfigureAwait(false)) { currentCount--; } if (currentCount <= 0) { state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; } break; case AsyncIteratorState.Iterating: if (await enumerator.MoveNext(cancellationToken) .ConfigureAwait(false)) { current = enumerator.Current; return true; } break; } Dispose(); return false; } } private sealed class SkipLastAsyncIterator : AsyncIterator { private readonly int count; private readonly IAsyncEnumerable source; private IAsyncEnumerator enumerator; private Queue queue; public SkipLastAsyncIterator(IAsyncEnumerable source, int count) { Debug.Assert(source != null); this.source = source; this.count = count; } public override AsyncIterator Clone() { return new SkipLastAsyncIterator(source, count); } public override void Dispose() { if (enumerator != null) { enumerator.Dispose(); enumerator = null; } queue = null; // release the memory base.Dispose(); } protected override async Task MoveNextCore(CancellationToken cancellationToken) { switch (state) { case AsyncIteratorState.Allocated: enumerator = source.GetEnumerator(); queue = new Queue(); state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: while (await enumerator.MoveNext(cancellationToken) .ConfigureAwait(false)) { var item = enumerator.Current; queue.Enqueue(item); if (queue.Count > count) { current = queue.Dequeue(); return true; } } break; } Dispose(); return false; } } private sealed class SkipWhileAsyncIterator : AsyncIterator { private readonly Func predicate; private readonly IAsyncEnumerable source; private bool doMoveNext; private IAsyncEnumerator enumerator; public SkipWhileAsyncIterator(IAsyncEnumerable source, Func predicate) { Debug.Assert(predicate != null); Debug.Assert(source != null); this.source = source; this.predicate = predicate; } public override AsyncIterator Clone() { return new SkipWhileAsyncIterator(source, predicate); } public override void Dispose() { if (enumerator != null) { enumerator.Dispose(); enumerator = null; } base.Dispose(); } protected override async Task MoveNextCore(CancellationToken cancellationToken) { switch (state) { case AsyncIteratorState.Allocated: enumerator = source.GetEnumerator(); // skip elements as requested while (await enumerator.MoveNext(cancellationToken) .ConfigureAwait(false)) { var element = enumerator.Current; if (!predicate(element)) { doMoveNext = false; state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; } } break; case AsyncIteratorState.Iterating: if (doMoveNext && await enumerator.MoveNext(cancellationToken) .ConfigureAwait(false)) { current = enumerator.Current; return true; } if (!doMoveNext) { current = enumerator.Current; doMoveNext = true; return true; } break; } Dispose(); return false; } } private sealed class SkipWhileWithIndexAsyncIterator : AsyncIterator { private readonly Func predicate; private readonly IAsyncEnumerable source; private bool doMoveNext; private IAsyncEnumerator enumerator; private int index; public SkipWhileWithIndexAsyncIterator(IAsyncEnumerable source, Func predicate) { Debug.Assert(predicate != null); Debug.Assert(source != null); this.source = source; this.predicate = predicate; } public override AsyncIterator Clone() { return new SkipWhileWithIndexAsyncIterator(source, predicate); } public override void Dispose() { if (enumerator != null) { enumerator.Dispose(); enumerator = null; } base.Dispose(); } protected override async Task MoveNextCore(CancellationToken cancellationToken) { switch (state) { case AsyncIteratorState.Allocated: enumerator = source.GetEnumerator(); index = -1; // skip elements as requested while (await enumerator.MoveNext(cancellationToken) .ConfigureAwait(false)) { checked { index++; } var element = enumerator.Current; if (!predicate(element, index)) { doMoveNext = false; state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; } } break; case AsyncIteratorState.Iterating: if (doMoveNext && await enumerator.MoveNext(cancellationToken) .ConfigureAwait(false)) { current = enumerator.Current; return true; } if (!doMoveNext) { current = enumerator.Current; doMoveNext = true; return true; } break; } Dispose(); return false; } } } }