// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace System.Linq
{
public static partial class AsyncEnumerable
{
///
/// Returns elements from an async-enumerable sequence as long as a specified condition is true.
///
/// The type of the elements in the source sequence.
/// A sequence to return elements from.
/// A function to test each element for a condition.
/// An async-enumerable sequence that contains the elements from the input sequence that occur before the element at which the test no longer passes.
/// or is null.
public static IAsyncEnumerable TakeWhile(this IAsyncEnumerable source, Func predicate)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (predicate == null)
throw Error.ArgumentNull(nameof(predicate));
return Create(Core);
async IAsyncEnumerator Core(CancellationToken cancellationToken)
{
await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
if (!predicate(element))
{
break;
}
yield return element;
}
}
}
///
/// Returns elements from an async-enumerable sequence as long as a specified condition is true.
/// The element's index is used in the logic of the predicate function.
///
/// The type of the elements in the source sequence.
/// A sequence to return elements from.
/// A function to test each element for a condition; the second parameter of the function represents the index of the source element.
/// An async-enumerable sequence that contains the elements from the input sequence that occur before the element at which the test no longer passes.
/// or is null.
public static IAsyncEnumerable TakeWhile(this IAsyncEnumerable source, Func predicate)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (predicate == null)
throw Error.ArgumentNull(nameof(predicate));
return Create(Core);
async IAsyncEnumerator Core(CancellationToken cancellationToken)
{
var index = -1;
await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
checked
{
index++;
}
if (!predicate(element, index))
{
break;
}
yield return element;
}
}
}
internal static IAsyncEnumerable TakeWhileAwaitCore(this IAsyncEnumerable source, Func> predicate)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (predicate == null)
throw Error.ArgumentNull(nameof(predicate));
return Create(Core);
async IAsyncEnumerator Core(CancellationToken cancellationToken)
{
await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
if (!await predicate(element).ConfigureAwait(false))
{
break;
}
yield return element;
}
}
}
#if !NO_DEEP_CANCELLATION
internal static IAsyncEnumerable TakeWhileAwaitWithCancellationCore(this IAsyncEnumerable source, Func> predicate)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (predicate == null)
throw Error.ArgumentNull(nameof(predicate));
return Create(Core);
async IAsyncEnumerator Core(CancellationToken cancellationToken)
{
await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
if (!await predicate(element, cancellationToken).ConfigureAwait(false))
{
break;
}
yield return element;
}
}
}
#endif
internal static IAsyncEnumerable TakeWhileAwaitCore(this IAsyncEnumerable source, Func> predicate)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (predicate == null)
throw Error.ArgumentNull(nameof(predicate));
return Create(Core);
async IAsyncEnumerator Core(CancellationToken cancellationToken)
{
var index = -1;
await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
checked
{
index++;
}
if (!await predicate(element, index).ConfigureAwait(false))
{
break;
}
yield return element;
}
}
}
#if !NO_DEEP_CANCELLATION
internal static IAsyncEnumerable TakeWhileAwaitWithCancellationCore(this IAsyncEnumerable source, Func> predicate)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (predicate == null)
throw Error.ArgumentNull(nameof(predicate));
return Create(Core);
async IAsyncEnumerator Core(CancellationToken cancellationToken)
{
var index = -1;
await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
checked
{
index++;
}
if (!await predicate(element, index, cancellationToken).ConfigureAwait(false))
{
break;
}
yield return element;
}
}
}
#endif
}
}