瀏覽代碼

Update tests to work with ValueTask

Oren Novotny 7 年之前
父節點
當前提交
f24aa5d9c4

+ 1 - 1
Ix.NET/Source/System.Interactive.Async.Tests/AsyncTests.Bugs.cs

@@ -63,7 +63,7 @@ namespace Tests
             var ys = xs.Select(x => { if (x == 1) throw ex; return x; });
 
             var e = ys.GetAsyncEnumerator();
-            await Assert.ThrowsAsync<Exception>(() => e.MoveNextAsync());
+            await AssertX.ThrowsAsync<Exception>(() => e.MoveNextAsync());
 
             var result = await disposed.Task;
             Assert.True(result);

+ 5 - 0
Ix.NET/Source/System.Interactive.Async.Tests/System.Interactive.Async.Tests.csproj

@@ -27,6 +27,11 @@
 
   </ItemGroup>
 
+  <ItemGroup>
+    <Compile Include="..\System.Linq.Async.Tests\ValueTaskHelpers.cs" />
+    <Compile Include="..\System.Linq.Async.Tests\TaskExt.cs" />
+  </ItemGroup>
+
   <ItemGroup>
     <None Update="AsyncQueryableExTests.Generated.tt">
       <LastGenOutput>AsyncQueryableExTests.Generated.cs</LastGenOutput>

+ 2 - 2
Ix.NET/Source/System.Linq.Async.Tests/System/Linq/AsyncEnumerableTests.cs

@@ -88,9 +88,9 @@ namespace Tests
 #if NO_TASK_FROMEXCEPTION
             var tcs = new TaskCompletionSource<bool>();
             tcs.TrySetException(exception);
-            var moveNextThrows = tcs.Task;
+            var moveNextThrows = new ValueTask<bool>(tcs.Task);
 #else
-            var moveNextThrows = Task.FromException<bool>(exception);
+            var moveNextThrows = new ValueTask<bool>(Task.FromException<bool>(exception));
 #endif
 
             return AsyncEnumerable.CreateEnumerable(

+ 0 - 21
Ix.NET/Source/System.Linq.Async.Tests/System/Linq/AsyncEnumeratorTests.cs

@@ -1,21 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the Apache 2.0 License.
-// See the LICENSE file in the project root for more information. 
-
-using System;
-using System.Collections.Generic;
-using Xunit;
-
-namespace Tests
-{
-    public class AsyncEnumeratorTests
-    {
-        [Fact]
-        public void MoveNextExtension_Null()
-        {
-            var en = default(IAsyncEnumerator<int>);
-
-            Assert.ThrowsAsync<ArgumentNullException>(() => en.MoveNextAsync());
-        }
-    }
-}

+ 2 - 2
Ix.NET/Source/System.Linq.Async.Tests/System/Linq/Operators/CreateEnumerator.cs

@@ -15,13 +15,13 @@ namespace Tests
         [Fact]
         public void CreateEnumerator_Null()
         {
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.CreateEnumerator<int>(default(Func<Task<bool>>), () => 3, () => Task.FromResult(true)));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.CreateEnumerator<int>(default(Func<ValueTask<bool>>), () => 3, () => TaskExt.CompletedTask));
         }
 
         [Fact]
         public void CreateEnumerator_Throws()
         {
-            var iter = AsyncEnumerable.CreateEnumerator<int>(() => Task.FromResult(true), () => 3, () => Task.FromResult(true));
+            var iter = AsyncEnumerable.CreateEnumerator<int>(() => TaskExt.True, () => 3, () => TaskExt.CompletedTask);
 
             var enu = (IAsyncEnumerable<int>)iter;
 

+ 1 - 1
Ix.NET/Source/System.Linq.Async.Tests/System/Linq/Operators/ToEnumerable.cs

@@ -35,7 +35,7 @@ namespace Tests
         {
             var ex = new Exception("Bang");
             var xs = Throw<int>(ex).ToEnumerable();
-            AssertThrows<Exception>(() => xs.GetEnumerator().MoveNext(), SingleInnerExceptionMatches(ex));
+            Assert.Throws<Exception>(() => xs.GetEnumerator().MoveNext());
         }
     }
 }

+ 4 - 4
Ix.NET/Source/System.Linq.Async.Tests/System/Linq/Operators/ToObservable.cs

@@ -143,9 +143,9 @@ namespace Tests
 
             var ae = AsyncEnumerable.CreateEnumerable(
                 () => AsyncEnumerable.CreateEnumerator<int>(
-                    () => Task.FromResult(false),
+                    () => TaskExt.False,
                     () => { throw new InvalidOperationException(); },
-                    () => { evt.Set(); return Task.FromResult(true); }));
+                    () => { evt.Set(); return TaskExt.CompletedTask; }));
 
             ae
                 .ToObservable()
