Browse Source

Add support for BindAsync without ParameterInfo (#36505)

Brennan 4 years ago
parent
commit
f6efa13e71

+ 5 - 4
src/Http/Http.Extensions/src/RequestDelegateFactory.cs

@@ -840,12 +840,12 @@ namespace Microsoft.AspNetCore.Http
             var isOptional = IsOptionalParameter(parameter, factoryContext);
 
             // Get the BindAsync method for the type.
-            var bindAsyncExpression = ParameterBindingMethodCache.FindBindAsyncMethod(parameter);
+            var bindAsyncMethod = ParameterBindingMethodCache.FindBindAsyncMethod(parameter);
             // We know BindAsync exists because there's no way to opt-in without defining the method on the type.
-            Debug.Assert(bindAsyncExpression is not null);
+            Debug.Assert(bindAsyncMethod.Expression is not null);
 
             // Compile the delegate to the BindAsync method for this parameter index
-            var bindAsyncDelegate = Expression.Lambda<Func<HttpContext, ValueTask<object?>>>(bindAsyncExpression, HttpContextExpr).Compile();
+            var bindAsyncDelegate = Expression.Lambda<Func<HttpContext, ValueTask<object?>>>(bindAsyncMethod.Expression, HttpContextExpr).Compile();
             factoryContext.ParameterBinders.Add(bindAsyncDelegate);
 
             // boundValues[index]
@@ -854,6 +854,7 @@ namespace Microsoft.AspNetCore.Http
             if (!isOptional)
             {
                 var typeName = TypeNameHelper.GetTypeDisplayName(parameter.ParameterType, fullName: false);
+                var message = bindAsyncMethod.ParamCount == 2 ? $"{typeName}.BindAsync(HttpContext, ParameterInfo)" : $"{typeName}.BindAsync(HttpContext)";
                 var checkRequiredBodyBlock = Expression.Block(
                         Expression.IfThen(
                         Expression.Equal(boundValueExpr, Expression.Constant(null)),
@@ -863,7 +864,7 @@ namespace Microsoft.AspNetCore.Http
                                         HttpContextExpr,
                                         Expression.Constant(typeName),
                                         Expression.Constant(parameter.Name),
-                                        Expression.Constant($"{typeName}.BindAsync(HttpContext, ParameterInfo)"),
+                                        Expression.Constant(message),
                                         Expression.Constant(factoryContext.ThrowOnBadRequest))
                             )
                         )

+ 79 - 2
src/Http/Http.Extensions/test/ParameterBindingMethodCacheTests.cs

@@ -173,12 +173,13 @@ namespace Microsoft.AspNetCore.Http.Extensions.Tests
             var parameter = new MockParameterInfo(type, "bindAsyncRecord");
             var methodFound = cache.FindBindAsyncMethod(parameter);
 
-            Assert.NotNull(methodFound);
+            Assert.NotNull(methodFound.Expression);
+            Assert.Equal(2, methodFound.ParamCount);
 
             var parsedValue = Expression.Variable(type, "parsedValue");
 
             var parseHttpContext = Expression.Lambda<Func<HttpContext, ValueTask<object>>>(
-                Expression.Block(new[] { parsedValue }, methodFound!),
+                Expression.Block(new[] { parsedValue }, methodFound.Expression!),
                 ParameterBindingMethodCache.HttpContextExpr).Compile();
 
             var httpContext = new DefaultHttpContext
@@ -195,6 +196,37 @@ namespace Microsoft.AspNetCore.Http.Extensions.Tests
             Assert.Equal(new BindAsyncRecord(42), await parseHttpContext(httpContext));
         }
 
