// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System.Collections; using System.Collections.Generic; using System.Diagnostics; using System.Runtime.ExceptionServices; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerable { public static IAsyncEnumerable> GroupBy(this IAsyncEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (elementSelector == null) throw new ArgumentNullException(nameof(elementSelector)); if (comparer == null) throw new ArgumentNullException(nameof(comparer)); return new GroupedAsyncEnumerable(source, keySelector, elementSelector, comparer); } public static IAsyncEnumerable> GroupBy(this IAsyncEnumerable source, Func keySelector, Func elementSelector) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (elementSelector == null) throw new ArgumentNullException(nameof(elementSelector)); return source.GroupBy(keySelector, elementSelector, EqualityComparer.Default); } public static IAsyncEnumerable> GroupBy(this IAsyncEnumerable source, Func keySelector, IEqualityComparer comparer) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (comparer == null) throw new ArgumentNullException(nameof(comparer)); return new GroupedAsyncEnumerable(source, keySelector, comparer); } public static IAsyncEnumerable> GroupBy(this IAsyncEnumerable source, Func keySelector) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); return new GroupedAsyncEnumerable(source, keySelector, EqualityComparer.Default); } public static IAsyncEnumerable GroupBy(this IAsyncEnumerable source, Func keySelector, Func elementSelector, Func, TResult> resultSelector, IEqualityComparer comparer) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (elementSelector == null) throw new ArgumentNullException(nameof(elementSelector)); if (resultSelector == null) throw new ArgumentNullException(nameof(resultSelector)); if (comparer == null) throw new ArgumentNullException(nameof(comparer)); return source.GroupBy(keySelector, elementSelector, comparer) .Select(g => resultSelector(g.Key, g)); } public static IAsyncEnumerable GroupBy(this IAsyncEnumerable source, Func keySelector, Func elementSelector, Func, TResult> resultSelector) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (elementSelector == null) throw new ArgumentNullException(nameof(elementSelector)); if (resultSelector == null) throw new ArgumentNullException(nameof(resultSelector)); return source.GroupBy(keySelector, elementSelector, EqualityComparer.Default) .Select(g => resultSelector(g.Key, g)); } public static IAsyncEnumerable GroupBy(this IAsyncEnumerable source, Func keySelector, Func, TResult> resultSelector, IEqualityComparer comparer) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (resultSelector == null) throw new ArgumentNullException(nameof(resultSelector)); if (comparer == null) throw new ArgumentNullException(nameof(comparer)); return new GroupedResultAsyncEnumerable(source, keySelector, resultSelector, comparer); } public static IAsyncEnumerable GroupBy(this IAsyncEnumerable source, Func keySelector, Func, TResult> resultSelector) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (resultSelector == null) throw new ArgumentNullException(nameof(resultSelector)); return GroupBy(source, keySelector, resultSelector, EqualityComparer.Default); } internal sealed class GroupedResultAsyncEnumerable : AsyncIterator, IIListProvider { private readonly IAsyncEnumerable source; private readonly Func keySelector; private readonly Func, TResult> resultSelector; private readonly IEqualityComparer comparer; private Internal.Lookup lookup; private IEnumerator enumerator; public GroupedResultAsyncEnumerable(IAsyncEnumerable source, Func keySelector, Func, TResult> resultSelector, IEqualityComparer comparer) { Debug.Assert(source != null); Debug.Assert(keySelector != null); Debug.Assert(resultSelector != null); Debug.Assert(comparer != null); this.source = source; this.keySelector = keySelector; this.resultSelector = resultSelector; this.comparer = comparer; } public override AsyncIterator Clone() { return new GroupedResultAsyncEnumerable(source, keySelector, resultSelector, comparer); } public override void Dispose() { if (enumerator != null) { enumerator.Dispose(); enumerator = null; lookup = null; } base.Dispose(); } protected override async Task MoveNextCore(CancellationToken cancellationToken) { switch (state) { case AsyncIteratorState.Allocated: lookup = await Internal.Lookup.CreateAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false); enumerator = lookup.ApplyResultSelector(resultSelector).GetEnumerator(); state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: if (enumerator.MoveNext()) { current = enumerator.Current; return true; } Dispose(); break; } return false; } public async Task ToArrayAsync(CancellationToken cancellationToken) { var lookup = await Internal.Lookup.CreateAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false); return lookup.ToArray(resultSelector); } public async Task> ToListAsync(CancellationToken cancellationToken) { var lookup = await Internal.Lookup.CreateAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false); return lookup.ToList(resultSelector); } public async Task GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken) { if (onlyIfCheap) { return -1; } var lookup = await Internal.Lookup.CreateAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false); return lookup.Count; } } internal sealed class GroupedAsyncEnumerable : AsyncIterator>, IIListProvider> { private readonly IAsyncEnumerable source; private readonly Func keySelector; private readonly Func elementSelector; private readonly IEqualityComparer comparer; private Internal.Lookup lookup; private IEnumerator> enumerator; public GroupedAsyncEnumerable(IAsyncEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) { Debug.Assert(source != null); Debug.Assert(keySelector != null); Debug.Assert(elementSelector != null); Debug.Assert(comparer != null); this.source = source; this.keySelector = keySelector; this.elementSelector = elementSelector; this.comparer = comparer; } public override AsyncIterator> Clone() { return new GroupedAsyncEnumerable(source, keySelector, elementSelector, comparer); } public override void Dispose() { if (enumerator != null) { enumerator.Dispose(); enumerator = null; lookup = null; } base.Dispose(); } protected override async Task MoveNextCore(CancellationToken cancellationToken) { switch (state) { case AsyncIteratorState.Allocated: lookup = await Internal.Lookup.CreateAsync(source, keySelector, elementSelector, comparer, cancellationToken).ConfigureAwait(false); enumerator = lookup.GetEnumerator(); state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: if (enumerator.MoveNext()) { current = (IAsyncGrouping)enumerator.Current; return true; } Dispose(); break; } return false; } public async Task[]> ToArrayAsync(CancellationToken cancellationToken) { IIListProvider> lookup = await Internal.Lookup.CreateAsync(source, keySelector, elementSelector, comparer, cancellationToken).ConfigureAwait(false); return await lookup.ToArrayAsync(cancellationToken).ConfigureAwait(false); } public async Task>> ToListAsync(CancellationToken cancellationToken) { IIListProvider> lookup = await Internal.Lookup.CreateAsync(source, keySelector, elementSelector, comparer, cancellationToken).ConfigureAwait(false); return await lookup.ToListAsync(cancellationToken).ConfigureAwait(false); } public async Task GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken) { if (onlyIfCheap) { return -1; } var lookup = await Internal.Lookup.CreateAsync(source, keySelector, elementSelector, comparer, cancellationToken).ConfigureAwait(false); return lookup.Count; } } internal sealed class GroupedAsyncEnumerable : AsyncIterator>, IIListProvider> { private readonly IAsyncEnumerable source; private readonly Func keySelector; private readonly IEqualityComparer comparer; private Internal.Lookup lookup; private IEnumerator> enumerator; public GroupedAsyncEnumerable(IAsyncEnumerable source, Func keySelector, IEqualityComparer comparer) { Debug.Assert(source != null); Debug.Assert(keySelector != null); Debug.Assert(comparer != null); this.source = source; this.keySelector = keySelector; this.comparer = comparer; } public override AsyncIterator> Clone() { return new GroupedAsyncEnumerable(source, keySelector, comparer); } public override void Dispose() { if (enumerator != null) { enumerator.Dispose(); enumerator = null; lookup = null; } base.Dispose(); } protected override async Task MoveNextCore(CancellationToken cancellationToken) { switch (state) { case AsyncIteratorState.Allocated: lookup = await Internal.Lookup.CreateAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false); enumerator = lookup.GetEnumerator(); state = AsyncIteratorState.Iterating; goto case AsyncIteratorState.Iterating; case AsyncIteratorState.Iterating: if (enumerator.MoveNext()) { current = (IAsyncGrouping)enumerator.Current; return true; } Dispose(); break; } return false; } public async Task[]> ToArrayAsync(CancellationToken cancellationToken) { IIListProvider> lookup = await Internal.Lookup.CreateAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false); return await lookup.ToArrayAsync(cancellationToken).ConfigureAwait(false); } public async Task>> ToListAsync(CancellationToken cancellationToken) { IIListProvider> lookup = await Internal.Lookup.CreateAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false); return await lookup.ToListAsync(cancellationToken).ConfigureAwait(false); } public async Task GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken) { if (onlyIfCheap) { return -1; } var lookup = await Internal.Lookup.CreateAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false); return lookup.Count; } } } } // Note: The type here has to be internal as System.Linq has it's own public copy we're not using namespace System.Linq.Internal { /// Adapted from System.Linq.Grouping from .NET Framework /// Source: https://github.com/dotnet/corefx/blob/b90532bc97b07234a7d18073819d019645285f1c/src/System.Linq/src/System/Linq/Grouping.cs#L64 internal class Grouping : IGrouping, IList, IAsyncGrouping { internal int _count; internal TElement[] _elements; internal int _hashCode; internal Grouping _hashNext; internal TKey _key; internal Grouping _next; IEnumerator IEnumerable.GetEnumerator() { return GetEnumerator(); } public IEnumerator GetEnumerator() { for (var i = 0; i < _count; i++) { yield return _elements[i]; } } // DDB195907: implement IGrouping<>.Key implicitly // so that WPF binding works on this property. public TKey Key { get { return _key; } } int ICollection.Count { get { return _count; } } bool ICollection.IsReadOnly { get { return true; } } void ICollection.Add(TElement item) { throw new NotSupportedException(Strings.NOT_SUPPORTED); } void ICollection.Clear() { throw new NotSupportedException(Strings.NOT_SUPPORTED); } bool ICollection.Contains(TElement item) { return Array.IndexOf(_elements, item, 0, _count) >= 0; } void ICollection.CopyTo(TElement[] array, int arrayIndex) { Array.Copy(_elements, 0, array, arrayIndex, _count); } bool ICollection.Remove(TElement item) { throw new NotSupportedException(Strings.NOT_SUPPORTED); } int IList.IndexOf(TElement item) { return Array.IndexOf(_elements, item, 0, _count); } void IList.Insert(int index, TElement item) { throw new NotSupportedException(Strings.NOT_SUPPORTED); } void IList.RemoveAt(int index) { throw new NotSupportedException(Strings.NOT_SUPPORTED); } TElement IList.this[int index] { get { if (index < 0 || index >= _count) { throw new ArgumentOutOfRangeException(nameof(index)); } return _elements[index]; } set { throw new NotSupportedException(Strings.NOT_SUPPORTED); } } internal void Add(TElement element) { if (_elements.Length == _count) { Array.Resize(ref _elements, checked(_count*2)); } _elements[_count] = element; _count++; } internal void Trim() { if (_elements.Length != _count) { Array.Resize(ref _elements, _count); } } IAsyncEnumerator IAsyncEnumerable.GetEnumerator() { return this.ToAsyncEnumerable().GetEnumerator(); } } }