// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System.Collections; using System.Collections.Generic; using System.Diagnostics; using System.Threading; using System.Threading.Tasks; // This is internal because System.Linq exposes a public Lookup that we cannot directly use here namespace System.Linq.Internal { internal class Lookup : ILookup, IAsyncIListProvider> { private readonly IEqualityComparer _comparer; private Grouping[] _groupings; private Grouping _lastGrouping; private Lookup(IEqualityComparer comparer) { _comparer = comparer ?? EqualityComparer.Default; _groupings = new Grouping[7]; } public int Count { get; private set; } public IEnumerable this[TKey key] { get { var grouping = GetGrouping(key, create: false); if (grouping != null) { return grouping; } #if NO_ARRAY_EMPTY return EmptyArray.Value; #else return Array.Empty(); #endif } } public bool Contains(TKey key) { return GetGrouping(key, create: false) != null; } IEnumerator IEnumerable.GetEnumerator() { return GetEnumerator(); } public IEnumerator> GetEnumerator() { var g = _lastGrouping; if (g != null) { do { g = g._next; yield return g; } while (g != _lastGrouping); } } public IEnumerable ApplyResultSelector(Func, TResult> resultSelector) { var g = _lastGrouping; if (g != null) { do { g = g._next; g.Trim(); var result = resultSelector(g._key, g._elements.ToAsyncEnumerable()); yield return result; } while (g != _lastGrouping); } } internal static async Task> CreateAsync(IAsyncEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer, CancellationToken cancellationToken) { Debug.Assert(source != null); Debug.Assert(keySelector != null); Debug.Assert(elementSelector != null); var lookup = new Lookup(comparer); await foreach (TSource item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) { var key = keySelector(item); var group = lookup.GetGrouping(key, create: true); var element = elementSelector(item); group.Add(element); } return lookup; } internal static async Task> CreateAsync(IAsyncEnumerable source, Func keySelector, IEqualityComparer comparer, CancellationToken cancellationToken) { Debug.Assert(source != null); Debug.Assert(keySelector != null); var lookup = new Lookup(comparer); await foreach (TElement item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) { var key = keySelector(item); lookup.GetGrouping(key, create: true).Add(item); } return lookup; } internal static async Task> CreateForJoinAsync(IAsyncEnumerable source, Func keySelector, IEqualityComparer comparer, CancellationToken cancellationToken) { var lookup = new Lookup(comparer); await foreach (TElement item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) { var key = keySelector(item); if (key != null) { lookup.GetGrouping(key, create: true).Add(item); } } return lookup; } internal Grouping GetGrouping(TKey key, bool create) { var hashCode = InternalGetHashCode(key); for (var g = _groupings[hashCode % _groupings.Length]; g != null; g = g._hashNext) { if (g._hashCode == hashCode && _comparer.Equals(g._key, key)) { return g; } } if (create) { if (Count == _groupings.Length) { Resize(); } var index = hashCode % _groupings.Length; var g = new Grouping { _key = key, _hashCode = hashCode, _elements = new TElement[1], _hashNext = _groupings[index] }; _groupings[index] = g; if (_lastGrouping == null) { g._next = g; } else { g._next = _lastGrouping._next; _lastGrouping._next = g; } _lastGrouping = g; Count++; return g; } return null; } internal int InternalGetHashCode(TKey key) { // Handle comparer implementations that throw when passed null return (key == null) ? 0 : _comparer.GetHashCode(key) & 0x7FFFFFFF; } internal TResult[] ToArray(Func, TResult> resultSelector) { var array = new TResult[Count]; var index = 0; var g = _lastGrouping; if (g != null) { do { g = g._next; g.Trim(); array[index] = resultSelector(g._key, g._elements.ToAsyncEnumerable()); ++index; } while (g != _lastGrouping); } return array; } internal List ToList(Func, TResult> resultSelector) { var list = new List(Count); var g = _lastGrouping; if (g != null) { do { g = g._next; g.Trim(); var result = resultSelector(g._key, g._elements.ToAsyncEnumerable()); list.Add(result); } while (g != _lastGrouping); } return list; } private void Resize() { var newSize = checked((Count * 2) + 1); var newGroupings = new Grouping[newSize]; var g = _lastGrouping; do { g = g._next; var index = g._hashCode % newSize; g._hashNext = newGroupings[index]; newGroupings[index] = g; } while (g != _lastGrouping); _groupings = newGroupings; } public ValueTask GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken) { return new ValueTask(Count); } IAsyncEnumerator> IAsyncEnumerable>.GetAsyncEnumerator(CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); // NB: [LDM-2018-11-28] Equivalent to async iterator behavior. return Enumerable.Cast>(this).ToAsyncEnumerable().GetAsyncEnumerator(cancellationToken); } ValueTask>> IAsyncIListProvider>.ToListAsync(CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); var list = new List>(Count); var g = _lastGrouping; if (g != null) { do { g = g._next; list.Add(g); } while (g != _lastGrouping); } return new ValueTask>>(list); } ValueTask[]> IAsyncIListProvider>.ToArrayAsync(CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); var array = new IAsyncGrouping[Count]; var index = 0; var g = _lastGrouping; if (g != null) { do { g = g._next; array[index] = g; ++index; } while (g != _lastGrouping); } return new ValueTask[]>(array); } } internal class LookupWithTask : ILookup, IAsyncIListProvider> { private readonly IEqualityComparer _comparer; private Grouping[] _groupings; private Grouping _lastGrouping; private LookupWithTask(IEqualityComparer comparer) { _comparer = comparer ?? EqualityComparer.Default; _groupings = new Grouping[7]; } public int Count { get; private set; } public IEnumerable this[TKey key] { get { var grouping = GetGrouping(key, create: false); if (grouping != null) { return grouping; } #if NO_ARRAY_EMPTY return EmptyArray.Value; #else return Array.Empty(); #endif } } public bool Contains(TKey key) { return GetGrouping(key, create: false) != null; } IEnumerator IEnumerable.GetEnumerator() { return GetEnumerator(); } public IEnumerator> GetEnumerator() { var g = _lastGrouping; if (g != null) { do { g = g._next; yield return g; } while (g != _lastGrouping); } } internal static async Task> CreateAsync(IAsyncEnumerable source, Func> keySelector, Func> elementSelector, IEqualityComparer comparer, CancellationToken cancellationToken) { Debug.Assert(source != null); Debug.Assert(keySelector != null); Debug.Assert(elementSelector != null); var lookup = new LookupWithTask(comparer); await foreach (TSource item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) { var key = await keySelector(item).ConfigureAwait(false); var group = lookup.GetGrouping(key, create: true); var element = await elementSelector(item).ConfigureAwait(false); group.Add(element); } return lookup; } #if !NO_DEEP_CANCELLATION internal static async Task> CreateAsync(IAsyncEnumerable source, Func> keySelector, Func> elementSelector, IEqualityComparer comparer, CancellationToken cancellationToken) { Debug.Assert(source != null); Debug.Assert(keySelector != null); Debug.Assert(elementSelector != null); var lookup = new LookupWithTask(comparer); await foreach (TSource item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) { var key = await keySelector(item, cancellationToken).ConfigureAwait(false); var group = lookup.GetGrouping(key, create: true); var element = await elementSelector(item, cancellationToken).ConfigureAwait(false); group.Add(element); } return lookup; } #endif internal static async Task> CreateAsync(IAsyncEnumerable source, Func> keySelector, IEqualityComparer comparer, CancellationToken cancellationToken) { Debug.Assert(source != null); Debug.Assert(keySelector != null); var lookup = new LookupWithTask(comparer); await foreach (TElement item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) { var key = await keySelector(item).ConfigureAwait(false); lookup.GetGrouping(key, create: true).Add(item); } return lookup; } #if !NO_DEEP_CANCELLATION internal static async Task> CreateAsync(IAsyncEnumerable source, Func> keySelector, IEqualityComparer comparer, CancellationToken cancellationToken) { Debug.Assert(source != null); Debug.Assert(keySelector != null); var lookup = new LookupWithTask(comparer); await foreach (TElement item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) { var key = await keySelector(item, cancellationToken).ConfigureAwait(false); lookup.GetGrouping(key, create: true).Add(item); } return lookup; } #endif internal static async Task> CreateForJoinAsync(IAsyncEnumerable source, Func> keySelector, IEqualityComparer comparer, CancellationToken cancellationToken) { var lookup = new LookupWithTask(comparer); await foreach (TElement item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) { var key = await keySelector(item).ConfigureAwait(false); if (key != null) { lookup.GetGrouping(key, create: true).Add(item); } } return lookup; } #if !NO_DEEP_CANCELLATION internal static async Task> CreateForJoinAsync(IAsyncEnumerable source, Func> keySelector, IEqualityComparer comparer, CancellationToken cancellationToken) { var lookup = new LookupWithTask(comparer); await foreach (TElement item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) { var key = await keySelector(item, cancellationToken).ConfigureAwait(false); if (key != null) { lookup.GetGrouping(key, create: true).Add(item); } } return lookup; } #endif internal Grouping GetGrouping(TKey key, bool create) { var hashCode = InternalGetHashCode(key); for (var g = _groupings[hashCode % _groupings.Length]; g != null; g = g._hashNext) { if (g._hashCode == hashCode && _comparer.Equals(g._key, key)) { return g; } } if (create) { if (Count == _groupings.Length) { Resize(); } var index = hashCode % _groupings.Length; var g = new Grouping { _key = key, _hashCode = hashCode, _elements = new TElement[1], _hashNext = _groupings[index] }; _groupings[index] = g; if (_lastGrouping == null) { g._next = g; } else { g._next = _lastGrouping._next; _lastGrouping._next = g; } _lastGrouping = g; Count++; return g; } return null; } internal int InternalGetHashCode(TKey key) { // Handle comparer implementations that throw when passed null return (key == null) ? 0 : _comparer.GetHashCode(key) & 0x7FFFFFFF; } internal async Task ToArray(Func, ValueTask> resultSelector) { var array = new TResult[Count]; var index = 0; var g = _lastGrouping; if (g != null) { do { g = g._next; g.Trim(); array[index] = await resultSelector(g._key, g._elements.ToAsyncEnumerable()).ConfigureAwait(false); ++index; } while (g != _lastGrouping); } return array; } #if !NO_DEEP_CANCELLATION internal async Task ToArray(Func, CancellationToken, ValueTask> resultSelector, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); var array = new TResult[Count]; var index = 0; var g = _lastGrouping; if (g != null) { do { g = g._next; g.Trim(); array[index] = await resultSelector(g._key, g._elements.ToAsyncEnumerable(), cancellationToken).ConfigureAwait(false); ++index; } while (g != _lastGrouping); } return array; } #endif internal async Task> ToList(Func, ValueTask> resultSelector) { var list = new List(Count); var g = _lastGrouping; if (g != null) { do { g = g._next; g.Trim(); var result = await resultSelector(g._key, g._elements.ToAsyncEnumerable()).ConfigureAwait(false); list.Add(result); } while (g != _lastGrouping); } return list; } #if !NO_DEEP_CANCELLATION internal async Task> ToList(Func, CancellationToken, ValueTask> resultSelector, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); var list = new List(Count); var g = _lastGrouping; if (g != null) { do { g = g._next; g.Trim(); var result = await resultSelector(g._key, g._elements.ToAsyncEnumerable(), cancellationToken).ConfigureAwait(false); list.Add(result); } while (g != _lastGrouping); } return list; } #endif private void Resize() { var newSize = checked((Count * 2) + 1); var newGroupings = new Grouping[newSize]; var g = _lastGrouping; do { g = g._next; var index = g._hashCode % newSize; g._hashNext = newGroupings[index]; newGroupings[index] = g; } while (g != _lastGrouping); _groupings = newGroupings; } public ValueTask GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken) { return new ValueTask(Count); } IAsyncEnumerator> IAsyncEnumerable>.GetAsyncEnumerator(CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); // NB: [LDM-2018-11-28] Equivalent to async iterator behavior. return Enumerable.Cast>(this).ToAsyncEnumerable().GetAsyncEnumerator(cancellationToken); } ValueTask>> IAsyncIListProvider>.ToListAsync(CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); var list = new List>(Count); var g = _lastGrouping; if (g != null) { do { g = g._next; list.Add(g); } while (g != _lastGrouping); } return new ValueTask>>(list); } ValueTask[]> IAsyncIListProvider>.ToArrayAsync(CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); var array = new IAsyncGrouping[Count]; var index = 0; var g = _lastGrouping; if (g != null) { do { g = g._next; array[index] = g; ++index; } while (g != _lastGrouping); } return new ValueTask[]>(array); } } }