// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT License.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace System.Linq
{
public static partial class AsyncEnumerable
{
///
/// Projects each element of an async-enumerable sequence to an async-enumerable sequence and merges the resulting async-enumerable sequences into one async-enumerable sequence.
///
/// The type of the elements in the source sequence.
/// The type of the elements in the projected inner sequences and the elements in the merged result sequence.
/// An async-enumerable sequence of elements to project.
/// A transform function to apply to each element.
/// An async-enumerable sequence whose elements are the result of invoking the one-to-many transform function on each element of the input sequence.
/// or is null.
public static IAsyncEnumerable SelectMany(this IAsyncEnumerable source, Func> selector)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (selector == null)
throw Error.ArgumentNull(nameof(selector));
return new SelectManyAsyncIterator(source, selector);
}
// REVIEW: Should we keep these overloads that return ValueTask>? One could argue the selector is async twice.
internal static IAsyncEnumerable SelectManyAwaitCore(this IAsyncEnumerable source, Func>> selector)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (selector == null)
throw Error.ArgumentNull(nameof(selector));
return new SelectManyAsyncIteratorWithTask(source, selector);
}
#if !NO_DEEP_CANCELLATION
internal static IAsyncEnumerable SelectManyAwaitWithCancellationCore(this IAsyncEnumerable source, Func>> selector)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (selector == null)
throw Error.ArgumentNull(nameof(selector));
return new SelectManyAsyncIteratorWithTaskAndCancellation(source, selector);
}
#endif
///
/// Projects each element of an async-enumerable sequence to an async-enumerable sequence by incorporating the element's index and merges the resulting async-enumerable sequences into one async-enumerable sequence.
///
/// The type of the elements in the source sequence.
/// The type of the elements in the projected inner sequences and the elements in the merged result sequence.
/// An async-enumerable sequence of elements to project.
/// A transform function to apply to each element; the second parameter of the function represents the index of the source element.
/// An async-enumerable sequence whose elements are the result of invoking the one-to-many transform function on each element of the input sequence.
/// or is null.
public static IAsyncEnumerable SelectMany(this IAsyncEnumerable source, Func> selector)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (selector == null)
throw Error.ArgumentNull(nameof(selector));
#if HAS_ASYNC_ENUMERABLE_CANCELLATION
return Core();
async IAsyncEnumerable Core([System.Runtime.CompilerServices.EnumeratorCancellation]CancellationToken cancellationToken = default)
#else
return Create(Core);
async IAsyncEnumerator Core(CancellationToken cancellationToken)
#endif
{
var index = -1;
await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
checked
{
index++;
}
var inner = selector(element, index);
await foreach (var subElement in inner.WithCancellation(cancellationToken).ConfigureAwait(false))
{
yield return subElement;
}
}
}
}
internal static IAsyncEnumerable SelectManyAwaitCore(this IAsyncEnumerable source, Func>> selector)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (selector == null)
throw Error.ArgumentNull(nameof(selector));
#if HAS_ASYNC_ENUMERABLE_CANCELLATION
return Core();
async IAsyncEnumerable Core([System.Runtime.CompilerServices.EnumeratorCancellation]CancellationToken cancellationToken = default)
#else
return Create(Core);
async IAsyncEnumerator Core(CancellationToken cancellationToken)
#endif
{
var index = -1;
await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
checked
{
index++;
}
var inner = await selector(element, index).ConfigureAwait(false);
await foreach (var subElement in inner.WithCancellation(cancellationToken).ConfigureAwait(false))
{
yield return subElement;
}
}
}
}
#if !NO_DEEP_CANCELLATION
internal static IAsyncEnumerable SelectManyAwaitWithCancellationCore(this IAsyncEnumerable source, Func>> selector)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (selector == null)
throw Error.ArgumentNull(nameof(selector));
#if HAS_ASYNC_ENUMERABLE_CANCELLATION
return Core();
async IAsyncEnumerable Core([System.Runtime.CompilerServices.EnumeratorCancellation]CancellationToken cancellationToken = default)
#else
return Create(Core);
async IAsyncEnumerator Core(CancellationToken cancellationToken)
#endif
{
var index = -1;
await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
checked
{
index++;
}
var inner = await selector(element, index, cancellationToken).ConfigureAwait(false);
await foreach (var subElement in inner.WithCancellation(cancellationToken).ConfigureAwait(false))
{
yield return subElement;
}
}
}
}
#endif
///
/// Projects each element of an async-enumerable sequence to an async-enumerable sequence, invokes the result selector for the source element and each of the corresponding inner sequence's elements, and merges the results into one async-enumerable sequence.
///
/// The type of the elements in the source sequence.
/// The type of the elements in the projected intermediate sequences.
/// The type of the elements in the result sequence, obtained by using the selector to combine source sequence elements with their corresponding intermediate sequence elements.
/// An async-enumerable sequence of elements to project.
/// A transform function to apply to each element.
/// A transform function to apply to each element of the intermediate sequence.
/// An async-enumerable sequence whose elements are the result of invoking the one-to-many transform function collectionSelector on each element of the input sequence and then mapping each of those sequence elements and their corresponding source element to a result element.
/// or or is null.
public static IAsyncEnumerable SelectMany(this IAsyncEnumerable source, Func> collectionSelector, Func resultSelector)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (collectionSelector == null)
throw Error.ArgumentNull(nameof(collectionSelector));
if (resultSelector == null)
throw Error.ArgumentNull(nameof(resultSelector));
#if HAS_ASYNC_ENUMERABLE_CANCELLATION
return Core();
async IAsyncEnumerable Core([System.Runtime.CompilerServices.EnumeratorCancellation]CancellationToken cancellationToken = default)
#else
return Create(Core);
async IAsyncEnumerator Core(CancellationToken cancellationToken)
#endif
{
await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
var inner = collectionSelector(element);
await foreach (var subElement in inner.WithCancellation(cancellationToken).ConfigureAwait(false))
{
yield return resultSelector(element, subElement);
}
}
}
}
internal static IAsyncEnumerable SelectManyAwaitCore(this IAsyncEnumerable source, Func>> collectionSelector, Func> resultSelector)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (collectionSelector == null)
throw Error.ArgumentNull(nameof(collectionSelector));
if (resultSelector == null)
throw Error.ArgumentNull(nameof(resultSelector));
#if HAS_ASYNC_ENUMERABLE_CANCELLATION
return Core();
async IAsyncEnumerable Core([System.Runtime.CompilerServices.EnumeratorCancellation]CancellationToken cancellationToken = default)
#else
return Create(Core);
async IAsyncEnumerator Core(CancellationToken cancellationToken)
#endif
{
await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
var inner = await collectionSelector(element).ConfigureAwait(false);
await foreach (var subElement in inner.WithCancellation(cancellationToken).ConfigureAwait(false))
{
yield return await resultSelector(element, subElement).ConfigureAwait(false);
}
}
}
}
#if !NO_DEEP_CANCELLATION
internal static IAsyncEnumerable SelectManyAwaitWithCancellationCore(this IAsyncEnumerable source, Func>> collectionSelector, Func> resultSelector)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (collectionSelector == null)
throw Error.ArgumentNull(nameof(collectionSelector));
if (resultSelector == null)
throw Error.ArgumentNull(nameof(resultSelector));
#if HAS_ASYNC_ENUMERABLE_CANCELLATION
return Core();
async IAsyncEnumerable Core([System.Runtime.CompilerServices.EnumeratorCancellation]CancellationToken cancellationToken = default)
#else
return Create(Core);
async IAsyncEnumerator Core(CancellationToken cancellationToken)
#endif
{
await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
var inner = await collectionSelector(element, cancellationToken).ConfigureAwait(false);
await foreach (var subElement in inner.WithCancellation(cancellationToken).ConfigureAwait(false))
{
yield return await resultSelector(element, subElement, cancellationToken).ConfigureAwait(false);
}
}
}
}
#endif
///
/// Projects each element of an async-enumerable sequence to an async-enumerable sequence by incorporating the element's index, invokes the result selector for the source element and each of the corresponding inner sequence's elements, and merges the results into one async-enumerable sequence.
///
/// The type of the elements in the source sequence.
/// The type of the elements in the projected intermediate sequences.
/// The type of the elements in the result sequence, obtained by using the selector to combine source sequence elements with their corresponding intermediate sequence elements.
/// An async-enumerable sequence of elements to project.
/// A transform function to apply to each element; the second parameter of the function represents the index of the source element.
/// A transform function to apply to each element of the intermediate sequence; the second parameter of the function represents the index of the source element and the fourth parameter represents the index of the intermediate element.
/// An async-enumerable sequence whose elements are the result of invoking the one-to-many transform function collectionSelector on each element of the input sequence and then mapping each of those sequence elements and their corresponding source element to a result element.
/// or or is null.
public static IAsyncEnumerable SelectMany(this IAsyncEnumerable source, Func> collectionSelector, Func resultSelector)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (collectionSelector == null)
throw Error.ArgumentNull(nameof(collectionSelector));
if (resultSelector == null)
throw Error.ArgumentNull(nameof(resultSelector));
#if HAS_ASYNC_ENUMERABLE_CANCELLATION
return Core();
async IAsyncEnumerable Core([System.Runtime.CompilerServices.EnumeratorCancellation]CancellationToken cancellationToken = default)
#else
return Create(Core);
async IAsyncEnumerator Core(CancellationToken cancellationToken)
#endif
{
var index = -1;
await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
checked
{
index++;
}
var inner = collectionSelector(element, index);
await foreach (var subElement in inner.WithCancellation(cancellationToken).ConfigureAwait(false))
{
yield return resultSelector(element, subElement);
}
}
}
}
internal static IAsyncEnumerable SelectManyAwaitCore(this IAsyncEnumerable source, Func>> collectionSelector, Func> resultSelector)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (collectionSelector == null)
throw Error.ArgumentNull(nameof(collectionSelector));
if (resultSelector == null)
throw Error.ArgumentNull(nameof(resultSelector));
#if HAS_ASYNC_ENUMERABLE_CANCELLATION
return Core();
async IAsyncEnumerable Core([System.Runtime.CompilerServices.EnumeratorCancellation]CancellationToken cancellationToken = default)
#else
return Create(Core);
async IAsyncEnumerator Core(CancellationToken cancellationToken)
#endif
{
var index = -1;
await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
checked
{
index++;
}
var inner = await collectionSelector(element, index).ConfigureAwait(false);
await foreach (var subElement in inner.WithCancellation(cancellationToken).ConfigureAwait(false))
{
yield return await resultSelector(element, subElement).ConfigureAwait(false);
}
}
}
}
#if !NO_DEEP_CANCELLATION
internal static IAsyncEnumerable SelectManyAwaitWithCancellationCore(this IAsyncEnumerable source, Func>> collectionSelector, Func> resultSelector)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (collectionSelector == null)
throw Error.ArgumentNull(nameof(collectionSelector));
if (resultSelector == null)
throw Error.ArgumentNull(nameof(resultSelector));
#if HAS_ASYNC_ENUMERABLE_CANCELLATION
return Core();
async IAsyncEnumerable Core([System.Runtime.CompilerServices.EnumeratorCancellation]CancellationToken cancellationToken = default)
#else
return Create(Core);
async IAsyncEnumerator Core(CancellationToken cancellationToken)
#endif
{
var index = -1;
await foreach (var element in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
checked
{
index++;
}
var inner = await collectionSelector(element, index, cancellationToken).ConfigureAwait(false);
await foreach (var subElement in inner.WithCancellation(cancellationToken).ConfigureAwait(false))
{
yield return await resultSelector(element, subElement, cancellationToken).ConfigureAwait(false);
}
}
}
}
#endif
private sealed class SelectManyAsyncIterator : AsyncIterator, IAsyncIListProvider
{
private const int State_Source = 1;
private const int State_Result = 2;
private readonly Func> _selector;
private readonly IAsyncEnumerable _source;
private int _mode;
private IAsyncEnumerator? _resultEnumerator;
private IAsyncEnumerator? _sourceEnumerator;
public SelectManyAsyncIterator(IAsyncEnumerable source, Func> selector)
{
_source = source;
_selector = selector;
}
public override AsyncIteratorBase Clone()
{
return new SelectManyAsyncIterator(_source, _selector);
}
public override async ValueTask DisposeAsync()
{
if (_resultEnumerator != null)
{
await _resultEnumerator.DisposeAsync().ConfigureAwait(false);
_resultEnumerator = null;
}
if (_sourceEnumerator != null)
{
await _sourceEnumerator.DisposeAsync().ConfigureAwait(false);
_sourceEnumerator = null;
}
await base.DisposeAsync().ConfigureAwait(false);
}
public ValueTask GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
{
if (onlyIfCheap)
{
return new ValueTask(-1);
}
return Core(cancellationToken);
async ValueTask Core(CancellationToken cancellationToken)
{
var count = 0;
await foreach (var element in _source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
checked
{
count += await _selector(element).CountAsync().ConfigureAwait(false);
}
}
return count;
}
}
public async ValueTask ToArrayAsync(CancellationToken cancellationToken)
{
// REVIEW: Substitute for SparseArrayBuilder logic once we have access to that.
var list = await ToListAsync(cancellationToken).ConfigureAwait(false);
return list.ToArray();
}
public async ValueTask> ToListAsync(CancellationToken cancellationToken)
{
var list = new List();
await foreach (var element in _source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
var items = _selector(element);
await list.AddRangeAsync(items, cancellationToken).ConfigureAwait(false);
}
return list;
}
protected override async ValueTask MoveNextCore()
{
switch (_state)
{
case AsyncIteratorState.Allocated:
_sourceEnumerator = _source.GetAsyncEnumerator(_cancellationToken);
_mode = State_Source;
_state = AsyncIteratorState.Iterating;
goto case AsyncIteratorState.Iterating;
case AsyncIteratorState.Iterating:
switch (_mode)
{
case State_Source:
if (await _sourceEnumerator!.MoveNextAsync().ConfigureAwait(false))
{
if (_resultEnumerator != null)
{
await _resultEnumerator.DisposeAsync().ConfigureAwait(false);
}
var inner = _selector(_sourceEnumerator.Current);
_resultEnumerator = inner.GetAsyncEnumerator(_cancellationToken);
_mode = State_Result;
goto case State_Result;
}
break;
case State_Result:
if (await _resultEnumerator!.MoveNextAsync().ConfigureAwait(false))
{
_current = _resultEnumerator.Current;
return true;
}
_mode = State_Source;
goto case State_Source; // loop
}
break;
}
await DisposeAsync().ConfigureAwait(false);
return false;
}
}
private sealed class SelectManyAsyncIteratorWithTask : AsyncIterator, IAsyncIListProvider
{
private const int State_Source = 1;
private const int State_Result = 2;
private readonly Func>> _selector;
private readonly IAsyncEnumerable _source;
private int _mode;
private IAsyncEnumerator? _resultEnumerator;
private IAsyncEnumerator? _sourceEnumerator;
public SelectManyAsyncIteratorWithTask(IAsyncEnumerable source, Func>> selector)
{
_source = source;
_selector = selector;
}
public override AsyncIteratorBase Clone()
{
return new SelectManyAsyncIteratorWithTask(_source, _selector);
}
public override async ValueTask DisposeAsync()
{
if (_resultEnumerator != null)
{
await _resultEnumerator.DisposeAsync().ConfigureAwait(false);
_resultEnumerator = null;
}
if (_sourceEnumerator != null)
{
await _sourceEnumerator.DisposeAsync().ConfigureAwait(false);
_sourceEnumerator = null;
}
await base.DisposeAsync().ConfigureAwait(false);
}
public ValueTask GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
{
if (onlyIfCheap)
{
return new ValueTask(-1);
}
return Core(cancellationToken);
async ValueTask Core(CancellationToken cancellationToken)
{
var count = 0;
await foreach (var element in _source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
var items = await _selector(element).ConfigureAwait(false);
checked
{
count += await items.CountAsync().ConfigureAwait(false);
}
}
return count;
}
}
public async ValueTask ToArrayAsync(CancellationToken cancellationToken)
{
// REVIEW: Substitute for SparseArrayBuilder logic once we have access to that.
var list = await ToListAsync(cancellationToken).ConfigureAwait(false);
return list.ToArray();
}
public async ValueTask> ToListAsync(CancellationToken cancellationToken)
{
var list = new List();
await foreach (var element in _source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
var items = await _selector(element).ConfigureAwait(false);
await list.AddRangeAsync(items, cancellationToken).ConfigureAwait(false);
}
return list;
}
protected override async ValueTask MoveNextCore()
{
switch (_state)
{
case AsyncIteratorState.Allocated:
_sourceEnumerator = _source.GetAsyncEnumerator(_cancellationToken);
_mode = State_Source;
_state = AsyncIteratorState.Iterating;
goto case AsyncIteratorState.Iterating;
case AsyncIteratorState.Iterating:
switch (_mode)
{
case State_Source:
if (await _sourceEnumerator!.MoveNextAsync().ConfigureAwait(false))
{
if (_resultEnumerator != null)
{
await _resultEnumerator.DisposeAsync().ConfigureAwait(false);
}
var inner = await _selector(_sourceEnumerator.Current).ConfigureAwait(false);
_resultEnumerator = inner.GetAsyncEnumerator(_cancellationToken);
_mode = State_Result;
goto case State_Result;
}
break;
case State_Result:
if (await _resultEnumerator!.MoveNextAsync().ConfigureAwait(false))
{
_current = _resultEnumerator.Current;
return true;
}
_mode = State_Source;
goto case State_Source; // loop
}
break;
}
await DisposeAsync().ConfigureAwait(false);
return false;
}
}
#if !NO_DEEP_CANCELLATION
private sealed class SelectManyAsyncIteratorWithTaskAndCancellation : AsyncIterator, IAsyncIListProvider
{
private const int State_Source = 1;
private const int State_Result = 2;
private readonly Func>> _selector;
private readonly IAsyncEnumerable _source;
private int _mode;
private IAsyncEnumerator? _resultEnumerator;
private IAsyncEnumerator? _sourceEnumerator;
public SelectManyAsyncIteratorWithTaskAndCancellation(IAsyncEnumerable source, Func>> selector)
{
_source = source;
_selector = selector;
}
public override AsyncIteratorBase Clone()
{
return new SelectManyAsyncIteratorWithTaskAndCancellation(_source, _selector);
}
public override async ValueTask DisposeAsync()
{
if (_resultEnumerator != null)
{
await _resultEnumerator.DisposeAsync().ConfigureAwait(false);
_resultEnumerator = null;
}
if (_sourceEnumerator != null)
{
await _sourceEnumerator.DisposeAsync().ConfigureAwait(false);
_sourceEnumerator = null;
}
await base.DisposeAsync().ConfigureAwait(false);
}
public ValueTask GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
{
if (onlyIfCheap)
{
return new ValueTask(-1);
}
return Core(cancellationToken);
async ValueTask Core(CancellationToken cancellationToken)
{
var count = 0;
await foreach (var element in _source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
var items = await _selector(element, cancellationToken).ConfigureAwait(false);
checked
{
count += await items.CountAsync().ConfigureAwait(false);
}
}
return count;
}
}
public async ValueTask ToArrayAsync(CancellationToken cancellationToken)
{
// REVIEW: Substitute for SparseArrayBuilder logic once we have access to that.
var list = await ToListAsync(cancellationToken).ConfigureAwait(false);
return list.ToArray();
}
public async ValueTask> ToListAsync(CancellationToken cancellationToken)
{
var list = new List();
await foreach (var element in _source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
var items = await _selector(element, cancellationToken).ConfigureAwait(false);
await list.AddRangeAsync(items, cancellationToken).ConfigureAwait(false);
}
return list;
}
protected override async ValueTask MoveNextCore()
{
switch (_state)
{
case AsyncIteratorState.Allocated:
_sourceEnumerator = _source.GetAsyncEnumerator(_cancellationToken);
_mode = State_Source;
_state = AsyncIteratorState.Iterating;
goto case AsyncIteratorState.Iterating;
case AsyncIteratorState.Iterating:
switch (_mode)
{
case State_Source:
if (await _sourceEnumerator!.MoveNextAsync().ConfigureAwait(false))
{
if (_resultEnumerator != null)
{
await _resultEnumerator.DisposeAsync().ConfigureAwait(false);
}
var inner = await _selector(_sourceEnumerator.Current, _cancellationToken).ConfigureAwait(false);
_resultEnumerator = inner.GetAsyncEnumerator(_cancellationToken);
_mode = State_Result;
goto case State_Result;
}
break;
case State_Result:
if (await _resultEnumerator!.MoveNextAsync().ConfigureAwait(false))
{
_current = _resultEnumerator.Current;
return true;
}
_mode = State_Source;
goto case State_Source; // loop
}
break;
}
await DisposeAsync().ConfigureAwait(false);
return false;
}
}
#endif
}
}