+        [Fact]
+        public async Task FindBindAsyncMethod_FindsSingleArgBindAsync()
+        {
+            var type = typeof(BindAsyncSingleArgStruct);
+            var cache = new ParameterBindingMethodCache();
+            var parameter = new MockParameterInfo(type, "bindAsyncSingleArgStruct");
+            var methodFound = cache.FindBindAsyncMethod(parameter);
+
+            Assert.NotNull(methodFound.Expression);
+            Assert.Equal(1, methodFound.ParamCount);
+
+            var parsedValue = Expression.Variable(type, "parsedValue");
+
+            var parseHttpContext = Expression.Lambda<Func<HttpContext, ValueTask<object>>>(
+                Expression.Block(new[] { parsedValue }, methodFound.Expression!),
+                ParameterBindingMethodCache.HttpContextExpr).Compile();
+
+            var httpContext = new DefaultHttpContext
+            {
+                Request =
+                {
+                    Headers =
+                    {
+                        ["ETag"] = "42",
+                    },
+                },
+            };
+
+            Assert.Equal(new BindAsyncSingleArgStruct(42), await parseHttpContext(httpContext));
+        }
+
         public static IEnumerable<object[]> BindAsyncParameterInfoData
         {
             get
@@ -209,6 +241,14 @@ namespace Microsoft.AspNetCore.Http.Extensions.Tests
                     {
                         GetFirstParameter((BindAsyncStruct arg) => BindAsyncStructMethod(arg)),
                     },
+                    new[]
+                    {
+                        GetFirstParameter((BindAsyncSingleArgRecord arg) => BindAsyncSingleArgRecordMethod(arg)),
+                    },
+                    new[]
+                    {
+                        GetFirstParameter((BindAsyncSingleArgStruct arg) => BindAsyncSingleArgStructMethod(arg)),
+                    }
                 };
             }
         }
@@ -250,6 +290,11 @@ namespace Microsoft.AspNetCore.Http.Extensions.Tests
         private static void BindAsyncNullableStructMethod(BindAsyncStruct? arg) { }
         private static void NullableReturningBindAsyncStructMethod(NullableReturningBindAsyncStruct arg) { }
 
+        private static void BindAsyncSingleArgRecordMethod(BindAsyncSingleArgRecord arg) { }
+        private static void BindAsyncSingleArgStructMethod(BindAsyncSingleArgStruct arg) { }
+        private static void BindAsyncNullableSingleArgStructMethod(BindAsyncSingleArgStruct? arg) { }
+        private static void NullableReturningBindAsyncSingleArgStructMethod(NullableReturningBindAsyncSingleArgStruct arg) { }
+
         private static ParameterInfo GetFirstParameter<T>(Expression<Action<T>> expr)
         {
             var mc = (MethodCallExpression)expr.Body;
@@ -324,6 +369,38 @@ namespace Microsoft.AspNetCore.Http.Extensions.Tests
                 throw new NotImplementedException();
         }
 
+        private record BindAsyncSingleArgRecord(int Value)
+        {
+            public static ValueTask<BindAsyncSingleArgRecord?> BindAsync(HttpContext context)
+            {
+                if (!int.TryParse(context.Request.Headers.ETag, out var val))
+                {
+                    return new(result: null);
+                }
+
+                return new(result: new(val));
+            }
+        }
+
+        private record struct BindAsyncSingleArgStruct(int Value)
+        {
+            public static ValueTask<BindAsyncSingleArgStruct> BindAsync(HttpContext context)
+            {
+                if (!int.TryParse(context.Request.Headers.ETag, out var val))
+                {
+                    throw new BadHttpRequestException("The request is missing the required ETag header.");
+                }
+
+                return new(result: new(val));
+            }
+        }
+
+        private record struct NullableReturningBindAsyncSingleArgStruct(int Value)
+        {
+            public static ValueTask<NullableReturningBindAsyncStruct?> BindAsync(HttpContext context, ParameterInfo parameter) =>
+                throw new NotImplementedException();
+        }
+
         private class MockParameterInfo : ParameterInfo
         {
             public MockParameterInfo(Type type, string name)

+ 141 - 3
src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs

@@ -612,6 +612,54 @@ namespace Microsoft.AspNetCore.Routing.Internal
             }
         }
 
