// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System; using System.Collections.Generic; using System.Linq; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerable { public static IAsyncEnumerable Create(Func> getEnumerator) { return new AnonymousAsyncEnumerable(getEnumerator); } private class AnonymousAsyncEnumerable : IAsyncEnumerable { private Func> getEnumerator; public AnonymousAsyncEnumerable(Func> getEnumerator) { this.getEnumerator = getEnumerator; } public IAsyncEnumerator GetEnumerator() { return getEnumerator(); } } private static IAsyncEnumerator Create(Func> moveNext, Func current, Action dispose, IDisposable enumerator) { return Create(async ct => { using (ct.Register(dispose)) { try { var result = await moveNext(ct).ConfigureAwait(false); if (!result) { enumerator?.Dispose(); } return result; } catch { enumerator?.Dispose(); throw; } } }, current, dispose); } public static IAsyncEnumerator Create(Func> moveNext, Func current, Action dispose) { return new AnonymousAsyncEnumerator(moveNext, current, dispose); } private static IAsyncEnumerator Create(Func, Task> moveNext, Func current, Action dispose) { var self = default(IAsyncEnumerator); self = new AnonymousAsyncEnumerator( async ct => { var tcs = new TaskCompletionSource(); var stop = new Action(() => { self.Dispose(); tcs.TrySetCanceled(); }); using (ct.Register(stop)) { return await moveNext(ct, tcs).ConfigureAwait(false); } }, current, dispose ); return self; } private class AnonymousAsyncEnumerator : IAsyncEnumerator { private readonly Func> _moveNext; private readonly Func _current; private readonly Action _dispose; private bool _disposed; public AnonymousAsyncEnumerator(Func> moveNext, Func current, Action dispose) { _moveNext = moveNext; _current = current; _dispose = dispose; } public Task MoveNext(CancellationToken cancellationToken) { if (_disposed) return TaskExt.False; return _moveNext(cancellationToken); } public T Current { get { return _current(); } } public void Dispose() { if (!_disposed) { _disposed = true; _dispose(); } } } public static IAsyncEnumerable Return(TValue value) { return new[] { value }.ToAsyncEnumerable(); } public static IAsyncEnumerable Throw(Exception exception) { if (exception == null) throw new ArgumentNullException(nameof(exception)); return Create(() => Create( ct => TaskExt.Throw(exception), () => { throw new InvalidOperationException(); }, () => { }) ); } public static IAsyncEnumerable Never() { return Create(() => Create( (ct, tcs) => tcs.Task, () => { throw new InvalidOperationException(); }, () => { }) ); } public static IAsyncEnumerable Empty() { return Create(() => Create( ct => TaskExt.False, () => { throw new InvalidOperationException(); }, () => { }) ); } public static IAsyncEnumerable Range(int start, int count) { if (count < 0) throw new ArgumentOutOfRangeException(nameof(count)); return Enumerable.Range(start, count).ToAsyncEnumerable(); } public static IAsyncEnumerable Repeat(TResult element, int count) { if (count < 0) throw new ArgumentOutOfRangeException(nameof(count)); return Enumerable.Repeat(element, count).ToAsyncEnumerable(); } public static IAsyncEnumerable Repeat(TResult element) { return Create(() => { return Create( ct => TaskExt.True, () => element, () => { } ); }); } public static IAsyncEnumerable Defer(Func> factory) { if (factory == null) throw new ArgumentNullException(nameof(factory)); return Create(() => factory().GetEnumerator()); } public static IAsyncEnumerable Generate(TState initialState, Func condition, Func iterate, Func resultSelector) { if (condition == null) throw new ArgumentNullException(nameof(condition)); if (iterate == null) throw new ArgumentNullException(nameof(iterate)); if (resultSelector == null) throw new ArgumentNullException(nameof(resultSelector)); return Create(() => { var i = initialState; var started = false; var current = default(TResult); return Create( ct => { var b = false; try { if (started) i = iterate(i); b = condition(i); if (b) current = resultSelector(i); } catch (Exception ex) { return TaskExt.Throw(ex); } if (!b) return TaskExt.False; if (!started) started = true; return TaskExt.True; }, () => current, () => { } ); }); } public static IAsyncEnumerable Using(Func resourceFactory, Func> enumerableFactory) where TResource : IDisposable { if (resourceFactory == null) throw new ArgumentNullException(nameof(resourceFactory)); if (enumerableFactory == null) throw new ArgumentNullException(nameof(enumerableFactory)); return Create(() => { var resource = resourceFactory(); var e = default(IAsyncEnumerator); try { e = enumerableFactory(resource).GetEnumerator(); } catch (Exception) { resource.Dispose(); throw; } var cts = new CancellationTokenDisposable(); var d = Disposable.Create(cts, resource, e); var current = default(TSource); return Create( async ct => { bool res; try { res = await e.MoveNext(cts.Token).ConfigureAwait(false); } catch (Exception) { d.Dispose(); throw; } if (res) { current = e.Current; return true; } d.Dispose(); return false; }, () => current, d.Dispose, null ); }); } } }