// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System.Collections.Generic; using System.Collections.ObjectModel; using System.Globalization; using System.Linq.Expressions; using System.Reflection; namespace System.Linq { /// /// Rewrites an expression tree representation using AsyncQueryable methods to the corresponding AsyncEnumerable equivalents. /// internal class AsyncEnumerableRewriter : ExpressionVisitor { private static volatile ILookup _methods; protected override Expression VisitConstant(ConstantExpression node) { // // Not an expression representation obtained from the async enumerable query provider, // so just a plain constant that can be returned as-is. // if (!(node.Value is AsyncEnumerableQuery enumerableQuery)) { return node; } // // Expression representation obtained from the async enumerable query provider, so first // check whether it wraps an enumerable sequence that has been evaluated already. // if (enumerableQuery.Enumerable != null) { var publicType = GetPublicType(enumerableQuery.Enumerable.GetType()); return Expression.Constant(enumerableQuery.Enumerable, publicType); } // // If not evaluated yet, inline the expression representation. // return Visit(enumerableQuery.Expression); } protected override Expression VisitMethodCall(MethodCallExpression node) { var obj = Visit(node.Object); var args = Visit(node.Arguments); // // Nothing changed during the visit; just some unrelated method call that can // be returned as-is. // if (obj == node.Object && args == node.Arguments) { return node; } var typeArgs = node.Method.IsGenericMethod ? node.Method.GetGenericArguments() : null; // // Check whether the method is compatible with the recursively rewritten instance // and arguments expressions. If so, create a new call expression. // if ((node.Method.IsStatic || node.Method.DeclaringType.IsAssignableFrom(obj.Type)) && ArgsMatch(node.Method, args, typeArgs)) { return Expression.Call(obj, node.Method, args); } MethodInfo method; // // Find a corresponding method in the non-expression world, e.g. rewriting from // the AsyncQueryable methods to the ones on AsyncEnumerable. // if (node.Method.DeclaringType == typeof(AsyncQueryable)) { method = FindEnumerableMethod(node.Method.Name, args, typeArgs); args = FixupQuotedArgs(method, args); return Expression.Call(obj, method, args); } else { method = FindMethod(node.Method.DeclaringType, node.Method.Name, args, typeArgs, BindingFlags.Static | (node.Method.IsPublic ? BindingFlags.Public : BindingFlags.NonPublic)); args = FixupQuotedArgs(method, args); } return Expression.Call(obj, method, args); } protected override Expression VisitLambda(Expression node) { // // Don't recurse into lambdas; all the ones returning IAsyncQueryable // are compatible with their IAsyncEnumerable counterparts due to the // covariant return type. // return node; } protected override Expression VisitParameter(ParameterExpression node) { // // See remark on VisitLambda. // return node; } private static Type GetPublicType(Type type) { if (!type.GetTypeInfo().IsNestedPrivate) { return type; } foreach (var ifType in type.GetInterfaces()) { if (ifType.GetTypeInfo().IsGenericType) { var def = ifType.GetGenericTypeDefinition(); if (def == typeof(IAsyncEnumerable<>) || def == typeof(IAsyncGrouping<,>)) { return ifType; } } } // // NB: Add if we ever decide to add the non-generic interface. // //if (typeof(IAsyncEnumerable).IsAssignableFrom(type)) //{ // return typeof(IAsyncEnumerable); //} return type; } private static bool ArgsMatch(MethodInfo method, ReadOnlyCollection args, Type[] typeArgs) { // // Number of parameters should match the number of arguments to bind. // var parameters = method.GetParameters(); if (parameters.Length != args.Count) { return false; } // // Both should be generic or non-generic. // if (!method.IsGenericMethod && typeArgs != null && typeArgs.Length != 0) { return false; } // // Closed generic methods need to get converted to their open generic counterpart. // if (!method.IsGenericMethodDefinition && method.IsGenericMethod && method.ContainsGenericParameters) { method = method.GetGenericMethodDefinition(); } // // For generic methods, close the candidate using the specified type arguments. // if (method.IsGenericMethodDefinition) { // // We should have at least 1 type argument. // if (typeArgs == null || typeArgs.Length == 0) { return false; } // // The number of type arguments needed should match the specified type argument count. // if (method.GetGenericArguments().Length != typeArgs.Length) { return false; } // // Close the generic method and re-obtain the parameters. // method = method.MakeGenericMethod(typeArgs); parameters = method.GetParameters(); } // // Check for contravariant assignability of each parameter. // for (var i = 0; i < args.Count; i++) { var type = parameters[i].ParameterType; // // Hardening against reflection quirks. // if (type == null) { return false; } // // Deal with ref or out parameters by using the element type which can // match the corresponding expression type (ref passing is not encoded // in the type of expression trees). // if (type.IsByRef) { type = type.GetElementType(); } var expression = args[i]; // // If the expression is assignable to the parameter, all is good. If not, // it's possible there's a match because we're dealing with a quote that // needs to be unpacked. // if (!type.IsAssignableFrom(expression.Type)) { // // Unpack the quote, if any. See AsyncQueryable for examples of operators // that hit this case. // if (expression.NodeType == ExpressionType.Quote) { expression = ((UnaryExpression)expression).Operand; } // // Try assigning the raw expression type or the quote-free expression type // to the parameter. If none of these work, there's no match. // if (!type.IsAssignableFrom(expression.Type) && !type.IsAssignableFrom(StripExpression(expression.Type))) { return false; } } } return true; } private static ReadOnlyCollection FixupQuotedArgs(MethodInfo method, ReadOnlyCollection argList) { // // Get all of the method parameters. No fix-up needed if empty. // var parameters = method.GetParameters(); if (parameters.Length != 0) { var list = default(List); // // Process all parameters. If any fixup is needed, the list will // get assigned. // for (var i = 0; i < parameters.Length; i++) { var expression = argList[i]; var parameterInfo = parameters[i]; // // Perform the fix-up if needed and check the outcome. If a // change was made, the list is lazily allocated. // expression = FixupQuotedExpression(parameterInfo.ParameterType, expression); if (list == null && expression != argList[i]) { list = new List(argList.Count); for (var j = 0; j < i; j++) { list.Add(argList[j]); } } if (list != null) { list.Add(expression); } } // // If any argument was fixed up, return a new argument list. // if (list != null) { argList = new ReadOnlyCollection(list); } } return argList; } private static Expression FixupQuotedExpression(Type type, Expression expression) { var res = expression; // // Keep unquoting until assignability checks pass. // while (!type.IsAssignableFrom(res.Type)) { // // In case this is not a quote, bail out early. // if (res.NodeType != ExpressionType.Quote) { // // Array initialization expressions need special care by unquoting the elements. // if (!type.IsAssignableFrom(res.Type) && type.IsArray && res.NodeType == ExpressionType.NewArrayInit) { var unquotedType = StripExpression(res.Type); if (type.IsAssignableFrom(unquotedType)) { var newArrayExpression = (NewArrayExpression)res; var count = newArrayExpression.Expressions.Count; var elementType = type.GetElementType(); var list = new List(count); for (var i = 0; i < count; i++) { list.Add(FixupQuotedExpression(elementType, newArrayExpression.Expressions[i])); } expression = Expression.NewArrayInit(elementType, list); } } return expression; } // // Unquote and try again; at most two passes should be needed. // res = ((UnaryExpression)res).Operand; } return res; } private static Type StripExpression(Type type) { // // Array of quotes need to be stripped, so extract the element type. // var elemType = type.IsArray ? type.GetElementType() : type; // // Try to find Expression and obtain T. // var genType = FindGenericType(typeof(Expression<>), elemType); if (genType != null) { elemType = genType.GetGenericArguments()[0]; } // // Not an array, nothing to do here. // if (!type.IsArray) { return type; } // // Reconstruct the array type from the stripped element type. // var arrayRank = type.GetArrayRank(); if (arrayRank != 1) { return elemType.MakeArrayType(arrayRank); } return elemType.MakeArrayType(); } private static MethodInfo FindEnumerableMethod(string name, ReadOnlyCollection args, params Type[] typeArgs) { // // Ensure the cached lookup table for AsyncEnumerable methods is initialized. // if (_methods == null) { _methods = typeof(AsyncEnumerable).GetMethods(BindingFlags.Static | BindingFlags.Public).ToLookup(m => m.Name); } // // Find a match based on the method name and the argument types. // var method = _methods[name].FirstOrDefault(m => ArgsMatch(m, args, typeArgs)); if (method == null) { throw new InvalidOperationException(string.Format(CultureInfo.InvariantCulture, "Could not find method with name '{0}' on type '{1}'.", name, typeof(Enumerable))); } // // Close the generic method if needed. // if (typeArgs != null) { return method.MakeGenericMethod(typeArgs); } return method; } private static MethodInfo FindMethod(Type type, string name, ReadOnlyCollection args, Type[] typeArgs, BindingFlags flags) { // // Support the enumerable methods to be defined on another type. // var targetType = type.GetTypeInfo().GetCustomAttribute()?.TargetType ?? type; // // Get all the candidates based on name and fail if none are found. // var methods = targetType.GetMethods(flags).Where(m => m.Name == name).ToArray(); if (methods.Length == 0) { throw new InvalidOperationException(string.Format(CultureInfo.InvariantCulture, "Could not find method with name '{0}' on type '{1}'.", name, type)); } // // Find a match based on arguments and fail if no match is found. // var method = methods.FirstOrDefault(m => ArgsMatch(m, args, typeArgs)); if (method == null) { throw new InvalidOperationException(string.Format(CultureInfo.InvariantCulture, "Could not find a matching method with name '{0}' on type '{1}'.", name, type)); } // // Close the generic method if needed. // if (typeArgs != null) { return method.MakeGenericMethod(typeArgs); } return method; } private static Type FindGenericType(Type definition, Type type) { while (type != null && type != typeof(object)) { // // If the current type matches the specified definition, return. // if (type.GetTypeInfo().IsGenericType && type.GetGenericTypeDefinition() == definition) { return type; } // // Probe all interfaces implemented by the current type. // if (definition.GetTypeInfo().IsInterface) { foreach (var ifType in type.GetInterfaces()) { var res = FindGenericType(definition, ifType); if (res != null) { return res; } } } // // Continue up the type hierarchy. // type = type.GetTypeInfo().BaseType; } return null; } } }