// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT License. // See the LICENSE file in the project root for more information. using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerable { #if INCLUDE_SYSTEM_LINQ_ASYNCENUMERABLE_DUPLICATES // https://learn.microsoft.com/en-us/dotnet/api/system.linq.asyncenumerable.where?view=net-9.0-pp#system-linq-asyncenumerable-where-1(system-collections-generic-iasyncenumerable((-0))-system-func((-0-system-boolean))) /// /// Filters the elements of an async-enumerable sequence based on a predicate. /// /// The type of the elements in the source sequence. /// An async-enumerable sequence whose elements to filter. /// A function to test each source element for a condition. /// An async-enumerable sequence that contains elements from the input sequence that satisfy the condition. /// or is null. public static IAsyncEnumerable Where(this IAsyncEnumerable source, Func predicate) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (predicate == null) throw Error.ArgumentNull(nameof(predicate)); if (source is AsyncIteratorBase iterator) { return iterator.Where(predicate); } // TODO: Can we add array/list optimizations here, does it make sense? return new WhereEnumerableAsyncIterator(source, predicate); } // https://learn.microsoft.com/en-us/dotnet/api/system.linq.asyncenumerable.where?view=net-9.0-pp#system-linq-asyncenumerable-where-1(system-collections-generic-iasyncenumerable((-0))-system-func((-0-system-int32-system-boolean))) /// /// Filters the elements of an async-enumerable sequence based on a predicate by incorporating the element's index. /// /// The type of the elements in the source sequence. /// An async-enumerable sequence whose elements to filter. /// A function to test each source element for a condition; the second parameter of the function represents the index of the source element. /// An async-enumerable sequence that contains elements from the input sequence that satisfy the condition. /// or is null. public static IAsyncEnumerable Where(this IAsyncEnumerable source, Func predicate) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (predicate == null) throw Error.ArgumentNull(nameof(predicate)); return Core(source, predicate); static async IAsyncEnumerable Core(IAsyncEnumerable source, Func predicate, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default) { var index = -1; await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) { checked { index++; } if (predicate(element, index)) { yield return element; } } } } #endif // INCLUDE_SYSTEM_LINQ_ASYNCENUMERABLE_DUPLICATES /// /// Filters the elements of an async-enumerable sequence based on an asynchronous predicate. /// /// The type of the elements in the source sequence. /// An async-enumerable sequence whose elements to filter. /// An asynchronous predicate to test each source element for a condition. /// An async-enumerable sequence that contains elements from the input sequence that satisfy the condition. /// or is null. [GenerateAsyncOverload] [Obsolete("Use Where. IAsyncEnumerable LINQ is now in System.Linq.AsyncEnumerable, and the WhereAwait functionality now exists as overloads of Where. You will need to modify your callback to take an additional CancellationToken argument.")] private static IAsyncEnumerable WhereAwaitCore(this IAsyncEnumerable source, Func> predicate) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (predicate == null) throw Error.ArgumentNull(nameof(predicate)); if (source is AsyncIteratorBase iterator) { return iterator.Where(predicate); } // TODO: Can we add array/list optimizations here, does it make sense? return new WhereEnumerableAsyncIteratorWithTask(source, predicate); } #if !NO_DEEP_CANCELLATION [GenerateAsyncOverload] [Obsolete("Use Where. IAsyncEnumerable LINQ is now in System.Linq.AsyncEnumerable, and the WhereAwaitWithCancellation functionality now exists as overloads of Where.")] private static IAsyncEnumerable WhereAwaitWithCancellationCore(this IAsyncEnumerable source, Func> predicate) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (predicate == null) throw Error.ArgumentNull(nameof(predicate)); if (source is AsyncIteratorBase iterator) { return iterator.Where(predicate); } // TODO: Can we add array/list optimizations here, does it make sense? return new WhereEnumerableAsyncIteratorWithTaskAndCancellation(source, predicate); } #endif /// /// Filters the elements of an async-enumerable sequence based on an asynchronous predicate that incorporates the element's index. /// /// The type of the elements in the source sequence. /// An async-enumerable sequence whose elements to filter. /// An asynchronous predicate to test each source element for a condition; the second parameter of the function represents the index of the source element. /// An async-enumerable sequence that contains elements from the input sequence that satisfy the condition. /// or is null. [GenerateAsyncOverload] [Obsolete("Use Where. IAsyncEnumerable LINQ is now in System.Linq.AsyncEnumerable, and the WhereAwait functionality now exists as overloads of Where. You will need to modify your callback to take an additional CancellationToken argument.")] private static IAsyncEnumerable WhereAwaitCore(this IAsyncEnumerable source, Func> predicate) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (predicate == null) throw Error.ArgumentNull(nameof(predicate)); return Core(source, predicate); static async IAsyncEnumerable Core(IAsyncEnumerable source, Func> predicate, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default) { var index = -1; await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) { checked { index++; } if (await predicate(element, index).ConfigureAwait(false)) { yield return element; } } } } #if !NO_DEEP_CANCELLATION [GenerateAsyncOverload] [Obsolete("Use Where. IAsyncEnumerable LINQ is now in System.Linq.AsyncEnumerable, and the WhereAwaitWithCancellation functionality now exists as overloads of Where.")] private static IAsyncEnumerable WhereAwaitWithCancellationCore(this IAsyncEnumerable source, Func> predicate) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (predicate == null) throw Error.ArgumentNull(nameof(predicate)); return Core(source, predicate); static async IAsyncEnumerable Core(IAsyncEnumerable source, Func> predicate, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default) { var index = -1; await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) { checked { index++; } if (await predicate(element, index, cancellationToken).ConfigureAwait(false)) { yield return element; } } } } #endif internal sealed class WhereEnumerableAsyncIterator : AsyncIterator { private readonly Func _predicate; private readonly IAsyncEnumerable _source; private IAsyncEnumerator? _enumerator; public WhereEnumerableAsyncIterator(IAsyncEnumerable source, Func predicate) { _source = source; _predicate = predicate; } public override AsyncIteratorBase Clone() { return new WhereEnumerableAsyncIterator(_source, _predicate); } public override async ValueTask DisposeAsync() { if (_enumerator != null) { await _enumerator.DisposeAsync().ConfigureAwait(false); _enumerator = null; } await base.DisposeAsync().ConfigureAwait(false); } public override IAsyncEnumerable Select(Func selector) { return new WhereSelectEnumerableAsyncIterator(_source, _predicate, selector); } public override IAsyncEnumerable Where(Func predicate) { return new WhereEnumerableAsyncIterator(_source, CombinePredicates(_predicate, predicate)); } protected override async ValueTask MoveNextCore() { switch (_state) { case AsyncIteratorState.Allocated: _enumerator = _source.GetAsyncEnumerator(_cancellationToken); _state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: while (await _enumerator!.MoveNextAsync().ConfigureAwait(false)) { var item = _enumerator.Current; if (_predicate(item)) { _current = item; return true; } } await DisposeAsync().ConfigureAwait(false); break; } return false; } } internal sealed class WhereEnumerableAsyncIteratorWithTask : AsyncIterator { private readonly Func> _predicate; private readonly IAsyncEnumerable _source; private IAsyncEnumerator? _enumerator; public WhereEnumerableAsyncIteratorWithTask(IAsyncEnumerable source, Func> predicate) { _source = source; _predicate = predicate; } public override AsyncIteratorBase Clone() { return new WhereEnumerableAsyncIteratorWithTask(_source, _predicate); } public override async ValueTask DisposeAsync() { if (_enumerator != null) { await _enumerator.DisposeAsync().ConfigureAwait(false); _enumerator = null; } await base.DisposeAsync().ConfigureAwait(false); } public override IAsyncEnumerable Where(Func> predicate) { return new WhereEnumerableAsyncIteratorWithTask(_source, CombinePredicates(_predicate, predicate)); } protected override async ValueTask MoveNextCore() { switch (_state) { case AsyncIteratorState.Allocated: _enumerator = _source.GetAsyncEnumerator(_cancellationToken); _state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: while (await _enumerator!.MoveNextAsync().ConfigureAwait(false)) { var item = _enumerator.Current; if (await _predicate(item).ConfigureAwait(false)) { _current = item; return true; } } await DisposeAsync().ConfigureAwait(false); break; } return false; } } #if !NO_DEEP_CANCELLATION internal sealed class WhereEnumerableAsyncIteratorWithTaskAndCancellation : AsyncIterator { private readonly Func> _predicate; private readonly IAsyncEnumerable _source; private IAsyncEnumerator? _enumerator; public WhereEnumerableAsyncIteratorWithTaskAndCancellation(IAsyncEnumerable source, Func> predicate) { _source = source; _predicate = predicate; } public override AsyncIteratorBase Clone() { return new WhereEnumerableAsyncIteratorWithTaskAndCancellation(_source, _predicate); } public override async ValueTask DisposeAsync() { if (_enumerator != null) { await _enumerator.DisposeAsync().ConfigureAwait(false); _enumerator = null; } await base.DisposeAsync().ConfigureAwait(false); } public override IAsyncEnumerable Where(Func> predicate) { return new WhereEnumerableAsyncIteratorWithTaskAndCancellation(_source, CombinePredicates(_predicate, predicate)); } protected override async ValueTask MoveNextCore() { switch (_state) { case AsyncIteratorState.Allocated: _enumerator = _source.GetAsyncEnumerator(_cancellationToken); _state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: while (await _enumerator!.MoveNextAsync().ConfigureAwait(false)) { var item = _enumerator.Current; if (await _predicate(item, _cancellationToken).ConfigureAwait(false)) { _current = item; return true; } } await DisposeAsync().ConfigureAwait(false); break; } return false; } } #endif private sealed class WhereSelectEnumerableAsyncIterator : AsyncIterator { private readonly Func _predicate; private readonly Func _selector; private readonly IAsyncEnumerable _source; private IAsyncEnumerator? _enumerator; public WhereSelectEnumerableAsyncIterator(IAsyncEnumerable source, Func predicate, Func selector) { _source = source; _predicate = predicate; _selector = selector; } public override AsyncIteratorBase Clone() { return new WhereSelectEnumerableAsyncIterator(_source, _predicate, _selector); } public override async ValueTask DisposeAsync() { if (_enumerator != null) { await _enumerator.DisposeAsync().ConfigureAwait(false); _enumerator = null; } await base.DisposeAsync().ConfigureAwait(false); } public override IAsyncEnumerable Select(Func selector) { return new WhereSelectEnumerableAsyncIterator(_source, _predicate, CombineSelectors(_selector, selector)); } protected override async ValueTask MoveNextCore() { switch (_state) { case AsyncIteratorState.Allocated: _enumerator = _source.GetAsyncEnumerator(_cancellationToken); _state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: while (await _enumerator!.MoveNextAsync().ConfigureAwait(false)) { var item = _enumerator.Current; if (_predicate(item)) { _current = _selector(item); return true; } } await DisposeAsync().ConfigureAwait(false); break; } return false; } } } }