// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using FluentAssertions; using Xunit; namespace Tests { public class AsyncEnumerableTests { protected static readonly IAsyncEnumerable Return42 = new[] { 42 }.ToAsyncEnumerable(); protected async Task AssertThrowsAsync(Task t) where TException : Exception { await Assert.ThrowsAsync(() => t); } protected async Task AssertThrowsAsync(Task t, Exception e) { try { await t; } catch (Exception ex) { Assert.Same(e, ex); } } protected Task AssertThrowsAsync(ValueTask t, Exception e) { return AssertThrowsAsync(t.AsTask(), e); } protected async Task NoNextAsync(IAsyncEnumerator e) { Assert.False(await e.MoveNextAsync()); } protected async Task HasNextAsync(IAsyncEnumerator e, T value) { Assert.True(await e.MoveNextAsync()); Assert.Equal(value, e.Current); } protected async Task SequenceIdentity(IAsyncEnumerable enumerable) { var en1 = enumerable.GetAsyncEnumerator(); var en2 = enumerable.GetAsyncEnumerator(); Assert.Equal(en1.GetType(), en2.GetType()); await en1.DisposeAsync(); await en2.DisposeAsync(); var res1 = await enumerable.ToList(); var res2 = await enumerable.ToList(); res1.ShouldAllBeEquivalentTo(res2); } protected static IAsyncEnumerable Throw(Exception exception) { if (exception == null) throw new ArgumentNullException(nameof(exception)); #if NO_TASK_FROMEXCEPTION var tcs = new TaskCompletionSource(); tcs.TrySetException(exception); var moveNextThrows = new ValueTask(tcs.Task); #else var moveNextThrows = new ValueTask(Task.FromException(exception)); #endif return AsyncEnumerable.CreateEnumerable( _ => AsyncEnumerable.CreateEnumerator( () => moveNextThrows, current: null, dispose: null) ); } private void AssertThrows(Action a, Func assert) where E : Exception { var hasFailed = false; try { a(); } catch (E e) { Assert.True(assert(e)); hasFailed = true; } if (!hasFailed) { Assert.True(false); } } } }