// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System; using System.Collections; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerable { public static Task> ToLookup(this IAsyncEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (elementSelector == null) throw new ArgumentNullException(nameof(elementSelector)); if (comparer == null) throw new ArgumentNullException(nameof(comparer)); return ToLookup(source, keySelector, elementSelector, comparer, CancellationToken.None); } public static Task> ToLookup(this IAsyncEnumerable source, Func keySelector, Func elementSelector) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (elementSelector == null) throw new ArgumentNullException(nameof(elementSelector)); return ToLookup(source, keySelector, elementSelector, CancellationToken.None); } public static Task> ToLookup(this IAsyncEnumerable source, Func keySelector, IEqualityComparer comparer) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (comparer == null) throw new ArgumentNullException(nameof(comparer)); return ToLookup(source, keySelector, comparer, CancellationToken.None); } public static Task> ToLookup(this IAsyncEnumerable source, Func keySelector) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); return ToLookup(source, keySelector, CancellationToken.None); } public static async Task> ToLookup(this IAsyncEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer, CancellationToken cancellationToken) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (elementSelector == null) throw new ArgumentNullException(nameof(elementSelector)); if (comparer == null) throw new ArgumentNullException(nameof(comparer)); var lookup = await Internal.Lookup.CreateAsync(source, keySelector, elementSelector, comparer, cancellationToken) .ConfigureAwait(false); return lookup; } public static Task> ToLookup(this IAsyncEnumerable source, Func keySelector, Func elementSelector, CancellationToken cancellationToken) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (elementSelector == null) throw new ArgumentNullException(nameof(elementSelector)); return source.ToLookup(keySelector, elementSelector, EqualityComparer.Default, cancellationToken); } public static Task> ToLookup(this IAsyncEnumerable source, Func keySelector, IEqualityComparer comparer, CancellationToken cancellationToken) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); if (comparer == null) throw new ArgumentNullException(nameof(comparer)); return source.ToLookup(keySelector, x => x, comparer, cancellationToken); } public static Task> ToLookup(this IAsyncEnumerable source, Func keySelector, CancellationToken cancellationToken) { if (source == null) throw new ArgumentNullException(nameof(source)); if (keySelector == null) throw new ArgumentNullException(nameof(keySelector)); return source.ToLookup(keySelector, x => x, EqualityComparer.Default, cancellationToken); } } } // This is internal because System.Linq exposes a public Lookup that we cannot directly use here namespace System.Linq.Internal { internal class Lookup : ILookup, IIListProvider>, IIListProvider> { private readonly IEqualityComparer _comparer; private Grouping[] _groupings; private Grouping _lastGrouping; private Lookup(IEqualityComparer comparer) { _comparer = comparer ?? EqualityComparer.Default; _groupings = new Grouping[7]; } public int Count { get; private set; } public IEnumerable this[TKey key] { get { var grouping = GetGrouping(key, create: false); if (grouping != null) { return grouping; } #if NO_ARRAY_EMPTY return EmptyArray.Value; #else return Array.Empty(); #endif } } public bool Contains(TKey key) { return GetGrouping(key, create: false) != null; } IEnumerator IEnumerable.GetEnumerator() { return GetEnumerator(); } public IEnumerator> GetEnumerator() { var g = _lastGrouping; if (g != null) { do { g = g._next; yield return g; } while (g != _lastGrouping); } } public IEnumerable ApplyResultSelector(Func, TResult> resultSelector) { var g = _lastGrouping; if (g != null) { do { g = g._next; g.Trim(); yield return resultSelector(g._key, g._elements.ToAsyncEnumerable()); } while (g != _lastGrouping); } } internal static async Task> CreateAsync(IAsyncEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer, CancellationToken cancellationToken) { Debug.Assert(source != null); Debug.Assert(keySelector != null); Debug.Assert(elementSelector != null); var lookup = new Lookup(comparer); using (var enu = source.GetEnumerator()) { while (await enu.MoveNext(cancellationToken) .ConfigureAwait(false)) { lookup.GetGrouping(keySelector(enu.Current), create: true) .Add(elementSelector(enu.Current)); } } return lookup; } internal static async Task> CreateAsync(IAsyncEnumerable source, Func keySelector, IEqualityComparer comparer, CancellationToken cancellationToken) { Debug.Assert(source != null); Debug.Assert(keySelector != null); var lookup = new Lookup(comparer); using (var enu = source.GetEnumerator()) { while (await enu.MoveNext(cancellationToken) .ConfigureAwait(false)) { lookup.GetGrouping(keySelector(enu.Current), create: true) .Add(enu.Current); } } return lookup; } internal static async Task> CreateForJoinAsync(IAsyncEnumerable source, Func keySelector, IEqualityComparer comparer, CancellationToken cancellationToken) { var lookup = new Lookup(comparer); using (var enu = source.GetEnumerator()) { while (await enu.MoveNext(cancellationToken) .ConfigureAwait(false)) { var key = keySelector(enu.Current); if (key != null) { lookup.GetGrouping(key, create: true) .Add(enu.Current); } } } return lookup; } internal Grouping GetGrouping(TKey key, bool create) { var hashCode = InternalGetHashCode(key); for (var g = _groupings[hashCode%_groupings.Length]; g != null; g = g._hashNext) { if (g._hashCode == hashCode && _comparer.Equals(g._key, key)) { return g; } } if (create) { if (Count == _groupings.Length) { Resize(); } var index = hashCode%_groupings.Length; var g = new Grouping(); g._key = key; g._hashCode = hashCode; g._elements = new TElement[1]; g._hashNext = _groupings[index]; _groupings[index] = g; if (_lastGrouping == null) { g._next = g; } else { g._next = _lastGrouping._next; _lastGrouping._next = g; } _lastGrouping = g; Count++; return g; } return null; } internal int InternalGetHashCode(TKey key) { // Handle comparer implementations that throw when passed null return (key == null) ? 0 : _comparer.GetHashCode(key) & 0x7FFFFFFF; } internal TResult[] ToArray(Func, TResult> resultSelector) { var array = new TResult[Count]; var index = 0; var g = _lastGrouping; if (g != null) { do { g = g._next; g.Trim(); array[index] = resultSelector(g._key, g._elements.ToAsyncEnumerable()); ++index; } while (g != _lastGrouping); } return array; } internal List ToList(Func, TResult> resultSelector) { var list = new List(Count); var g = _lastGrouping; if (g != null) { do { g = g._next; g.Trim(); list.Add(resultSelector(g._key, g._elements.ToAsyncEnumerable())); } while (g != _lastGrouping); } return list; } private void Resize() { var newSize = checked((Count*2) + 1); var newGroupings = new Grouping[newSize]; var g = _lastGrouping; do { g = g._next; var index = g._hashCode%newSize; g._hashNext = newGroupings[index]; newGroupings[index] = g; } while (g != _lastGrouping); _groupings = newGroupings; } IAsyncEnumerator> IAsyncEnumerable>.GetEnumerator() { return this.ToAsyncEnumerable().GetEnumerator(); } public Task[]> ToArrayAsync(CancellationToken cancellationToken) { var array = new IGrouping[Count]; var index = 0; var g = _lastGrouping; if (g != null) { do { g = g._next; array[index] = g; ++index; } while (g != _lastGrouping); } return Task.FromResult(array); } public Task>> ToListAsync(CancellationToken cancellationToken) { var list = new List>(Count); var g = _lastGrouping; if (g != null) { do { g = g._next; list.Add(g); } while (g != _lastGrouping); } return Task.FromResult(list); } public Task GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken) { return Task.FromResult(Count); } IAsyncEnumerator> IAsyncEnumerable>.GetEnumerator() { return Enumerable.Cast>(this).ToAsyncEnumerable().GetEnumerator(); } Task>> IIListProvider>.ToListAsync(CancellationToken cancellationToken) { var list = new List>(Count); var g = _lastGrouping; if (g != null) { do { g = g._next; list.Add(g); } while (g != _lastGrouping); } return Task.FromResult(list); } Task[]> IIListProvider>.ToArrayAsync(CancellationToken cancellationToken) { var array = new IAsyncGrouping[Count]; var index = 0; var g = _lastGrouping; if (g != null) { do { g = g._next; array[index] = g; ++index; } while (g != _lastGrouping); } return Task.FromResult(array); } } }