AsyncEnumerableRewriter.cs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the MIT License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Collections.ObjectModel;
  6. using System.Globalization;
  7. using System.Linq.Expressions;
  8. using System.Reflection;
  9. namespace System.Linq
  10. {
  11. /// <summary>
  12. /// Rewrites an expression tree representation using AsyncQueryable methods to the corresponding AsyncEnumerable equivalents.
  13. /// </summary>
  14. internal class AsyncEnumerableRewriter : ExpressionVisitor
  15. {
  16. private static volatile ILookup<string, MethodInfo>? _methods;
  17. protected override Expression VisitConstant(ConstantExpression node)
  18. {
  19. //
  20. // Not an expression representation obtained from the async enumerable query provider,
  21. // so just a plain constant that can be returned as-is.
  22. //
  23. if (node.Value is not AsyncEnumerableQuery enumerableQuery)
  24. {
  25. return node;
  26. }
  27. //
  28. // Expression representation obtained from the async enumerable query provider, so first
  29. // check whether it wraps an enumerable sequence that has been evaluated already.
  30. //
  31. if (enumerableQuery.Enumerable != null)
  32. {
  33. var publicType = GetPublicType(enumerableQuery.Enumerable.GetType());
  34. return Expression.Constant(enumerableQuery.Enumerable, publicType);
  35. }
  36. //
  37. // If not evaluated yet, inline the expression representation.
  38. //
  39. return Visit(enumerableQuery.Expression);
  40. }
  41. protected override Expression VisitMethodCall(MethodCallExpression node)
  42. {
  43. var obj = Visit(node.Object)!;
  44. var args = Visit(node.Arguments);
  45. //
  46. // Nothing changed during the visit; just some unrelated method call that can
  47. // be returned as-is.
  48. //
  49. if (obj == node.Object && args == node.Arguments)
  50. {
  51. return node;
  52. }
  53. var declType = node.Method.DeclaringType;
  54. var typeArgs = node.Method.IsGenericMethod ? node.Method.GetGenericArguments() : null;
  55. //
  56. // Check whether the method is compatible with the recursively rewritten instance
  57. // and arguments expressions. If so, create a new call expression.
  58. //
  59. if ((node.Method.IsStatic || declType != null && declType.IsAssignableFrom(obj.Type)) && ArgsMatch(node.Method, args, typeArgs))
  60. {
  61. return Expression.Call(obj, node.Method, args);
  62. }
  63. MethodInfo method;
  64. //
  65. // Find a corresponding method in the non-expression world, e.g. rewriting from
  66. // the AsyncQueryable methods to the ones on AsyncEnumerable.
  67. //
  68. if (declType == typeof(AsyncQueryable))
  69. {
  70. method = FindEnumerableMethod(node.Method.Name, args, typeArgs);
  71. args = FixupQuotedArgs(method, args);
  72. return Expression.Call(obj, method, args);
  73. }
  74. else
  75. {
  76. if (declType == null)
  77. {
  78. throw new NotSupportedException(string.Format(CultureInfo.InvariantCulture, "Could not rewrite method with name '{0}' without a DeclaringType.", node.Method.Name));
  79. }
  80. method = FindMethod(declType, node.Method.Name, args, typeArgs, BindingFlags.Static | (node.Method.IsPublic ? BindingFlags.Public : BindingFlags.NonPublic));
  81. args = FixupQuotedArgs(method, args);
  82. }
  83. return Expression.Call(obj, method, args);
  84. }
  85. protected override Expression VisitLambda<T>(Expression<T> node)
  86. {
  87. //
  88. // Don't recurse into lambdas; all the ones returning IAsyncQueryable<T>
  89. // are compatible with their IAsyncEnumerable<T> counterparts due to the
  90. // covariant return type.
  91. //
  92. return node;
  93. }
  94. protected override Expression VisitParameter(ParameterExpression node)
  95. {
  96. //
  97. // See remark on VisitLambda.
  98. //
  99. return node;
  100. }
  101. private static Type GetPublicType(Type type)
  102. {
  103. if (!type.GetTypeInfo().IsNestedPrivate)
  104. {
  105. return type;
  106. }
  107. foreach (var ifType in type.GetInterfaces())
  108. {
  109. if (ifType.GetTypeInfo().IsGenericType)
  110. {
  111. var def = ifType.GetGenericTypeDefinition();
  112. if (def == typeof(IAsyncEnumerable<>) || def == typeof(IAsyncGrouping<,>))
  113. {
  114. return ifType;
  115. }
  116. }
  117. }
  118. //
  119. // NB: Add if we ever decide to add the non-generic interface.
  120. //
  121. //if (typeof(IAsyncEnumerable).IsAssignableFrom(type))
  122. //{
  123. // return typeof(IAsyncEnumerable);
  124. //}
  125. return type;
  126. }
  127. private static bool ArgsMatch(MethodInfo method, ReadOnlyCollection<Expression> args, Type[]? typeArgs)
  128. {
  129. //
  130. // Number of parameters should match the number of arguments to bind.
  131. //
  132. var parameters = method.GetParameters();
  133. if (parameters.Length != args.Count)
  134. {
  135. return false;
  136. }
  137. //
  138. // Both should be generic or non-generic.
  139. //
  140. if (!method.IsGenericMethod && typeArgs != null && typeArgs.Length != 0)
  141. {
  142. return false;
  143. }
  144. //
  145. // Closed generic methods need to get converted to their open generic counterpart.
  146. //
  147. if (!method.IsGenericMethodDefinition && method.IsGenericMethod && method.ContainsGenericParameters)
  148. {
  149. method = method.GetGenericMethodDefinition();
  150. }
  151. //
  152. // For generic methods, close the candidate using the specified type arguments.
  153. //
  154. if (method.IsGenericMethodDefinition)
  155. {
  156. //
  157. // We should have at least 1 type argument.
  158. //
  159. if (typeArgs == null || typeArgs.Length == 0)
  160. {
  161. return false;
  162. }
  163. //
  164. // The number of type arguments needed should match the specified type argument count.
  165. //
  166. if (method.GetGenericArguments().Length != typeArgs.Length)
  167. {
  168. return false;
  169. }
  170. //
  171. // Close the generic method and re-obtain the parameters.
  172. //
  173. method = method.MakeGenericMethod(typeArgs);
  174. parameters = method.GetParameters();
  175. }
  176. //
  177. // Check for contravariant assignability of each parameter.
  178. //
  179. for (var i = 0; i < args.Count; i++)
  180. {
  181. var type = parameters[i].ParameterType;
  182. //
  183. // Hardening against reflection quirks.
  184. //
  185. if (type == null)
  186. {
  187. return false;
  188. }
  189. //
  190. // Deal with ref or out parameters by using the element type which can
  191. // match the corresponding expression type (ref passing is not encoded
  192. // in the type of expression trees).
  193. //
  194. if (type.IsByRef)
  195. {
  196. type = type.GetElementType()!;
  197. }
  198. var expression = args[i];
  199. //
  200. // If the expression is assignable to the parameter, all is good. If not,
  201. // it's possible there's a match because we're dealing with a quote that
  202. // needs to be unpacked.
  203. //
  204. if (!type.IsAssignableFrom(expression.Type))
  205. {
  206. //
  207. // Unpack the quote, if any. See AsyncQueryable for examples of operators
  208. // that hit this case.
  209. //
  210. if (expression.NodeType == ExpressionType.Quote)
  211. {
  212. expression = ((UnaryExpression)expression).Operand;
  213. }
  214. //
  215. // Try assigning the raw expression type or the quote-free expression type
  216. // to the parameter. If none of these work, there's no match.
  217. //
  218. if (!type.IsAssignableFrom(expression.Type) && !type.IsAssignableFrom(StripExpression(expression.Type)))
  219. {
  220. return false;
  221. }
  222. }
  223. }
  224. return true;
  225. }
  226. private static ReadOnlyCollection<Expression> FixupQuotedArgs(MethodInfo method, ReadOnlyCollection<Expression> argList)
  227. {
  228. //
  229. // Get all of the method parameters. No fix-up needed if empty.
  230. //
  231. var parameters = method.GetParameters();
  232. if (parameters.Length != 0)
  233. {
  234. var list = default(List<Expression>);
  235. //
  236. // Process all parameters. If any fixup is needed, the list will
  237. // get assigned.
  238. //
  239. for (var i = 0; i < parameters.Length; i++)
  240. {
  241. var expression = argList[i];
  242. var parameterInfo = parameters[i];
  243. //
  244. // Perform the fix-up if needed and check the outcome. If a
  245. // change was made, the list is lazily allocated.
  246. //
  247. expression = FixupQuotedExpression(parameterInfo.ParameterType, expression);
  248. if (list == null && expression != argList[i])
  249. {
  250. list = new List<Expression>(argList.Count);
  251. for (var j = 0; j < i; j++)
  252. {
  253. list.Add(argList[j]);
  254. }
  255. }
  256. list?.Add(expression);
  257. }
  258. //
  259. // If any argument was fixed up, return a new argument list.
  260. //
  261. if (list != null)
  262. {
  263. argList = new ReadOnlyCollection<Expression>(list);
  264. }
  265. }
  266. return argList;
  267. }
  268. private static Expression FixupQuotedExpression(Type type, Expression expression)
  269. {
  270. var res = expression;
  271. //
  272. // Keep unquoting until assignability checks pass.
  273. //
  274. while (!type.IsAssignableFrom(res.Type))
  275. {
  276. //
  277. // In case this is not a quote, bail out early.
  278. //
  279. if (res.NodeType != ExpressionType.Quote)
  280. {
  281. //
  282. // Array initialization expressions need special care by unquoting the elements.
  283. //
  284. if (!type.IsAssignableFrom(res.Type) && type.IsArray && res.NodeType == ExpressionType.NewArrayInit)
  285. {
  286. var unquotedType = StripExpression(res.Type);
  287. if (type.IsAssignableFrom(unquotedType))
  288. {
  289. var newArrayExpression = (NewArrayExpression)res;
  290. var count = newArrayExpression.Expressions.Count;
  291. var elementType = type.GetElementType()!;
  292. var list = new List<Expression>(count);
  293. for (var i = 0; i < count; i++)
  294. {
  295. list.Add(FixupQuotedExpression(elementType, newArrayExpression.Expressions[i]));
  296. }
  297. expression = Expression.NewArrayInit(elementType, list);
  298. }
  299. }
  300. return expression;
  301. }
  302. //
  303. // Unquote and try again; at most two passes should be needed.
  304. //
  305. res = ((UnaryExpression)res).Operand;
  306. }
  307. return res;
  308. }
  309. private static Type StripExpression(Type type)
  310. {
  311. //
  312. // Array of quotes need to be stripped, so extract the element type.
  313. //
  314. var elemType = type.IsArray ? type.GetElementType()! : type;
  315. //
  316. // Try to find Expression<T> and obtain T.
  317. //
  318. var genType = FindGenericType(typeof(Expression<>), elemType);
  319. if (genType != null)
  320. {
  321. elemType = genType.GetGenericArguments()[0];
  322. }
  323. //
  324. // Not an array, nothing to do here.
  325. //
  326. if (!type.IsArray)
  327. {
  328. return type;
  329. }
  330. //
  331. // Reconstruct the array type from the stripped element type.
  332. //
  333. var arrayRank = type.GetArrayRank();
  334. if (arrayRank != 1)
  335. {
  336. return elemType.MakeArrayType(arrayRank);
  337. }
  338. return elemType.MakeArrayType();
  339. }
  340. private static MethodInfo FindEnumerableMethod(string name, ReadOnlyCollection<Expression> args, params Type[]? typeArgs)
  341. {
  342. //
  343. // Ensure the cached lookup table for AsyncEnumerable methods is initialized.
  344. //
  345. _methods ??= typeof(AsyncEnumerable).GetMethods(BindingFlags.Static | BindingFlags.Public).ToLookup(m => m.Name);
  346. //
  347. // Find a match based on the method name and the argument types.
  348. //
  349. var method = _methods[name].FirstOrDefault(m => ArgsMatch(m, args, typeArgs));
  350. if (method == null)
  351. {
  352. throw new InvalidOperationException(string.Format(CultureInfo.InvariantCulture, "Could not find method with name '{0}' on type '{1}'.", name, typeof(Enumerable)));
  353. }
  354. //
  355. // Close the generic method if needed.
  356. //
  357. if (typeArgs != null)
  358. {
  359. return method.MakeGenericMethod(typeArgs);
  360. }
  361. return method;
  362. }
  363. private static MethodInfo FindMethod(Type type, string name, ReadOnlyCollection<Expression> args, Type[]? typeArgs, BindingFlags flags)
  364. {
  365. //
  366. // Support the enumerable methods to be defined on another type.
  367. //
  368. var targetType = type.GetTypeInfo().GetCustomAttribute<LocalQueryMethodImplementationTypeAttribute>()?.TargetType ?? type;
  369. //
  370. // Get all the candidates based on name and fail if none are found.
  371. //
  372. var methods = targetType.GetMethods(flags).Where(m => m.Name == name).ToArray();
  373. if (methods.Length == 0)
  374. {
  375. throw new InvalidOperationException(string.Format(CultureInfo.InvariantCulture, "Could not find method with name '{0}' on type '{1}'.", name, type));
  376. }
  377. //
  378. // Find a match based on arguments and fail if no match is found.
  379. //
  380. var method = methods.FirstOrDefault(m => ArgsMatch(m, args, typeArgs));
  381. if (method == null)
  382. {
  383. throw new InvalidOperationException(string.Format(CultureInfo.InvariantCulture, "Could not find a matching method with name '{0}' on type '{1}'.", name, type));
  384. }
  385. //
  386. // Close the generic method if needed.
  387. //
  388. if (typeArgs != null)
  389. {
  390. return method.MakeGenericMethod(typeArgs);
  391. }
  392. return method;
  393. }
  394. private static Type? FindGenericType(Type definition, Type? type)
  395. {
  396. while (type != null && type != typeof(object))
  397. {
  398. //
  399. // If the current type matches the specified definition, return.
  400. //
  401. if (type.GetTypeInfo().IsGenericType && type.GetGenericTypeDefinition() == definition)
  402. {
  403. return type;
  404. }
  405. //
  406. // Probe all interfaces implemented by the current type.
  407. //
  408. if (definition.GetTypeInfo().IsInterface)
  409. {
  410. foreach (var ifType in type.GetInterfaces())
  411. {
  412. var res = FindGenericType(definition, ifType);
  413. if (res != null)
  414. {
  415. return res;
  416. }
  417. }
  418. }
  419. //
  420. // Continue up the type hierarchy.
  421. //
  422. type = type.GetTypeInfo().BaseType;
  423. }
  424. return null;
  425. }
  426. }
  427. }