@@ -186,7 +186,7 @@ namespace Tests
                     () =>
                     {
                         evt.Set();
-                        return Task.FromResult(true);
+                        return TaskExt.CompletedTask;
                     }));
 
             subscription = ae
@@ -234,7 +234,7 @@ namespace Tests
                     () =>
                     {
                         evt.Set();
-                        return Task.FromResult(true);
+                        return TaskExt.CompletedTask;
                     }));
 
             subscription = ae

+ 14 - 0
Ix.NET/Source/System.Linq.Async.Tests/TaskExt.cs

@@ -0,0 +1,14 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the Apache 2.0 License.
+// See the LICENSE file in the project root for more information. 
+
+namespace System.Threading.Tasks
+{
+    internal static class TaskExt
+    {
+        public static readonly ValueTask<bool> True = new ValueTask<bool>(true);
+        public static readonly ValueTask<bool> False = new ValueTask<bool>(false);
+        public static readonly ValueTask CompletedTask = new ValueTask(Task.FromResult(true));
+        public static readonly ValueTask<bool> Never = new ValueTask<bool>(new TaskCompletionSource<bool>().Task);
+    }
+}

+ 111 - 0
Ix.NET/Source/System.Linq.Async.Tests/ValueTaskHelpers.cs

@@ -0,0 +1,111 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using System.Threading.Tasks;
+using Xunit.Sdk;
+
+namespace System.Linq
+{
+    static class ValueTaskHelpers
+    {
+        public static void Wait<T>(this ValueTask<T> task, int timeOut)
+        {
+            task.AsTask().Wait(timeOut);
+        }
+    }
+}
+
+namespace Xunit
+{
+    static class AssertX
+    {
+        /// <summary>
+        /// Verifies that the exact exception is thrown (and not a derived exception type).
+        /// </summary>
+        /// <typeparam name="T">The type of the exception expected to be thrown</typeparam>
+        /// <param name="testCode">A delegate to the task to be tested</param>
+        /// <returns>The exception that was thrown, when successful</returns>
+        /// <exception cref="ThrowsException">Thrown when an exception was not thrown, or when an exception of the incorrect type is thrown</exception>
+        public static async Task<T> ThrowsAsync<T>(Func<ValueTask> testCode)
+            where T : Exception
+        {
+            return (T)Throws(typeof(T), await RecordExceptionAsync(testCode));
+        }
+
+        /// <summary>
+        /// Verifies that the exact exception is thrown (and not a derived exception type).
+        /// </summary>
+        /// <typeparam name="T">The type of the exception expected to be thrown</typeparam>
+        /// <param name="testCode">A delegate to the task to be tested</param>
+        /// <returns>The exception that was thrown, when successful</returns>
+        /// <exception cref="ThrowsException">Thrown when an exception was not thrown, or when an exception of the incorrect type is thrown</exception>
+        public static async Task<T> ThrowsAsync<T>(Func<ValueTask<bool>> testCode)
+            where T : Exception
+        {
+            return (T)Throws(typeof(T), await RecordExceptionAsync(testCode));
+        }
+
+        /// <summary>
+        /// Records any exception which is thrown by the given task.
+        /// </summary>
+        /// <param name="testCode">The task which may thrown an exception.</param>
+        /// <returns>Returns the exception that was thrown by the code; null, otherwise.</returns>
+        static async Task<Exception> RecordExceptionAsync(Func<ValueTask> testCode)
+        {
+            if (testCode == null)
+            {
+                throw new ArgumentNullException(nameof(testCode));
+            }
+
+            try
+            {
+                await testCode();
+                return null;
+            }
+            catch (Exception ex)
+            {
+                return ex;
+            }
+        }
+
+        /// <summary>
+        /// Records any exception which is thrown by the given task.
+        /// </summary>
+        /// <param name="testCode">The task which may thrown an exception.</param>
+        /// <returns>Returns the exception that was thrown by the code; null, otherwise.</returns>
+        static async Task<Exception> RecordExceptionAsync<T>(Func<ValueTask<T>> testCode)
+        {
+            if (testCode == null)
+            {
+                throw new ArgumentNullException(nameof(testCode));
+            }
+
+            try
+            {
+                await testCode();
+                return null;
+            }
+            catch (Exception ex)
+            {
+                return ex;
+            }
+        }
+
+        static Exception Throws(Type exceptionType, Exception exception)
+        {
+            if (exceptionType == null)
+            {
+                throw new ArgumentNullException(nameof(exceptionType));
+            }
+
+            if (exception == null)
+                throw new ThrowsException(exceptionType);
+
+            if (!exceptionType.Equals(exception.GetType()))
+                throw new ThrowsException(exceptionType, exception);
+
+            return exception;
+        }
+    }
+
+}