// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using System.Threading; namespace System.Linq { public static partial class AsyncEnumerable { public static IAsyncEnumerable Concat(this IAsyncEnumerable first, IAsyncEnumerable second) { if (first == null) throw new ArgumentNullException(nameof(first)); if (second == null) throw new ArgumentNullException(nameof(second)); return Create(() => { var switched = false; var e = first.GetEnumerator(); var cts = new CancellationTokenDisposable(); var a = new AssignableDisposable { Disposable = e }; var d = Disposable.Create(cts, a); var f = default(Func>); f = async ct => { if (await e.MoveNext(ct).ConfigureAwait(false)) { return true; } if (switched) { return false; } switched = true; e = second.GetEnumerator(); a.Disposable = e; return await f(ct).ConfigureAwait(false); }; return Create( f, () => e.Current, d.Dispose, e ); }); } public static IAsyncEnumerable Zip(this IAsyncEnumerable first, IAsyncEnumerable second, Func selector) { if (first == null) throw new ArgumentNullException(nameof(first)); if (second == null) throw new ArgumentNullException(nameof(second)); if (selector == null) throw new ArgumentNullException(nameof(selector)); return Create(() => { var e1 = first.GetEnumerator(); var e2 = second.GetEnumerator(); var current = default(TResult); var cts = new CancellationTokenDisposable(); var d = Disposable.Create(cts, e1, e2); return Create( ct => e1.MoveNext(cts.Token).Zip(e2.MoveNext(cts.Token), (f, s) => { var result = f && s; if (result) current = selector(e1.Current, e2.Current); return result; }), () => current, d.Dispose ); }); } public static IAsyncEnumerable Except(this IAsyncEnumerable first, IAsyncEnumerable second, IEqualityComparer comparer) { if (first == null) throw new ArgumentNullException(nameof(first)); if (second == null) throw new ArgumentNullException(nameof(second)); if (comparer == null) throw new ArgumentNullException(nameof(comparer)); return Create(() => { var e = first.GetEnumerator(); var cts = new CancellationTokenDisposable(); var d = Disposable.Create(cts, e); var mapTask = default(Task>); var getMapTask = new Func>>(ct => mapTask ?? (mapTask = second.ToDictionary(x => x, comparer, ct))); var f = default(Func>); f = async ct => { if (await e.MoveNext(ct).Zip(getMapTask(ct), (b, _) => b).ConfigureAwait(false)) { if (!mapTask.Result.ContainsKey(e.Current)) return true; return await f(ct).ConfigureAwait(false); } return false; }; return Create( f, () => e.Current, d.Dispose, e ); }); } public static IAsyncEnumerable Except(this IAsyncEnumerable first, IAsyncEnumerable second) { if (first == null) throw new ArgumentNullException(nameof(first)); if (second == null) throw new ArgumentNullException(nameof(second)); return first.Except(second, EqualityComparer.Default); } public static IAsyncEnumerable Intersect(this IAsyncEnumerable first, IAsyncEnumerable second, IEqualityComparer comparer) { if (first == null) throw new ArgumentNullException(nameof(first)); if (second == null) throw new ArgumentNullException(nameof(second)); if (comparer == null) throw new ArgumentNullException(nameof(comparer)); return Create(() => { var e = first.GetEnumerator(); var cts = new CancellationTokenDisposable(); var d = Disposable.Create(cts, e); var mapTask = default(Task>); var getMapTask = new Func>>(ct => { if (mapTask == null) mapTask = second.ToDictionary(x => x, comparer, ct); return mapTask; }); var f = default(Func>); f = async ct => { if (await e.MoveNext(ct).Zip(getMapTask(ct), (b, _) => b).ConfigureAwait(false)) { if (mapTask.Result.ContainsKey(e.Current)) return true; return await f(ct).ConfigureAwait(false); } return false; }; return Create( f, () => e.Current, d.Dispose, e ); }); } public static IAsyncEnumerable Intersect(this IAsyncEnumerable first, IAsyncEnumerable second) { if (first == null) throw new ArgumentNullException(nameof(first)); if (second == null) throw new ArgumentNullException(nameof(second)); return first.Intersect(second, EqualityComparer.Default); } public static IAsyncEnumerable Union(this IAsyncEnumerable first, IAsyncEnumerable second, IEqualityComparer comparer) { if (first == null) throw new ArgumentNullException(nameof(first)); if (second == null) throw new ArgumentNullException(nameof(second)); if (comparer == null) throw new ArgumentNullException(nameof(comparer)); return first.Concat(second).Distinct(comparer); } public static IAsyncEnumerable Union(this IAsyncEnumerable first, IAsyncEnumerable second) { if (first == null) throw new ArgumentNullException(nameof(first)); if (second == null) throw new ArgumentNullException(nameof(second)); return first.Union(second, EqualityComparer.Default); } public static Task SequenceEqual(this IAsyncEnumerable first, IAsyncEnumerable second, IEqualityComparer comparer, CancellationToken cancellationToken) { if (first == null) throw new ArgumentNullException(nameof(first)); if (second == null) throw new ArgumentNullException(nameof(second)); if (comparer == null) throw new ArgumentNullException(nameof(comparer)); return SequenceEqual_(first, second, comparer, cancellationToken); } private static async Task SequenceEqual_(IAsyncEnumerable first, IAsyncEnumerable second, IEqualityComparer comparer, CancellationToken cancellationToken) { using (var e1 = first.GetEnumerator()) using (var e2 = second.GetEnumerator()) { while (await e1.MoveNext(cancellationToken).ConfigureAwait(false)) { if (!(await e2.MoveNext(cancellationToken).ConfigureAwait(false) && comparer.Equals(e1.Current, e2.Current))) { return false; } } return !await e2.MoveNext(cancellationToken).ConfigureAwait(false); } } public static Task SequenceEqual(this IAsyncEnumerable first, IAsyncEnumerable second, CancellationToken cancellationToken) { if (first == null) throw new ArgumentNullException(nameof(first)); if (second == null) throw new ArgumentNullException(nameof(second)); return first.SequenceEqual(second, EqualityComparer.Default, cancellationToken); } public static IAsyncEnumerable GroupJoin(this IAsyncEnumerable outer, IAsyncEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func, TResult> resultSelector, IEqualityComparer comparer) { if (outer == null) throw new ArgumentNullException(nameof(outer)); if (inner == null) throw new ArgumentNullException(nameof(inner)); if (outerKeySelector == null) throw new ArgumentNullException(nameof(outerKeySelector)); if (innerKeySelector == null) throw new ArgumentNullException(nameof(innerKeySelector)); if (resultSelector == null) throw new ArgumentNullException(nameof(resultSelector)); if (comparer == null) throw new ArgumentNullException(nameof(comparer)); return new GroupJoinAsyncEnumerable(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer); } public static IAsyncEnumerable GroupJoin(this IAsyncEnumerable outer, IAsyncEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func, TResult> resultSelector) { if (outer == null) throw new ArgumentNullException(nameof(outer)); if (inner == null) throw new ArgumentNullException(nameof(inner)); if (outerKeySelector == null) throw new ArgumentNullException(nameof(outerKeySelector)); if (innerKeySelector == null) throw new ArgumentNullException(nameof(innerKeySelector)); if (resultSelector == null) throw new ArgumentNullException(nameof(resultSelector)); return outer.GroupJoin(inner, outerKeySelector, innerKeySelector, resultSelector, EqualityComparer.Default); } private sealed class GroupJoinAsyncEnumerable : IAsyncEnumerable { private readonly IAsyncEnumerable _outer; private readonly IAsyncEnumerable _inner; private readonly Func _outerKeySelector; private readonly Func _innerKeySelector; private readonly Func, TResult> _resultSelector; private readonly IEqualityComparer _comparer; public GroupJoinAsyncEnumerable( IAsyncEnumerable outer, IAsyncEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func, TResult> resultSelector, IEqualityComparer comparer) { _outer = outer; _inner = inner; _outerKeySelector = outerKeySelector; _innerKeySelector = innerKeySelector; _resultSelector = resultSelector; _comparer = comparer; } public IAsyncEnumerator GetEnumerator() => new GroupJoinAsyncEnumerator( _outer.GetEnumerator(), _inner, _outerKeySelector, _innerKeySelector, _resultSelector, _comparer); private sealed class GroupJoinAsyncEnumerator : IAsyncEnumerator { private readonly IAsyncEnumerator _outer; private readonly IAsyncEnumerable _inner; private readonly Func _outerKeySelector; private readonly Func _innerKeySelector; private readonly Func, TResult> _resultSelector; private readonly IEqualityComparer _comparer; private Internal.Lookup _lookup; public GroupJoinAsyncEnumerator( IAsyncEnumerator outer, IAsyncEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func, TResult> resultSelector, IEqualityComparer comparer) { _outer = outer; _inner = inner; _outerKeySelector = outerKeySelector; _innerKeySelector = innerKeySelector; _resultSelector = resultSelector; _comparer = comparer; } public async Task MoveNext(CancellationToken cancellationToken) { // nothing to do if (!await _outer.MoveNext(cancellationToken).ConfigureAwait(false)) { return false; } if (_lookup == null) { _lookup = await Internal.Lookup.CreateForJoinAsync(_inner, _innerKeySelector, _comparer, cancellationToken).ConfigureAwait(false); } var item = _outer.Current; Current = _resultSelector(item, new AsyncEnumerableAdapter(_lookup[_outerKeySelector(item)])); return true; } public TResult Current { get; private set; } public void Dispose() { _outer.Dispose(); } } } private sealed class AsyncEnumerableAdapter : IAsyncEnumerable { private readonly IEnumerable _source; public AsyncEnumerableAdapter(IEnumerable source) { _source = source; } public IAsyncEnumerator GetEnumerator() => new AsyncEnumeratorAdapter(_source.GetEnumerator()); private sealed class AsyncEnumeratorAdapter : IAsyncEnumerator { private readonly IEnumerator _enumerator; public AsyncEnumeratorAdapter(IEnumerator enumerator) { _enumerator = enumerator; } public Task MoveNext(CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); #if HAS_AWAIT return Task.FromResult(_enumerator.MoveNext()); #else return TaskEx.FromResult(_enumerator.MoveNext()); #endif } public T Current => _enumerator.Current; public void Dispose() => _enumerator.Dispose(); } } public static IAsyncEnumerable Join(this IAsyncEnumerable outer, IAsyncEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector, IEqualityComparer comparer) { if (outer == null) throw new ArgumentNullException(nameof(outer)); if (inner == null) throw new ArgumentNullException(nameof(inner)); if (outerKeySelector == null) throw new ArgumentNullException(nameof(outerKeySelector)); if (innerKeySelector == null) throw new ArgumentNullException(nameof(innerKeySelector)); if (resultSelector == null) throw new ArgumentNullException(nameof(resultSelector)); if (comparer == null) throw new ArgumentNullException(nameof(comparer)); return Create(() => { var oe = outer.GetEnumerator(); var ie = inner.GetEnumerator(); var cts = new CancellationTokenDisposable(); var d = Disposable.Create(cts, oe, ie); var current = default(TResult); var useOuter = true; var outerMap = new Dictionary>(comparer); var innerMap = new Dictionary>(comparer); var q = new Queue(); var f = default(Func>); f = async ct => { if (q.Count > 0) { current = q.Dequeue(); return true; } var b = useOuter; if (ie == null && oe == null) { return false; } if (ie == null) b = true; else if (oe == null) b = false; useOuter = !useOuter; var enqueue = new Func((o, i) => { var result = resultSelector(o, i); q.Enqueue(result); return true; }); if (b) { if (await oe.MoveNext(ct).ConfigureAwait(false)) { var element = oe.Current; var key = default(TKey); key = outerKeySelector(element); var outerList = default(List); if (!outerMap.TryGetValue(key, out outerList)) { outerList = new List(); outerMap.Add(key, outerList); } outerList.Add(element); var innerList = default(List); if (!innerMap.TryGetValue(key, out innerList)) { innerList = new List(); innerMap.Add(key, innerList); } foreach (var v in innerList) { if (!enqueue(element, v)) return false; } return await f(ct).ConfigureAwait(false); } oe.Dispose(); oe = null; return await f(ct).ConfigureAwait(false); } if (await ie.MoveNext(ct).ConfigureAwait(false)) { var element = ie.Current; var key = innerKeySelector(element); var innerList = default(List); if (!innerMap.TryGetValue(key, out innerList)) { innerList = new List(); innerMap.Add(key, innerList); } innerList.Add(element); var outerList = default(List); if (!outerMap.TryGetValue(key, out outerList)) { outerList = new List(); outerMap.Add(key, outerList); } foreach (var v in outerList) { if (!enqueue(v, element)) return false; } return await f(ct).ConfigureAwait(false); } ie.Dispose(); ie = null; return await f(ct).ConfigureAwait(false); }; return Create( f, () => current, d.Dispose, ie ); }); } public static IAsyncEnumerable Join(this IAsyncEnumerable outer, IAsyncEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector) { if (outer == null) throw new ArgumentNullException(nameof(outer)); if (inner == null) throw new ArgumentNullException(nameof(inner)); if (outerKeySelector == null) throw new ArgumentNullException(nameof(outerKeySelector)); if (innerKeySelector == null) throw new ArgumentNullException(nameof(innerKeySelector)); if (resultSelector == null) throw new ArgumentNullException(nameof(resultSelector)); return outer.Join(inner, outerKeySelector, innerKeySelector, resultSelector, EqualityComparer.Default); } public static IAsyncEnumerable Concat(this IEnumerable> sources) { if (sources == null) throw new ArgumentNullException(nameof(sources)); return sources.Concat_(); } public static IAsyncEnumerable Concat(params IAsyncEnumerable[] sources) { if (sources == null) throw new ArgumentNullException(nameof(sources)); return sources.Concat_(); } private static IAsyncEnumerable Concat_(this IEnumerable> sources) { return Create(() => { var se = sources.GetEnumerator(); var e = default(IAsyncEnumerator); var cts = new CancellationTokenDisposable(); var a = new AssignableDisposable(); var d = Disposable.Create(cts, se, a); var f = default(Func>); f = async ct => { if (e == null) { var b = false; b = se.MoveNext(); if (b) e = se.Current.GetEnumerator(); if (!b) { return false; } a.Disposable = e; } if (await e.MoveNext(ct).ConfigureAwait(false)) { return true; } e.Dispose(); e = null; return await f(ct).ConfigureAwait(false); }; return Create( f, () => e.Current, d.Dispose, a ); }); } public static IAsyncEnumerable SelectMany(this IAsyncEnumerable source, IAsyncEnumerable other) { if (source == null) throw new ArgumentNullException(nameof(source)); if (other == null) throw new ArgumentNullException(nameof(other)); return source.SelectMany(_ => other); } } }