// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT License.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace System.Linq
{
public static partial class AsyncEnumerable
{
#if INCLUDE_SYSTEM_LINQ_ASYNCENUMERABLE_DUPLICATES
// https://learn.microsoft.com/en-us/dotnet/api/system.linq.asyncenumerable.distinct?view=net-9.0-pp
// These two overloads are covered by a single method in System.Linq.AsyncEnumerable. Its only method
// takes a comparer, but specifies a default value of null.
///
/// Returns an async-enumerable sequence that contains only distinct elements.
///
/// The type of the elements in the source sequence.
/// An async-enumerable sequence to retain distinct elements for.
/// An async-enumerable sequence only containing the distinct elements from the source sequence.
/// is null.
/// Usage of this operator should be considered carefully due to the maintenance of an internal lookup structure which can grow large.
public static IAsyncEnumerable Distinct(this IAsyncEnumerable source) => Distinct(source, comparer: null);
///
/// Returns an async-enumerable sequence that contains only distinct elements according to the comparer.
///
/// The type of the elements in the source sequence.
/// An async-enumerable sequence to retain distinct elements for.
/// Equality comparer for source elements.
/// An async-enumerable sequence only containing the distinct elements from the source sequence.
/// or is null.
/// Usage of this operator should be considered carefully due to the maintenance of an internal lookup structure which can grow large.
public static IAsyncEnumerable Distinct(this IAsyncEnumerable source, IEqualityComparer? comparer)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
return new DistinctAsyncIterator(source, comparer);
}
private sealed class DistinctAsyncIterator : AsyncIterator, IAsyncIListProvider
{
private readonly IEqualityComparer? _comparer;
private readonly IAsyncEnumerable _source;
private IAsyncEnumerator? _enumerator;
private Set? _set;
public DistinctAsyncIterator(IAsyncEnumerable source, IEqualityComparer? comparer)
{
_source = source;
_comparer = comparer;
}
public async ValueTask ToArrayAsync(CancellationToken cancellationToken)
{
var s = await FillSetAsync(cancellationToken).ConfigureAwait(false);
return s.ToArray();
}
public async ValueTask> ToListAsync(CancellationToken cancellationToken)
{
var s = await FillSetAsync(cancellationToken).ConfigureAwait(false);
return s.ToList();
}
public ValueTask GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
{
if (onlyIfCheap)
{
return new ValueTask(-1);
}
return Core();
async ValueTask Core()
{
var s = await FillSetAsync(cancellationToken).ConfigureAwait(false);
return s.Count;
}
}
public override AsyncIteratorBase Clone()
{
return new DistinctAsyncIterator(_source, _comparer);
}
public override async ValueTask DisposeAsync()
{
if (_enumerator != null)
{
await _enumerator.DisposeAsync().ConfigureAwait(false);
_enumerator = null;
_set = null;
}
await base.DisposeAsync().ConfigureAwait(false);
}
protected override async ValueTask MoveNextCore()
{
switch (_state)
{
case AsyncIteratorState.Allocated:
_enumerator = _source.GetAsyncEnumerator(_cancellationToken);
if (!await _enumerator.MoveNextAsync().ConfigureAwait(false))
{
await DisposeAsync().ConfigureAwait(false);
return false;
}
var element = _enumerator.Current;
_set = new Set(_comparer);
_set.Add(element);
_current = element;
_state = AsyncIteratorState.Iterating;
return true;
case AsyncIteratorState.Iterating:
while (await _enumerator!.MoveNextAsync().ConfigureAwait(false))
{
element = _enumerator.Current;
if (_set!.Add(element))
{
_current = element;
return true;
}
}
break;
}
await DisposeAsync().ConfigureAwait(false);
return false;
}
private Task> FillSetAsync(CancellationToken cancellationToken)
{
return AsyncEnumerableHelpers.ToSet(_source, _comparer, cancellationToken);
}
}
#endif // INCLUDE_SYSTEM_LINQ_ASYNCENUMERABLE_DUPLICATES
}
}