// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System; using System.Collections; using System.Collections.Generic; using System.Linq; using System.Runtime.ExceptionServices; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerable { public static IAsyncEnumerable> GroupBy(this IAsyncEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (elementSelector == null) throw new ArgumentNullException(nameof(elementSelector)); if (comparer == null) throw new ArgumentNullException(nameof(comparer)); return CreateEnumerable(() => { var gate = new object(); var e = source.GetEnumerator(); var count = 1; var map = new Dictionary>(comparer); var list = new List>(); var index = 0; var current = default(IAsyncGrouping); var faulted = default(ExceptionDispatchInfo); var res = default(bool?); var cts = new CancellationTokenDisposable(); var refCount = new Disposable( () => { if (Interlocked.Decrement(ref count) == 0) e.Dispose(); } ); var d = Disposable.Create(cts, refCount); var iterateSource = default(Func>); iterateSource = async ct => { lock (gate) { if (res != null) { return res.Value; } res = null; } faulted?.Throw(); try { res = await e.MoveNext(ct) .ConfigureAwait(false); if (res == true) { var key = default(TKey); var element = default(TElement); var cur = e.Current; try { key = keySelector(cur); element = elementSelector(cur); } catch (Exception exception) { foreach (var v in map.Values) v.Error(exception); throw; } var group = default(AsyncGrouping); if (!map.TryGetValue(key, out group)) { group = new AsyncGrouping(key, iterateSource, refCount); map.Add(key, group); lock (list) list.Add(group); Interlocked.Increment(ref count); } group.Add(element); } return res.Value; } catch (Exception ex) { foreach (var v in map.Values) v.Error(ex); faulted = ExceptionDispatchInfo.Capture(ex); throw; } finally { res = null; } }; var f = default(Func>); f = async ct => { var result = await iterateSource(ct) .ConfigureAwait(false); current = null; lock (list) { if (index < list.Count) current = list[index++]; } if (current != null) { return true; } return result && await f(ct) .ConfigureAwait(false); }; return CreateEnumerator( f, () => current, d.Dispose, e ); }); } public static IAsyncEnumerable> GroupBy(this IAsyncEnumerable source, Func keySelector, Func elementSelector) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (elementSelector == null) throw new ArgumentNullException(nameof(elementSelector)); return source.GroupBy(keySelector, elementSelector, EqualityComparer.Default); } public static IAsyncEnumerable> GroupBy(this IAsyncEnumerable source, Func keySelector, IEqualityComparer comparer) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (comparer == null) throw new ArgumentNullException(nameof(comparer)); return new GroupedAsyncEnumerable(source, keySelector, comparer); } public static IAsyncEnumerable> GroupBy(this IAsyncEnumerable source, Func keySelector) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); return new GroupedAsyncEnumerable(source, keySelector, EqualityComparer.Default); } public static IAsyncEnumerable GroupBy(this IAsyncEnumerable source, Func keySelector, Func elementSelector, Func, TResult> resultSelector, IEqualityComparer comparer) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (elementSelector == null) throw new ArgumentNullException(nameof(elementSelector)); if (resultSelector == null) throw new ArgumentNullException(nameof(resultSelector)); if (comparer == null) throw new ArgumentNullException(nameof(comparer)); return source.GroupBy(keySelector, elementSelector, comparer) .Select(g => resultSelector(g.Key, g)); } public static IAsyncEnumerable GroupBy(this IAsyncEnumerable source, Func keySelector, Func elementSelector, Func, TResult> resultSelector) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (elementSelector == null) throw new ArgumentNullException(nameof(elementSelector)); if (resultSelector == null) throw new ArgumentNullException(nameof(resultSelector)); return source.GroupBy(keySelector, elementSelector, EqualityComparer.Default) .Select(g => resultSelector(g.Key, g)); } public static IAsyncEnumerable GroupBy(this IAsyncEnumerable source, Func keySelector, Func, TResult> resultSelector, IEqualityComparer comparer) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (resultSelector == null) throw new ArgumentNullException(nameof(resultSelector)); if (comparer == null) throw new ArgumentNullException(nameof(comparer)); return source.GroupBy(keySelector, x => x, comparer) .Select(g => resultSelector(g.Key, g)); } public static IAsyncEnumerable GroupBy(this IAsyncEnumerable source, Func keySelector, Func, TResult> resultSelector) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (resultSelector == null) throw new ArgumentNullException(nameof(resultSelector)); return source.GroupBy(keySelector, x => x, EqualityComparer.Default) .Select(g => resultSelector(g.Key, g)); } private static IEnumerable> GroupUntil(this IEnumerable source, Func keySelector, Func elementSelector, IComparer comparer) { var group = default(EnumerableGrouping); foreach (var x in source) { var key = keySelector(x); if (group == null || comparer.Compare(group.Key, key) != 0) { group = new EnumerableGrouping(key); yield return group; } group.Add(elementSelector(x)); } } internal sealed class GroupedAsyncEnumerable : IIListProvider> { private readonly IAsyncEnumerable source; private readonly Func keySelector; private readonly IEqualityComparer comparer; public GroupedAsyncEnumerable(IAsyncEnumerable source, Func keySelector, IEqualityComparer comparer) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); this.source = source; this.keySelector = keySelector; this.comparer = comparer; } public IAsyncEnumerator> GetEnumerator() { Internal.Lookup lookup = null; IEnumerator> enumerator = null; return CreateEnumerator( async ct => { if (lookup == null) { lookup = await Internal.Lookup.CreateAsync(source, keySelector, comparer, ct).ConfigureAwait(false); enumerator = lookup.GetEnumerator(); } // By the time we get here, the lookup is sync if (ct.IsCancellationRequested) return false; return enumerator?.MoveNext() ?? false; }, () => (IAsyncGrouping)enumerator?.Current, () => { if (enumerator != null) { enumerator.Dispose(); enumerator = null; } }); } public async Task[]> ToArrayAsync(CancellationToken cancellationToken) { IIListProvider> lookup = await Internal.Lookup.CreateAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false); return await lookup.ToArrayAsync(cancellationToken).ConfigureAwait(false); } public async Task>> ToListAsync(CancellationToken cancellationToken) { IIListProvider> lookup = await Internal.Lookup.CreateAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false); return await lookup.ToListAsync(cancellationToken).ConfigureAwait(false); } public async Task GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken) { if (onlyIfCheap) { return -1; } var lookup = await Internal.Lookup.CreateAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false); return lookup.Count; } } private class AsyncGrouping : IAsyncGrouping { private readonly List elements = new List(); private readonly Func> iterateSource; private readonly IDisposable sourceDisposable; private bool done; private ExceptionDispatchInfo exception; public AsyncGrouping(TKey key, Func> iterateSource, IDisposable sourceDisposable) { this.iterateSource = iterateSource; this.sourceDisposable = sourceDisposable; Key = key; } public TKey Key { get; } public IAsyncEnumerator GetEnumerator() { var index = -1; var cts = new CancellationTokenDisposable(); var d = Disposable.Create(cts, sourceDisposable); var f = default(Func>); f = async ct => { var size = 0; lock (elements) size = elements.Count; if (index < size) { return true; } if (done) { exception?.Throw(); return false; } if (await iterateSource(ct) .ConfigureAwait(false)) { return await f(ct) .ConfigureAwait(false); } return false; }; return CreateEnumerator( ct => { ++index; return f(cts.Token); }, () => elements[index], d.Dispose, null ); } public void Add(TElement element) { lock (elements) elements.Add(element); } public void Error(Exception exception) { done = true; this.exception = ExceptionDispatchInfo.Capture(exception); } } } } // Note: The type here has to be internal as System.Linq has it's own public copy we're not using namespace System.Linq.Internal { /// Adapted from System.Linq.Grouping from .NET Framework /// Source: https://github.com/dotnet/corefx/blob/b90532bc97b07234a7d18073819d019645285f1c/src/System.Linq/src/System/Linq/Grouping.cs#L64 internal class Grouping : IGrouping, IList, IAsyncGrouping { internal int _count; internal TElement[] _elements; internal int _hashCode; internal Grouping _hashNext; internal TKey _key; internal Grouping _next; IEnumerator IEnumerable.GetEnumerator() { return GetEnumerator(); } public IEnumerator GetEnumerator() { for (var i = 0; i < _count; i++) { yield return _elements[i]; } } // DDB195907: implement IGrouping<>.Key implicitly // so that WPF binding works on this property. public TKey Key { get { return _key; } } int ICollection.Count { get { return _count; } } bool ICollection.IsReadOnly { get { return true; } } void ICollection.Add(TElement item) { throw new NotSupportedException(Strings.NOT_SUPPORTED); } void ICollection.Clear() { throw new NotSupportedException(Strings.NOT_SUPPORTED); } bool ICollection.Contains(TElement item) { return Array.IndexOf(_elements, item, 0, _count) >= 0; } void ICollection.CopyTo(TElement[] array, int arrayIndex) { Array.Copy(_elements, 0, array, arrayIndex, _count); } bool ICollection.Remove(TElement item) { throw new NotSupportedException(Strings.NOT_SUPPORTED); } int IList.IndexOf(TElement item) { return Array.IndexOf(_elements, item, 0, _count); } void IList.Insert(int index, TElement item) { throw new NotSupportedException(Strings.NOT_SUPPORTED); } void IList.RemoveAt(int index) { throw new NotSupportedException(Strings.NOT_SUPPORTED); } TElement IList.this[int index] { get { if (index < 0 || index >= _count) { throw new ArgumentOutOfRangeException(nameof(index)); } return _elements[index]; } set { throw new NotSupportedException(Strings.NOT_SUPPORTED); } } internal void Add(TElement element) { if (_elements.Length == _count) { Array.Resize(ref _elements, checked(_count*2)); } _elements[_count] = element; _count++; } internal void Trim() { if (_elements.Length != _count) { Array.Resize(ref _elements, _count); } } IAsyncEnumerator IAsyncEnumerable.GetEnumerator() { var adapter = new AsyncEnumerable.AsyncEnumerableAdapter(this); return adapter.GetEnumerator(); } } }