ValueTaskHelpers.cs 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. using System;
  2. using System.Threading.Tasks;
  3. using Xunit.Sdk;
  4. namespace System.Linq
  5. {
  6. internal static class ValueTaskHelpers
  7. {
  8. public static void Wait<T>(this ValueTask<T> task, int timeOut)
  9. {
  10. task.AsTask().Wait(timeOut);
  11. }
  12. }
  13. }
  14. namespace Xunit
  15. {
  16. internal static class AssertX
  17. {
  18. /// <summary>
  19. /// Verifies that the exact exception is thrown (and not a derived exception type).
  20. /// </summary>
  21. /// <typeparam name="T">The type of the exception expected to be thrown</typeparam>
  22. /// <param name="testCode">A delegate to the task to be tested</param>
  23. /// <returns>The exception that was thrown, when successful</returns>
  24. /// <exception cref="ThrowsException">Thrown when an exception was not thrown, or when an exception of the incorrect type is thrown</exception>
  25. public static async Task<T> ThrowsAsync<T>(Func<ValueTask> testCode)
  26. where T : Exception
  27. {
  28. return (T)Throws(typeof(T), await RecordExceptionAsync(testCode));
  29. }
  30. /// <summary>
  31. /// Verifies that the exact exception is thrown (and not a derived exception type).
  32. /// </summary>
  33. /// <typeparam name="T">The type of the exception expected to be thrown</typeparam>
  34. /// <param name="testCode">A delegate to the task to be tested</param>
  35. /// <returns>The exception that was thrown, when successful</returns>
  36. /// <exception cref="ThrowsException">Thrown when an exception was not thrown, or when an exception of the incorrect type is thrown</exception>
  37. public static async Task<T> ThrowsAsync<T>(Func<ValueTask<bool>> testCode)
  38. where T : Exception
  39. {
  40. return (T)Throws(typeof(T), await RecordExceptionAsync(testCode));
  41. }
  42. /// <summary>
  43. /// Records any exception which is thrown by the given task.
  44. /// </summary>
  45. /// <param name="testCode">The task which may thrown an exception.</param>
  46. /// <returns>Returns the exception that was thrown by the code; null, otherwise.</returns>
  47. private static async Task<Exception> RecordExceptionAsync(Func<ValueTask> testCode)
  48. {
  49. if (testCode == null)
  50. {
  51. throw new ArgumentNullException(nameof(testCode));
  52. }
  53. try
  54. {
  55. await testCode();
  56. return null;
  57. }
  58. catch (Exception ex)
  59. {
  60. return ex;
  61. }
  62. }
  63. /// <summary>
  64. /// Records any exception which is thrown by the given task.
  65. /// </summary>
  66. /// <param name="testCode">The task which may thrown an exception.</param>
  67. /// <returns>Returns the exception that was thrown by the code; null, otherwise.</returns>
  68. private static async Task<Exception> RecordExceptionAsync<T>(Func<ValueTask<T>> testCode)
  69. {
  70. if (testCode == null)
  71. {
  72. throw new ArgumentNullException(nameof(testCode));
  73. }
  74. try
  75. {
  76. await testCode();
  77. return null;
  78. }
  79. catch (Exception ex)
  80. {
  81. return ex;
  82. }
  83. }
  84. private static Exception Throws(Type exceptionType, Exception exception)
  85. {
  86. if (exceptionType == null)
  87. {
  88. throw new ArgumentNullException(nameof(exceptionType));
  89. }
  90. if (exception == null)
  91. throw ThrowsException.ForNoException(exceptionType);
  92. if (!exceptionType.Equals(exception.GetType()))
  93. throw ThrowsException.ForIncorrectExceptionType(exceptionType, exception);
  94. return exception;
  95. }
  96. }
  97. }