Browse Source

Find inherited TryParse and BindAsync (#36688)

Brennan 4 years ago
parent
commit
f2c7b53b32

+ 127 - 0
src/Http/Http.Extensions/test/ParameterBindingMethodCacheTests.cs

@@ -77,6 +77,7 @@ namespace Microsoft.AspNetCore.Http.Extensions.Tests
         [Theory]
         [InlineData(typeof(TryParseStringRecord))]
         [InlineData(typeof(TryParseStringStruct))]
+        [InlineData(typeof(TryParseInheritClassWithFormatProvider))]
         public void FindTryParseStringMethod_ReturnsTheExpectedTryParseMethodWithInvariantCultureCustomType(Type type)
         {
             var methodFound = new ParameterBindingMethodCache().FindTryParseMethod(@type);
@@ -94,6 +95,24 @@ namespace Microsoft.AspNetCore.Http.Extensions.Tests
             Assert.True(((call.Arguments[1] as ConstantExpression)!.Value as CultureInfo)!.Equals(CultureInfo.InvariantCulture));
         }
 
+        [Theory]
+        [InlineData(typeof(TryParseNoFormatProviderRecord))]
+        [InlineData(typeof(TryParseNoFormatProviderStruct))]
+        [InlineData(typeof(TryParseInheritClass))]
+        public void FindTryParseMethod_WithNoFormatProvider(Type type)
+        {
+            var methodFound = new ParameterBindingMethodCache().FindTryParseMethod(@type);
+            Assert.NotNull(methodFound);
+
+            var call = methodFound!(Expression.Variable(type, "parsedValue")) as MethodCallExpression;
+            Assert.NotNull(call);
+            var parameters = call!.Method.GetParameters();
+
+            Assert.Equal(2, parameters.Length);
+            Assert.Equal(typeof(string), parameters[0].ParameterType);
+            Assert.True(parameters[1].IsOut);
+        }
+
         public static IEnumerable<object[]> TryParseStringParameterInfoData
         {
             get
@@ -249,6 +268,14 @@ namespace Microsoft.AspNetCore.Http.Extensions.Tests
                     new[]
                     {
                         GetFirstParameter((BindAsyncSingleArgStruct arg) => BindAsyncSingleArgStructMethod(arg)),
+                    },
+                    new[]
+                    {
+                        GetFirstParameter((InheritBindAsync arg) => InheritBindAsyncMethod(arg))
+                    },
+                    new[]
+                    {
+                        GetFirstParameter((InheritBindAsyncWithParameterInfo arg) => InheritBindAsyncWithParameterInfoMethod(arg))
                     }
                 };
             }
@@ -285,6 +312,7 @@ namespace Microsoft.AspNetCore.Http.Extensions.Tests
         [InlineData(typeof(InvalidTooFewArgsTryParseClass))]
         [InlineData(typeof(InvalidNonStaticTryParseStruct))]
         [InlineData(typeof(InvalidNonStaticTryParseClass))]