+        private record struct MyBothBindAsyncStruct(Uri Uri)
+        {
+            public static ValueTask<MyBothBindAsyncStruct> BindAsync(HttpContext context, ParameterInfo parameter)
+            {
+                Assert.True(parameter.ParameterType == typeof(MyBothBindAsyncStruct) || parameter.ParameterType == typeof(MyBothBindAsyncStruct?));
+                Assert.Equal("myBothBindAsyncStruct", parameter.Name);
+
+                if (!Uri.TryCreate(context.Request.Headers.Referer, UriKind.Absolute, out var uri))
+                {
+                    throw new BadHttpRequestException("The request is missing the required Referer header.");
+                }
+
+                return new(result: new(uri));
+            }
+
+            // BindAsync with ParameterInfo is preferred
+            public static ValueTask<MyBothBindAsyncStruct> BindAsync(HttpContext context)
+            {
+                throw new NotImplementedException();
+            }
+        }
+
+        private record struct MySimpleBindAsyncStruct(Uri Uri)
+        {
+            public static ValueTask<MySimpleBindAsyncStruct> BindAsync(HttpContext context)
+            {
+                if (!Uri.TryCreate(context.Request.Headers.Referer, UriKind.Absolute, out var uri))
+                {
+                    throw new BadHttpRequestException("The request is missing the required Referer header.");
+                }
+
+                return new(result: new(uri));
+            }
+        }
+
+        private record MySimpleBindAsyncRecord(Uri Uri)
+        {
+            public static ValueTask<MySimpleBindAsyncRecord?> BindAsync(HttpContext context)
+            {
+                if (!Uri.TryCreate(context.Request.Headers.Referer, UriKind.Absolute, out var uri))
+                {
+                    return new(result: null);
+                }
+
+                return new(result: new(uri));
+            }
+        }
+
         [Theory]
         [MemberData(nameof(TryParsableParameters))]
         public async Task RequestDelegatePopulatesUnattributedTryParsableParametersFromRouteValue(Delegate action, string? routeValue, object? expectedParameterValue)
@@ -724,6 +772,24 @@ namespace Microsoft.AspNetCore.Routing.Internal
             Assert.Equal(new MyBindAsyncStruct(new Uri("https://example.org")), httpContext.Items["myBindAsyncStruct"]);
         }
 
