// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System.Collections.Generic; using System.Diagnostics; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerable { public static IAsyncEnumerable SelectMany(this IAsyncEnumerable source, Func> selector) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (selector == null) throw Error.ArgumentNull(nameof(selector)); return new SelectManyAsyncIterator(source, selector); } public static IAsyncEnumerable SelectMany(this IAsyncEnumerable source, Func>> selector) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (selector == null) throw Error.ArgumentNull(nameof(selector)); return new SelectManyAsyncIteratorWithTask(source, selector); } public static IAsyncEnumerable SelectMany(this IAsyncEnumerable source, Func> selector) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (selector == null) throw Error.ArgumentNull(nameof(selector)); return new SelectManyWithIndexAsyncIterator(source, selector); } public static IAsyncEnumerable SelectMany(this IAsyncEnumerable source, Func>> selector) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (selector == null) throw Error.ArgumentNull(nameof(selector)); return new SelectManyWithIndexAsyncIteratorWithTask(source, selector); } public static IAsyncEnumerable SelectMany(this IAsyncEnumerable source, Func> selector, Func resultSelector) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (selector == null) throw Error.ArgumentNull(nameof(selector)); if (resultSelector == null) throw Error.ArgumentNull(nameof(resultSelector)); return new SelectManyAsyncIterator(source, selector, resultSelector); } public static IAsyncEnumerable SelectMany(this IAsyncEnumerable source, Func>> selector, Func> resultSelector) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (selector == null) throw Error.ArgumentNull(nameof(selector)); if (resultSelector == null) throw Error.ArgumentNull(nameof(resultSelector)); return new SelectManyAsyncIteratorWithTask(source, selector, resultSelector); } public static IAsyncEnumerable SelectMany(this IAsyncEnumerable source, Func> selector, Func resultSelector) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (selector == null) throw Error.ArgumentNull(nameof(selector)); if (resultSelector == null) throw Error.ArgumentNull(nameof(resultSelector)); return new SelectManyWithIndexAsyncIterator(source, selector, resultSelector); } public static IAsyncEnumerable SelectMany(this IAsyncEnumerable source, Func>> selector, Func> resultSelector) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (selector == null) throw Error.ArgumentNull(nameof(selector)); if (resultSelector == null) throw Error.ArgumentNull(nameof(resultSelector)); return new SelectManyWithIndexAsyncIteratorWithTask(source, selector, resultSelector); } private sealed class SelectManyAsyncIterator : AsyncIterator { private const int State_Source = 1; private const int State_Result = 2; private readonly Func> _selector; private readonly IAsyncEnumerable _source; private int _mode; private IAsyncEnumerator _resultEnumerator; private IAsyncEnumerator _sourceEnumerator; public SelectManyAsyncIterator(IAsyncEnumerable source, Func> selector) { Debug.Assert(source != null); Debug.Assert(selector != null); _source = source; _selector = selector; } public override AsyncIterator Clone() { return new SelectManyAsyncIterator(_source, _selector); } public override async ValueTask DisposeAsync() { if (_sourceEnumerator != null) { await _sourceEnumerator.DisposeAsync().ConfigureAwait(false); _sourceEnumerator = null; } if (_resultEnumerator != null) { await _resultEnumerator.DisposeAsync().ConfigureAwait(false); _resultEnumerator = null; } await base.DisposeAsync().ConfigureAwait(false); } protected override async ValueTask MoveNextCore(CancellationToken cancellationToken) { switch (state) { case AsyncIteratorState.Allocated: _sourceEnumerator = _source.GetAsyncEnumerator(cancellationToken); _mode = State_Source; state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: switch (_mode) { case State_Source: if (await _sourceEnumerator.MoveNextAsync().ConfigureAwait(false)) { if (_resultEnumerator != null) { await _resultEnumerator.DisposeAsync().ConfigureAwait(false); } var inner = _selector(_sourceEnumerator.Current); _resultEnumerator = inner.GetAsyncEnumerator(cancellationToken); _mode = State_Result; goto case State_Result; } break; case State_Result: if (await _resultEnumerator.MoveNextAsync().ConfigureAwait(false)) { current = _resultEnumerator.Current; return true; } _mode = State_Source; goto case State_Source; // loop } break; } await DisposeAsync().ConfigureAwait(false); return false; } } private sealed class SelectManyAsyncIteratorWithTask : AsyncIterator { private const int State_Source = 1; private const int State_Result = 2; private readonly Func>> _selector; private readonly IAsyncEnumerable _source; private int _mode; private IAsyncEnumerator _resultEnumerator; private IAsyncEnumerator _sourceEnumerator; public SelectManyAsyncIteratorWithTask(IAsyncEnumerable source, Func>> selector) { Debug.Assert(source != null); Debug.Assert(selector != null); _source = source; _selector = selector; } public override AsyncIterator Clone() { return new SelectManyAsyncIteratorWithTask(_source, _selector); } public override async ValueTask DisposeAsync() { if (_sourceEnumerator != null) { await _sourceEnumerator.DisposeAsync().ConfigureAwait(false); _sourceEnumerator = null; } if (_resultEnumerator != null) { await _resultEnumerator.DisposeAsync().ConfigureAwait(false); _resultEnumerator = null; } await base.DisposeAsync().ConfigureAwait(false); } protected override async ValueTask MoveNextCore(CancellationToken cancellationToken) { switch (state) { case AsyncIteratorState.Allocated: _sourceEnumerator = _source.GetAsyncEnumerator(cancellationToken); _mode = State_Source; state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: switch (_mode) { case State_Source: if (await _sourceEnumerator.MoveNextAsync().ConfigureAwait(false)) { if (_resultEnumerator != null) { await _resultEnumerator.DisposeAsync().ConfigureAwait(false); } var inner = await _selector(_sourceEnumerator.Current).ConfigureAwait(false); _resultEnumerator = inner.GetAsyncEnumerator(cancellationToken); _mode = State_Result; goto case State_Result; } break; case State_Result: if (await _resultEnumerator.MoveNextAsync().ConfigureAwait(false)) { current = _resultEnumerator.Current; return true; } _mode = State_Source; goto case State_Source; // loop } break; } await DisposeAsync().ConfigureAwait(false); return false; } } private sealed class SelectManyAsyncIterator : AsyncIterator { private const int State_Source = 1; private const int State_Result = 2; private readonly Func> _collectionSelector; private readonly Func _resultSelector; private readonly IAsyncEnumerable _source; private TSource _currentSource; private int _mode; private IAsyncEnumerator _resultEnumerator; private IAsyncEnumerator _sourceEnumerator; public SelectManyAsyncIterator(IAsyncEnumerable source, Func> collectionSelector, Func resultSelector) { Debug.Assert(source != null); Debug.Assert(collectionSelector != null); Debug.Assert(resultSelector != null); _source = source; _collectionSelector = collectionSelector; _resultSelector = resultSelector; } public override AsyncIterator Clone() { return new SelectManyAsyncIterator(_source, _collectionSelector, _resultSelector); } public override async ValueTask DisposeAsync() { if (_sourceEnumerator != null) { await _sourceEnumerator.DisposeAsync().ConfigureAwait(false); _sourceEnumerator = null; } if (_resultEnumerator != null) { await _resultEnumerator.DisposeAsync().ConfigureAwait(false); _resultEnumerator = null; } _currentSource = default; await base.DisposeAsync().ConfigureAwait(false); } protected override async ValueTask MoveNextCore(CancellationToken cancellationToken) { switch (state) { case AsyncIteratorState.Allocated: _sourceEnumerator = _source.GetAsyncEnumerator(cancellationToken); _mode = State_Source; state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: switch (_mode) { case State_Source: if (await _sourceEnumerator.MoveNextAsync().ConfigureAwait(false)) { if (_resultEnumerator != null) { await _resultEnumerator.DisposeAsync().ConfigureAwait(false); } _currentSource = _sourceEnumerator.Current; var inner = _collectionSelector(_currentSource); _resultEnumerator = inner.GetAsyncEnumerator(cancellationToken); _mode = State_Result; goto case State_Result; } break; case State_Result: if (await _resultEnumerator.MoveNextAsync().ConfigureAwait(false)) { current = _resultSelector(_currentSource, _resultEnumerator.Current); return true; } _mode = State_Source; goto case State_Source; // loop } break; } await DisposeAsync().ConfigureAwait(false); return false; } } private sealed class SelectManyAsyncIteratorWithTask : AsyncIterator { private const int State_Source = 1; private const int State_Result = 2; private readonly Func>> _collectionSelector; private readonly Func> _resultSelector; private readonly IAsyncEnumerable _source; private TSource _currentSource; private int _mode; private IAsyncEnumerator _resultEnumerator; private IAsyncEnumerator _sourceEnumerator; public SelectManyAsyncIteratorWithTask(IAsyncEnumerable source, Func>> collectionSelector, Func> resultSelector) { Debug.Assert(source != null); Debug.Assert(collectionSelector != null); Debug.Assert(resultSelector != null); _source = source; _collectionSelector = collectionSelector; _resultSelector = resultSelector; } public override AsyncIterator Clone() { return new SelectManyAsyncIteratorWithTask(_source, _collectionSelector, _resultSelector); } public override async ValueTask DisposeAsync() { if (_sourceEnumerator != null) { await _sourceEnumerator.DisposeAsync().ConfigureAwait(false); _sourceEnumerator = null; } if (_resultEnumerator != null) { await _resultEnumerator.DisposeAsync().ConfigureAwait(false); _resultEnumerator = null; } _currentSource = default; await base.DisposeAsync().ConfigureAwait(false); } protected override async ValueTask MoveNextCore(CancellationToken cancellationToken) { switch (state) { case AsyncIteratorState.Allocated: _sourceEnumerator = _source.GetAsyncEnumerator(cancellationToken); _mode = State_Source; state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: switch (_mode) { case State_Source: if (await _sourceEnumerator.MoveNextAsync().ConfigureAwait(false)) { if (_resultEnumerator != null) { await _resultEnumerator.DisposeAsync().ConfigureAwait(false); } _currentSource = _sourceEnumerator.Current; var inner = await _collectionSelector(_currentSource).ConfigureAwait(false); _resultEnumerator = inner.GetAsyncEnumerator(cancellationToken); _mode = State_Result; goto case State_Result; } break; case State_Result: if (await _resultEnumerator.MoveNextAsync().ConfigureAwait(false)) { current = await _resultSelector(_currentSource, _resultEnumerator.Current).ConfigureAwait(false); return true; } _mode = State_Source; goto case State_Source; // loop } break; } await DisposeAsync().ConfigureAwait(false); return false; } } private sealed class SelectManyWithIndexAsyncIterator : AsyncIterator { private const int State_Source = 1; private const int State_Result = 2; private readonly Func> _collectionSelector; private readonly Func _resultSelector; private readonly IAsyncEnumerable _source; private TSource _currentSource; private int _index; private int _mode; private IAsyncEnumerator _resultEnumerator; private IAsyncEnumerator _sourceEnumerator; public SelectManyWithIndexAsyncIterator(IAsyncEnumerable source, Func> collectionSelector, Func resultSelector) { Debug.Assert(source != null); Debug.Assert(collectionSelector != null); Debug.Assert(resultSelector != null); _source = source; _collectionSelector = collectionSelector; _resultSelector = resultSelector; } public override AsyncIterator Clone() { return new SelectManyWithIndexAsyncIterator(_source, _collectionSelector, _resultSelector); } public override async ValueTask DisposeAsync() { if (_sourceEnumerator != null) { await _sourceEnumerator.DisposeAsync().ConfigureAwait(false); _sourceEnumerator = null; } if (_resultEnumerator != null) { await _resultEnumerator.DisposeAsync().ConfigureAwait(false); _resultEnumerator = null; } _currentSource = default; await base.DisposeAsync().ConfigureAwait(false); } protected override async ValueTask MoveNextCore(CancellationToken cancellationToken) { switch (state) { case AsyncIteratorState.Allocated: _sourceEnumerator = _source.GetAsyncEnumerator(cancellationToken); _index = -1; _mode = State_Source; state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: switch (_mode) { case State_Source: if (await _sourceEnumerator.MoveNextAsync().ConfigureAwait(false)) { if (_resultEnumerator != null) { await _resultEnumerator.DisposeAsync().ConfigureAwait(false); } _currentSource = _sourceEnumerator.Current; checked { _index++; } var inner = _collectionSelector(_currentSource, _index); _resultEnumerator = inner.GetAsyncEnumerator(cancellationToken); _mode = State_Result; goto case State_Result; } break; case State_Result: if (await _resultEnumerator.MoveNextAsync().ConfigureAwait(false)) { current = _resultSelector(_currentSource, _resultEnumerator.Current); return true; } _mode = State_Source; goto case State_Source; // loop } break; } await DisposeAsync().ConfigureAwait(false); return false; } } private sealed class SelectManyWithIndexAsyncIteratorWithTask : AsyncIterator { private const int State_Source = 1; private const int State_Result = 2; private readonly Func>> _collectionSelector; private readonly Func> _resultSelector; private readonly IAsyncEnumerable _source; private TSource _currentSource; private int _index; private int _mode; private IAsyncEnumerator _resultEnumerator; private IAsyncEnumerator _sourceEnumerator; public SelectManyWithIndexAsyncIteratorWithTask(IAsyncEnumerable source, Func>> collectionSelector, Func> resultSelector) { Debug.Assert(source != null); Debug.Assert(collectionSelector != null); Debug.Assert(resultSelector != null); _source = source; _collectionSelector = collectionSelector; _resultSelector = resultSelector; } public override AsyncIterator Clone() { return new SelectManyWithIndexAsyncIteratorWithTask(_source, _collectionSelector, _resultSelector); } public override async ValueTask DisposeAsync() { if (_sourceEnumerator != null) { await _sourceEnumerator.DisposeAsync().ConfigureAwait(false); _sourceEnumerator = null; } if (_resultEnumerator != null) { await _resultEnumerator.DisposeAsync().ConfigureAwait(false); _resultEnumerator = null; } _currentSource = default; await base.DisposeAsync().ConfigureAwait(false); } protected override async ValueTask MoveNextCore(CancellationToken cancellationToken) { switch (state) { case AsyncIteratorState.Allocated: _sourceEnumerator = _source.GetAsyncEnumerator(cancellationToken); _index = -1; _mode = State_Source; state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: switch (_mode) { case State_Source: if (await _sourceEnumerator.MoveNextAsync().ConfigureAwait(false)) { if (_resultEnumerator != null) { await _resultEnumerator.DisposeAsync().ConfigureAwait(false); } _currentSource = _sourceEnumerator.Current; checked { _index++; } var inner = await _collectionSelector(_currentSource, _index).ConfigureAwait(false); _resultEnumerator = inner.GetAsyncEnumerator(cancellationToken); _mode = State_Result; goto case State_Result; } break; case State_Result: if (await _resultEnumerator.MoveNextAsync().ConfigureAwait(false)) { current = await _resultSelector(_currentSource, _resultEnumerator.Current).ConfigureAwait(false); return true; } _mode = State_Source; goto case State_Source; // loop } break; } await DisposeAsync().ConfigureAwait(false); return false; } } private sealed class SelectManyWithIndexAsyncIterator : AsyncIterator { private const int State_Source = 1; private const int State_Result = 2; private readonly Func> _selector; private readonly IAsyncEnumerable _source; private int _index; private int _mode; private IAsyncEnumerator _resultEnumerator; private IAsyncEnumerator _sourceEnumerator; public SelectManyWithIndexAsyncIterator(IAsyncEnumerable source, Func> selector) { Debug.Assert(source != null); Debug.Assert(selector != null); _source = source; _selector = selector; } public override AsyncIterator Clone() { return new SelectManyWithIndexAsyncIterator(_source, _selector); } public override async ValueTask DisposeAsync() { if (_sourceEnumerator != null) { await _sourceEnumerator.DisposeAsync().ConfigureAwait(false); _sourceEnumerator = null; } if (_resultEnumerator != null) { await _resultEnumerator.DisposeAsync().ConfigureAwait(false); _resultEnumerator = null; } await base.DisposeAsync().ConfigureAwait(false); } protected override async ValueTask MoveNextCore(CancellationToken cancellationToken) { switch (state) { case AsyncIteratorState.Allocated: _sourceEnumerator = _source.GetAsyncEnumerator(cancellationToken); _index = -1; _mode = State_Source; state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: switch (_mode) { case State_Source: if (await _sourceEnumerator.MoveNextAsync().ConfigureAwait(false)) { if (_resultEnumerator != null) { await _resultEnumerator.DisposeAsync().ConfigureAwait(false); } checked { _index++; } var inner = _selector(_sourceEnumerator.Current, _index); _resultEnumerator = inner.GetAsyncEnumerator(cancellationToken); _mode = State_Result; goto case State_Result; } break; case State_Result: if (await _resultEnumerator.MoveNextAsync().ConfigureAwait(false)) { current = _resultEnumerator.Current; return true; } _mode = State_Source; goto case State_Source; // loop } break; } await DisposeAsync().ConfigureAwait(false); return false; } } private sealed class SelectManyWithIndexAsyncIteratorWithTask : AsyncIterator { private const int State_Source = 1; private const int State_Result = 2; private readonly Func>> _selector; private readonly IAsyncEnumerable _source; private int _index; private int _mode; private IAsyncEnumerator _resultEnumerator; private IAsyncEnumerator _sourceEnumerator; public SelectManyWithIndexAsyncIteratorWithTask(IAsyncEnumerable source, Func>> selector) { Debug.Assert(source != null); Debug.Assert(selector != null); _source = source; _selector = selector; } public override AsyncIterator Clone() { return new SelectManyWithIndexAsyncIteratorWithTask(_source, _selector); } public override async ValueTask DisposeAsync() { if (_sourceEnumerator != null) { await _sourceEnumerator.DisposeAsync().ConfigureAwait(false); _sourceEnumerator = null; } if (_resultEnumerator != null) { await _resultEnumerator.DisposeAsync().ConfigureAwait(false); _resultEnumerator = null; } await base.DisposeAsync().ConfigureAwait(false); } protected override async ValueTask MoveNextCore(CancellationToken cancellationToken) { switch (state) { case AsyncIteratorState.Allocated: _sourceEnumerator = _source.GetAsyncEnumerator(cancellationToken); _index = -1; _mode = State_Source; state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: switch (_mode) { case State_Source: if (await _sourceEnumerator.MoveNextAsync().ConfigureAwait(false)) { if (_resultEnumerator != null) { await _resultEnumerator.DisposeAsync().ConfigureAwait(false); } checked { _index++; } var inner = await _selector(_sourceEnumerator.Current, _index).ConfigureAwait(false); _resultEnumerator = inner.GetAsyncEnumerator(cancellationToken); _mode = State_Result; goto case State_Result; } break; case State_Result: if (await _resultEnumerator.MoveNextAsync().ConfigureAwait(false)) { current = _resultEnumerator.Current; return true; } _mode = State_Source; goto case State_Source; // loop } break; } await DisposeAsync().ConfigureAwait(false); return false; } } } }