+        [InlineData(typeof(TryParseWrongTypeInheritClass))]
         public void FindTryParseMethod_ThrowsIfInvalidTryParseOnType(Type type)
         {
             var ex = Assert.Throws<InvalidOperationException>(
@@ -308,6 +336,8 @@ namespace Microsoft.AspNetCore.Http.Extensions.Tests
         [InlineData(typeof(InvalidWrongReturnBindAsyncClass))]
         [InlineData(typeof(InvalidWrongParamBindAsyncStruct))]
         [InlineData(typeof(InvalidWrongParamBindAsyncClass))]
+        [InlineData(typeof(BindAsyncWrongTypeInherit))]
+        [InlineData(typeof(BindAsyncWithParameterInfoWrongTypeInherit))]
         public void FindBindAsyncMethod_ThrowsIfInvalidBindAsyncOnType(Type type)
         {
             var cache = new ParameterBindingMethodCache();
@@ -350,6 +380,8 @@ namespace Microsoft.AspNetCore.Http.Extensions.Tests
 
         private static void BindAsyncSingleArgRecordMethod(BindAsyncSingleArgRecord arg) { }
         private static void BindAsyncSingleArgStructMethod(BindAsyncSingleArgStruct arg) { }
+        private static void InheritBindAsyncMethod(InheritBindAsync arg) { }
+        private static void InheritBindAsyncWithParameterInfoMethod(InheritBindAsyncWithParameterInfo args) { }
 
         private static ParameterInfo GetFirstParameter<T>(Expression<Action<T>> expr)
         {
@@ -538,6 +570,67 @@ namespace Microsoft.AspNetCore.Http.Extensions.Tests
             }
         }
 
+        private record TryParseNoFormatProviderRecord(int Value)
+        {
+            public static bool TryParse(string? value, out TryParseNoFormatProviderRecord? result)
+            {
+                if (!int.TryParse(value, out var val))
+                {
+                    result = null;
+                    return false;
+                }
+
+                result = new TryParseNoFormatProviderRecord(val);
+                return true;
+            }
+        }
+
+        private record struct TryParseNoFormatProviderStruct(int Value)
+        {
+            public static bool TryParse(string? value, out TryParseNoFormatProviderStruct result)
+            {
+                if (!int.TryParse(value, out var val))
+                {
+                    result = default;
+                    return false;
+                }
+
+                result = new TryParseNoFormatProviderStruct(val);
+                return true;
+            }
+        }
+
+        private class BaseTryParseClass<T>
+        {
+            public static bool TryParse(string? value, out T? result)
+            {
+                result = default(T);
+                return false;
+            }
+        }
+
+        private class TryParseInheritClass : BaseTryParseClass<TryParseInheritClass>
+        {
+        }
+
+        // using wrong T on purpose
+        private class TryParseWrongTypeInheritClass : BaseTryParseClass<TryParseInheritClass>
+        {
+        }
+
+        private class BaseTryParseClassWithFormatProvider<T>
+        {
+            public static bool TryParse(string? value, IFormatProvider formatProvider, out T? result)
+            {
+                result = default(T);
+                return false;
+            }
+        }
+
+        private class TryParseInheritClassWithFormatProvider : BaseTryParseClassWithFormatProvider<TryParseInheritClassWithFormatProvider>
+        {
+        }
+
         private record BindAsyncRecord(int Value)
         {
             public static ValueTask<BindAsyncRecord?> BindAsync(HttpContext context, ParameterInfo parameter)
@@ -644,6 +737,40 @@ namespace Microsoft.AspNetCore.Http.Extensions.Tests
                 throw new NotImplementedException();
         }
 
+        private class BaseBindAsync<T>
+        {
+            public static ValueTask<T?> BindAsync(HttpContext context)
+            {
+                return new(default(T));
+            }
+        }
+
+        private class InheritBindAsync : BaseBindAsync<InheritBindAsync>
+        {
+        }
+
+        // Using wrong T on purpose
+        private class BindAsyncWrongTypeInherit : BaseBindAsync<InheritBindAsync>
+        {
+        }
+
+        private class BaseBindAsyncWithParameterInfo<T>
+        {
+            public static ValueTask<T?> BindAsync(HttpContext context, ParameterInfo parameter)
+            {
+                return new(default(T));
+            }
+        }
+
+        private class InheritBindAsyncWithParameterInfo : BaseBindAsyncWithParameterInfo<InheritBindAsyncWithParameterInfo>
+        {
+        }
+
+        // Using wrong T on purpose
+        private class BindAsyncWithParameterInfoWrongTypeInherit : BaseBindAsyncWithParameterInfo<InheritBindAsync>
+        {
+        }
+
         private class MockParameterInfo : ParameterInfo
         {
             public MockParameterInfo(Type type, string name)

+ 6 - 6
src/Shared/ParameterBindingMethodCache.cs

@@ -106,7 +106,7 @@ namespace Microsoft.AspNetCore.Http
                         expression);
                 }
 
-                methodInfo = type.GetMethod("TryParse", BindingFlags.Public | BindingFlags.Static, new[] { typeof(string), typeof(IFormatProvider), type.MakeByRefType() });
+                methodInfo = type.GetMethod("TryParse", BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy, new[] { typeof(string), typeof(IFormatProvider), type.MakeByRefType() });
 
                 if (methodInfo is not null && methodInfo.ReturnType == typeof(bool))
                 {
@@ -117,14 +117,14 @@ namespace Microsoft.AspNetCore.Http
                         expression);
                 }
 
-                methodInfo = type.GetMethod("TryParse", BindingFlags.Public | BindingFlags.Static, new[] { typeof(string), type.MakeByRefType() });
+                methodInfo = type.GetMethod("TryParse", BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy, new[] { typeof(string), type.MakeByRefType() });
 
                 if (methodInfo is not null && methodInfo.ReturnType == typeof(bool))
                 {
                     return (expression) => Expression.Call(methodInfo, TempSourceStringExpr, expression);
                 }
 
-                if (type.GetMethod("TryParse", BindingFlags.Public | BindingFlags.Static | BindingFlags.Instance) is MethodInfo invalidMethod)
+                if (type.GetMethod("TryParse", BindingFlags.Public | BindingFlags.Static | BindingFlags.Instance | BindingFlags.FlattenHierarchy) is MethodInfo invalidMethod)
                 {
                     var stringBuilder = new StringBuilder();
                     stringBuilder.AppendLine(CultureInfo.InvariantCulture, $"TryParse method found on {TypeNameHelper.GetTypeDisplayName(type, fullName: false)} with incorrect format. Must be a static method with format");
@@ -149,11 +149,11 @@ namespace Microsoft.AspNetCore.Http
             {
                 var hasParameterInfo = true;
                 // There should only be one BindAsync method with these parameters since C# does not allow overloading on return type.
-                var methodInfo = nonNullableParameterType.GetMethod("BindAsync", BindingFlags.Public | BindingFlags.Static, new[] { typeof(HttpContext), typeof(ParameterInfo) });
+                var methodInfo = nonNullableParameterType.GetMethod("BindAsync", BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy, new[] { typeof(HttpContext), typeof(ParameterInfo) });
                 if (methodInfo is null)
                 {
                     hasParameterInfo = false;
-                    methodInfo = nonNullableParameterType.GetMethod("BindAsync", BindingFlags.Public | BindingFlags.Static, new[] { typeof(HttpContext) });
+                    methodInfo = nonNullableParameterType.GetMethod("BindAsync", BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy, new[] { typeof(HttpContext) });
                 }
 
                 // We're looking for a method with the following signatures:
@@ -207,7 +207,7 @@ namespace Microsoft.AspNetCore.Http
                     }
                 }
 
-                if (nonNullableParameterType.GetMethod("BindAsync", BindingFlags.Public | BindingFlags.Static | BindingFlags.Instance) is MethodInfo invalidBindMethod)
+                if (nonNullableParameterType.GetMethod("BindAsync", BindingFlags.Public | BindingFlags.Static | BindingFlags.Instance | BindingFlags.FlattenHierarchy) is MethodInfo invalidBindMethod)
                 {
                     var stringBuilder = new StringBuilder();
                     stringBuilder.AppendLine(CultureInfo.InvariantCulture, $"BindAsync method found on {TypeNameHelper.GetTypeDisplayName(nonNullableParameterType, fullName: false)} with incorrect format. Must be a static method with format");