+        [Fact]
+        public async Task RequestDelegateUsesParameterInfoBindAsyncOverOtherBindAsync()
+        {
+            var httpContext = CreateHttpContext();
+
+            httpContext.Request.Headers.Referer = "https://example.org";
+
+            var resultFactory = RequestDelegateFactory.Create((HttpContext httpContext, MyBothBindAsyncStruct? myBothBindAsyncStruct) =>
+            {
+                httpContext.Items["myBothBindAsyncStruct"] = myBothBindAsyncStruct;
+            });
+
+            var requestDelegate = resultFactory.RequestDelegate;
+            await requestDelegate(httpContext);
+
+            Assert.Equal(new MyBothBindAsyncStruct(new Uri("https://example.org")), httpContext.Items["myBothBindAsyncStruct"]);
+        }
+
         [Fact]
         public async Task RequestDelegateUsesTryParseOverBindAsyncGivenExplicitAttribute()
         {
@@ -873,7 +939,7 @@ namespace Microsoft.AspNetCore.Routing.Internal
         [Fact]
         public async Task RequestDelegateLogsBindAsyncFailuresAndSets400Response()
         {
-            // Not supplying any headers will cause the HttpContext TryParse overload to fail.
+            // Not supplying any headers will cause the HttpContext BindAsync overload to return null.
             var httpContext = CreateHttpContext();
             var invoked = false;
 
@@ -905,7 +971,7 @@ namespace Microsoft.AspNetCore.Routing.Internal
         [Fact]
         public async Task RequestDelegateLogsBindAsyncFailuresAndThrowsIfThrowOnBadRequest()
         {
-            // Not supplying any headers will cause the HttpContext TryParse overload to fail.
+            // Not supplying any headers will cause the HttpContext BindAsync overload to return null.
             var httpContext = CreateHttpContext();
             var invoked = false;
 
@@ -931,10 +997,72 @@ namespace Microsoft.AspNetCore.Routing.Internal
             Assert.Equal(400, badHttpRequestException.StatusCode);
         }
 
+        [Fact]
+        public async Task RequestDelegateLogsSingleArgBindAsyncFailuresAndSets400Response()
+        {
+            // Not supplying any headers will cause the HttpContext BindAsync overload to return null.
+            var httpContext = CreateHttpContext();
+            var invoked = false;
+
+            var factoryResult = RequestDelegateFactory.Create((MySimpleBindAsyncRecord mySimpleBindAsyncRecord1,
+                MySimpleBindAsyncRecord mySimpleBindAsyncRecord2) =>
+            {
+                invoked = true;
+            });
+
+            var requestDelegate = factoryResult.RequestDelegate;
+            await requestDelegate(httpContext);
+
+            Assert.False(invoked);
+            Assert.False(httpContext.RequestAborted.IsCancellationRequested);
+            Assert.Equal(400, httpContext.Response.StatusCode);
+
+            var logs = TestSink.Writes.ToArray();
+
+            Assert.Equal(2, logs.Length);
+
+            Assert.Equal(new EventId(4, "RequiredParameterNotProvided"), logs[0].EventId);
+            Assert.Equal(LogLevel.Debug, logs[0].LogLevel);
+            Assert.Equal(@"Required parameter ""MySimpleBindAsyncRecord mySimpleBindAsyncRecord1"" was not provided from MySimpleBindAsyncRecord.BindAsync(HttpContext).", logs[0].Message);
+
+            Assert.Equal(new EventId(4, "RequiredParameterNotProvided"), logs[1].EventId);
+            Assert.Equal(LogLevel.Debug, logs[1].LogLevel);
+            Assert.Equal(@"Required parameter ""MySimpleBindAsyncRecord mySimpleBindAsyncRecord2"" was not provided from MySimpleBindAsyncRecord.BindAsync(HttpContext).", logs[1].Message);
+        }
+
+        [Fact]
+        public async Task RequestDelegateLogsSingleArgBindAsyncFailuresAndThrowsIfThrowOnBadRequest()
+        {
+            // Not supplying any headers will cause the HttpContext BindAsync overload to return null.
+            var httpContext = CreateHttpContext();
+            var invoked = false;
+
+            var factoryResult = RequestDelegateFactory.Create((MySimpleBindAsyncRecord mySimpleBindAsyncRecord1,
+                MySimpleBindAsyncRecord mySimpleBindAsyncRecord2) =>
+            {
+                invoked = true;
+            }, new() { ThrowOnBadRequest = true });
+
+            var requestDelegate = factoryResult.RequestDelegate;
+            var badHttpRequestException = await Assert.ThrowsAsync<BadHttpRequestException>(() => requestDelegate(httpContext));
+
+            Assert.False(invoked);
+
+            // The httpContext should be untouched.
+            Assert.False(httpContext.RequestAborted.IsCancellationRequested);
+            Assert.Equal(200, httpContext.Response.StatusCode);
+            Assert.False(httpContext.Response.HasStarted);
+
+            // We don't log bad requests when we throw.
+            Assert.Empty(TestSink.Writes);
+
+            Assert.Equal(@"Required parameter ""MySimpleBindAsyncRecord mySimpleBindAsyncRecord1"" was not provided from MySimpleBindAsyncRecord.BindAsync(HttpContext).", badHttpRequestException.Message);
+            Assert.Equal(400, badHttpRequestException.StatusCode);
+        }
+
         [Fact]
         public async Task BindAsyncExceptionsAreUncaught()
         {
-            // Not supplying any headers will cause the HttpContext BindAsync overload to fail.
             var httpContext = CreateHttpContext();
 
             var factoryResult = RequestDelegateFactory.Create((MyBindAsyncTypeThatThrows arg1) => { });
@@ -2239,6 +2367,10 @@ namespace Microsoft.AspNetCore.Routing.Internal
                 {
                     context.Items["uri"] = myBindAsyncRecord?.Uri;
                 }
+                void requiredReferenceTypeSimple(HttpContext context, MySimpleBindAsyncRecord mySimpleBindAsyncRecord)
+                {
+                    context.Items["uri"] = mySimpleBindAsyncRecord.Uri;
+                }
 
 
                 void requiredValueType(HttpContext context, MyNullableBindAsyncStruct myNullableBindAsyncStruct)
@@ -2253,11 +2385,16 @@ namespace Microsoft.AspNetCore.Routing.Internal
                 {
                     context.Items["uri"] = myNullableBindAsyncStruct?.Uri;
                 }
+                void requiredValueTypeSimple(HttpContext context, MySimpleBindAsyncStruct mySimpleBindAsyncStruct)
+                {
+                    context.Items["uri"] = mySimpleBindAsyncStruct.Uri;
+                }
 
                 return new object?[][]
                 {
                     new object?[] { (Action<HttpContext, MyBindAsyncRecord>)requiredReferenceType, false, true, false },
                     new object?[] { (Action<HttpContext, MyBindAsyncRecord>)requiredReferenceType, true, false, false, },
+                    new object?[] { (Action<HttpContext, MySimpleBindAsyncRecord>)requiredReferenceTypeSimple, true, false, false },
 
                     new object?[] { (Action<HttpContext, MyBindAsyncRecord?>)defaultReferenceType, false, false, false, },
                     new object?[] { (Action<HttpContext, MyBindAsyncRecord?>)defaultReferenceType, true, false, false },
@@ -2267,6 +2404,7 @@ namespace Microsoft.AspNetCore.Routing.Internal
 
                     new object?[] { (Action<HttpContext, MyNullableBindAsyncStruct>)requiredValueType, false, true, true },
                     new object?[] { (Action<HttpContext, MyNullableBindAsyncStruct>)requiredValueType, true, false, true },
+                    new object?[] { (Action<HttpContext, MySimpleBindAsyncStruct>)requiredValueTypeSimple, true, false, true },
 
                     new object?[] { (Action<HttpContext, MyNullableBindAsyncStruct?>)defaultValueType, false, false, true },
                     new object?[] { (Action<HttpContext, MyNullableBindAsyncStruct?>)defaultValueType, true, false, true },

+ 39 - 16
src/Shared/ParameterBindingMethodCache.cs

@@ -26,7 +26,7 @@ namespace Microsoft.AspNetCore.Http
 
         // Since this is shared source, the cache won't be shared between RequestDelegateFactory and the ApiDescriptionProvider sadly :(
         private readonly ConcurrentDictionary<Type, Func<ParameterExpression, Expression>?> _stringMethodCallCache = new();
-        private readonly ConcurrentDictionary<Type, Func<ParameterInfo, Expression>?> _bindAsyncMethodCallCache = new();
+        private readonly ConcurrentDictionary<Type, (Func<ParameterInfo, Expression>?, int)> _bindAsyncMethodCallCache = new();
 
         // If IsDynamicCodeSupported is false, we can't use the static Enum.TryParse<T> since there's no easy way for
         // this code to generate the specific instantiation for any enums used
@@ -47,7 +47,7 @@ namespace Microsoft.AspNetCore.Http
         }
 
         public bool HasBindAsyncMethod(ParameterInfo parameter) =>
-            FindBindAsyncMethod(parameter) is not null;
+            FindBindAsyncMethod(parameter).Expression is not null;
 
         public Func<ParameterExpression, Expression>? FindTryParseMethod(Type type)
         {
@@ -128,12 +128,18 @@ namespace Microsoft.AspNetCore.Http
             return _stringMethodCallCache.GetOrAdd(type, Finder);
         }
 
-        public Expression? FindBindAsyncMethod(ParameterInfo parameter)
+        public (Expression? Expression, int ParamCount) FindBindAsyncMethod(ParameterInfo parameter)
         {
-            static Func<ParameterInfo, Expression>? Finder(Type nonNullableParameterType)
+            static (Func<ParameterInfo, Expression>?, int) Finder(Type nonNullableParameterType)
             {
+                var hasParameterInfo = true;
                 // There should only be one BindAsync method with these parameters since C# does not allow overloading on return type.
                 var methodInfo = nonNullableParameterType.GetMethod("BindAsync", BindingFlags.Public | BindingFlags.Static, new[] { typeof(HttpContext), typeof(ParameterInfo) });
+                if (methodInfo is null)
+                {
+                    hasParameterInfo = false;
+                    methodInfo = nonNullableParameterType.GetMethod("BindAsync", BindingFlags.Public | BindingFlags.Static, new[] { typeof(HttpContext) });
+                }
 
                 // We're looking for a method with the following signatures:
                 // public static ValueTask<{type}> BindAsync(HttpContext context, ParameterInfo parameter)
@@ -147,34 +153,51 @@ namespace Microsoft.AspNetCore.Http
                     // ValueTask<{type}>?
                     if (valueTaskResultType == nonNullableParameterType)
                     {
-                        return (parameter) =>
+                        return ((parameter) =>
                         {
-                            // parameter is being intentionally shadowed. We never want to use the outer ParameterInfo inside
-                            // this Func because the ParameterInfo varies after it's been cached for a given parameter type.
-                            var typedCall = Expression.Call(methodInfo, HttpContextExpr, Expression.Constant(parameter));
+                            MethodCallExpression typedCall;
+                            if (hasParameterInfo)
+                            {
+                                // parameter is being intentionally shadowed. We never want to use the outer ParameterInfo inside
+                                // this Func because the ParameterInfo varies after it's been cached for a given parameter type.
+                                typedCall = Expression.Call(methodInfo, HttpContextExpr, Expression.Constant(parameter));
+                            }
+                            else
+                            {
+                                typedCall = Expression.Call(methodInfo, HttpContextExpr);
+                            }
                             return Expression.Call(ConvertValueTaskMethod.MakeGenericMethod(nonNullableParameterType), typedCall);
-                        };
+                        }, hasParameterInfo ? 2 : 1);
                     }
                     // ValueTask<Nullable<{type}>>?
                     else if (valueTaskResultType.IsGenericType &&
                              valueTaskResultType.GetGenericTypeDefinition() == typeof(Nullable<>) &&
                              valueTaskResultType.GetGenericArguments()[0] == nonNullableParameterType)
                     {
-                        return (parameter) =>
+                        return ((parameter) =>
                         {
-                            // parameter is being intentionally shadowed. We never want to use the outer ParameterInfo inside
-                            // this Func because the ParameterInfo varies after it's been cached for a given parameter type.
-                            var typedCall = Expression.Call(methodInfo, HttpContextExpr, Expression.Constant(parameter));
+                            MethodCallExpression typedCall;
+                            if (hasParameterInfo)
+                            {
+                                // parameter is being intentionally shadowed. We never want to use the outer ParameterInfo inside
+                                // this Func because the ParameterInfo varies after it's been cached for a given parameter type.
+                                typedCall = Expression.Call(methodInfo, HttpContextExpr, Expression.Constant(parameter));
+                            }
+                            else
+                            {
+                                typedCall = Expression.Call(methodInfo, HttpContextExpr);
+                            }
                             return Expression.Call(ConvertValueTaskOfNullableResultMethod.MakeGenericMethod(nonNullableParameterType), typedCall);
-                        };
+                        }, hasParameterInfo ? 2 : 1);
                     }
                 }
 
-                return null;
+                return (null, 0);
             }
 
             var nonNullableParameterType = Nullable.GetUnderlyingType(parameter.ParameterType) ?? parameter.ParameterType;
-            return _bindAsyncMethodCallCache.GetOrAdd(nonNullableParameterType, Finder)?.Invoke(parameter);
+            var (method, paramCount) = _bindAsyncMethodCallCache.GetOrAdd(nonNullableParameterType, Finder);
+            return (method?.Invoke(parameter), paramCount);
         }
 
         private static MethodInfo GetEnumTryParseMethod(bool preferNonGenericEnumParseOverload)