// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT License.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace System.Linq
{
public static partial class AsyncEnumerable
{
///
/// Filters the elements of an async-enumerable sequence based on a predicate.
///
/// The type of the elements in the source sequence.
/// An async-enumerable sequence whose elements to filter.
/// A function to test each source element for a condition.
/// An async-enumerable sequence that contains elements from the input sequence that satisfy the condition.
/// or is null.
public static IAsyncEnumerable Where(this IAsyncEnumerable source, Func predicate)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (predicate == null)
throw Error.ArgumentNull(nameof(predicate));
if (source is AsyncIteratorBase iterator)
{
return iterator.Where(predicate);
}
// TODO: Can we add array/list optimizations here, does it make sense?
return new WhereEnumerableAsyncIterator(source, predicate);
}
///
/// Filters the elements of an async-enumerable sequence based on a predicate by incorporating the element's index.
///
/// The type of the elements in the source sequence.
/// An async-enumerable sequence whose elements to filter.
/// A function to test each source element for a condition; the second parameter of the function represents the index of the source element.
/// An async-enumerable sequence that contains elements from the input sequence that satisfy the condition.
/// or is null.
public static IAsyncEnumerable Where(this IAsyncEnumerable source, Func predicate)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (predicate == null)
throw Error.ArgumentNull(nameof(predicate));
#if HAS_ASYNC_ENUMERABLE_CANCELLATION
return Core();
async IAsyncEnumerable Core([System.Runtime.CompilerServices.EnumeratorCancellation]CancellationToken cancellationToken = default)
#else
return Create(Core);
async IAsyncEnumerator Core(CancellationToken cancellationToken)
#endif
{
var index = -1;
await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
checked
{
index++;
}
if (predicate(element, index))
{
yield return element;
}
}
}
}
internal static IAsyncEnumerable WhereAwaitCore(this IAsyncEnumerable source, Func> predicate)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (predicate == null)
throw Error.ArgumentNull(nameof(predicate));
if (source is AsyncIteratorBase iterator)
{
return iterator.Where(predicate);
}
// TODO: Can we add array/list optimizations here, does it make sense?
return new WhereEnumerableAsyncIteratorWithTask(source, predicate);
}
#if !NO_DEEP_CANCELLATION
internal static IAsyncEnumerable WhereAwaitWithCancellationCore(this IAsyncEnumerable source, Func> predicate)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (predicate == null)
throw Error.ArgumentNull(nameof(predicate));
if (source is AsyncIteratorBase iterator)
{
return iterator.Where(predicate);
}
// TODO: Can we add array/list optimizations here, does it make sense?
return new WhereEnumerableAsyncIteratorWithTaskAndCancellation(source, predicate);
}
#endif
internal static IAsyncEnumerable WhereAwaitCore(this IAsyncEnumerable source, Func> predicate)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (predicate == null)
throw Error.ArgumentNull(nameof(predicate));
#if HAS_ASYNC_ENUMERABLE_CANCELLATION
return Core();
async IAsyncEnumerable Core([System.Runtime.CompilerServices.EnumeratorCancellation]CancellationToken cancellationToken = default)
#else
return Create(Core);
async IAsyncEnumerator Core(CancellationToken cancellationToken)
#endif
{
var index = -1;
await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
checked
{
index++;
}
if (await predicate(element, index).ConfigureAwait(false))
{
yield return element;
}
}
}
}
#if !NO_DEEP_CANCELLATION
internal static IAsyncEnumerable WhereAwaitWithCancellationCore(this IAsyncEnumerable source, Func> predicate)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (predicate == null)
throw Error.ArgumentNull(nameof(predicate));
#if HAS_ASYNC_ENUMERABLE_CANCELLATION
return Core();
async IAsyncEnumerable Core([System.Runtime.CompilerServices.EnumeratorCancellation]CancellationToken cancellationToken = default)
#else
return Create(Core);
async IAsyncEnumerator Core(CancellationToken cancellationToken)
#endif
{
var index = -1;
await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
checked
{
index++;
}
if (await predicate(element, index, cancellationToken).ConfigureAwait(false))
{
yield return element;
}
}
}
}
#endif
internal sealed class WhereEnumerableAsyncIterator : AsyncIterator
{
private readonly Func _predicate;
private readonly IAsyncEnumerable _source;
private IAsyncEnumerator? _enumerator;
public WhereEnumerableAsyncIterator(IAsyncEnumerable source, Func predicate)
{
_source = source;
_predicate = predicate;
}
public override AsyncIteratorBase Clone()
{
return new WhereEnumerableAsyncIterator(_source, _predicate);
}
public override async ValueTask DisposeAsync()
{
if (_enumerator != null)
{
await _enumerator.DisposeAsync().ConfigureAwait(false);
_enumerator = null;
}
await base.DisposeAsync().ConfigureAwait(false);
}
public override IAsyncEnumerable Select(Func selector)
{
return new WhereSelectEnumerableAsyncIterator(_source, _predicate, selector);
}
public override IAsyncEnumerable Where(Func predicate)
{
return new WhereEnumerableAsyncIterator(_source, CombinePredicates(_predicate, predicate));
}
protected override async ValueTask MoveNextCore()
{
switch (_state)
{
case AsyncIteratorState.Allocated:
_enumerator = _source.GetAsyncEnumerator(_cancellationToken);
_state = AsyncIteratorState.Iterating;
goto case AsyncIteratorState.Iterating;
case AsyncIteratorState.Iterating:
while (await _enumerator!.MoveNextAsync().ConfigureAwait(false))
{
var item = _enumerator.Current;
if (_predicate(item))
{
_current = item;
return true;
}
}
await DisposeAsync().ConfigureAwait(false);
break;
}
return false;
}
}
internal sealed class WhereEnumerableAsyncIteratorWithTask : AsyncIterator
{
private readonly Func> _predicate;
private readonly IAsyncEnumerable _source;
private IAsyncEnumerator? _enumerator;
public WhereEnumerableAsyncIteratorWithTask(IAsyncEnumerable source, Func> predicate)
{
_source = source;
_predicate = predicate;
}
public override AsyncIteratorBase Clone()
{
return new WhereEnumerableAsyncIteratorWithTask(_source, _predicate);
}
public override async ValueTask DisposeAsync()
{
if (_enumerator != null)
{
await _enumerator.DisposeAsync().ConfigureAwait(false);
_enumerator = null;
}
await base.DisposeAsync().ConfigureAwait(false);
}
public override IAsyncEnumerable Where(Func> predicate)
{
return new WhereEnumerableAsyncIteratorWithTask(_source, CombinePredicates(_predicate, predicate));
}
protected override async ValueTask MoveNextCore()
{
switch (_state)
{
case AsyncIteratorState.Allocated:
_enumerator = _source.GetAsyncEnumerator(_cancellationToken);
_state = AsyncIteratorState.Iterating;
goto case AsyncIteratorState.Iterating;
case AsyncIteratorState.Iterating:
while (await _enumerator!.MoveNextAsync().ConfigureAwait(false))
{
var item = _enumerator.Current;
if (await _predicate(item).ConfigureAwait(false))
{
_current = item;
return true;
}
}
await DisposeAsync().ConfigureAwait(false);
break;
}
return false;
}
}
#if !NO_DEEP_CANCELLATION
internal sealed class WhereEnumerableAsyncIteratorWithTaskAndCancellation : AsyncIterator
{
private readonly Func> _predicate;
private readonly IAsyncEnumerable _source;
private IAsyncEnumerator? _enumerator;
public WhereEnumerableAsyncIteratorWithTaskAndCancellation(IAsyncEnumerable source, Func> predicate)
{
_source = source;
_predicate = predicate;
}
public override AsyncIteratorBase Clone()
{
return new WhereEnumerableAsyncIteratorWithTaskAndCancellation(_source, _predicate);
}
public override async ValueTask DisposeAsync()
{
if (_enumerator != null)
{
await _enumerator.DisposeAsync().ConfigureAwait(false);
_enumerator = null;
}
await base.DisposeAsync().ConfigureAwait(false);
}
public override IAsyncEnumerable Where(Func> predicate)
{
return new WhereEnumerableAsyncIteratorWithTaskAndCancellation(_source, CombinePredicates(_predicate, predicate));
}
protected override async ValueTask MoveNextCore()
{
switch (_state)
{
case AsyncIteratorState.Allocated:
_enumerator = _source.GetAsyncEnumerator(_cancellationToken);
_state = AsyncIteratorState.Iterating;
goto case AsyncIteratorState.Iterating;
case AsyncIteratorState.Iterating:
while (await _enumerator!.MoveNextAsync().ConfigureAwait(false))
{
var item = _enumerator.Current;
if (await _predicate(item, _cancellationToken).ConfigureAwait(false))
{
_current = item;
return true;
}
}
await DisposeAsync().ConfigureAwait(false);
break;
}
return false;
}
}
#endif
private sealed class WhereSelectEnumerableAsyncIterator : AsyncIterator
{
private readonly Func _predicate;
private readonly Func _selector;
private readonly IAsyncEnumerable _source;
private IAsyncEnumerator? _enumerator;
public WhereSelectEnumerableAsyncIterator(IAsyncEnumerable source, Func predicate, Func selector)
{
_source = source;
_predicate = predicate;
_selector = selector;
}
public override AsyncIteratorBase Clone()
{
return new WhereSelectEnumerableAsyncIterator(_source, _predicate, _selector);
}
public override async ValueTask DisposeAsync()
{
if (_enumerator != null)
{
await _enumerator.DisposeAsync().ConfigureAwait(false);
_enumerator = null;
}
await base.DisposeAsync().ConfigureAwait(false);
}
public override IAsyncEnumerable Select(Func selector)
{
return new WhereSelectEnumerableAsyncIterator(_source, _predicate, CombineSelectors(_selector, selector));
}
protected override async ValueTask MoveNextCore()
{
switch (_state)
{
case AsyncIteratorState.Allocated:
_enumerator = _source.GetAsyncEnumerator(_cancellationToken);
_state = AsyncIteratorState.Iterating;
goto case AsyncIteratorState.Iterating;
case AsyncIteratorState.Iterating:
while (await _enumerator!.MoveNextAsync().ConfigureAwait(false))
{
var item = _enumerator.Current;
if (_predicate(item))
{
_current = _selector(item);
return true;
}
}
await DisposeAsync().ConfigureAwait(false);
break;
}
return false;
}
}
}
}