AsyncEnumerableRewriter.cs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Collections.ObjectModel;
  6. using System.Globalization;
  7. using System.Linq.Expressions;
  8. using System.Reflection;
  9. namespace System.Linq
  10. {
  11. /// <summary>
  12. /// Rewrites an expression tree representation using AsyncQueryable methods to the corresponding AsyncEnumerable equivalents.
  13. /// </summary>
  14. internal class AsyncEnumerableRewriter : ExpressionVisitor
  15. {
  16. private static volatile ILookup<string, MethodInfo> _methods;
  17. protected override Expression VisitConstant(ConstantExpression node)
  18. {
  19. //
  20. // Not an expression representation obtained from the async enumerable query provider,
  21. // so just a plain constant that can be returned as-is.
  22. //
  23. if (!(node.Value is AsyncEnumerableQuery enumerableQuery))
  24. {
  25. return node;
  26. }
  27. //
  28. // Expression representation obtained from the async enumerable query provider, so first
  29. // check whether it wraps an enumerable sequence that has been evaluated already.
  30. //
  31. if (enumerableQuery.Enumerable != null)
  32. {
  33. var publicType = GetPublicType(enumerableQuery.Enumerable.GetType());
  34. return Expression.Constant(enumerableQuery.Enumerable, publicType);
  35. }
  36. //
  37. // If not evaluated yet, inline the expression representation.
  38. //
  39. return Visit(enumerableQuery.Expression);
  40. }
  41. protected override Expression VisitMethodCall(MethodCallExpression node)
  42. {
  43. var obj = Visit(node.Object);
  44. var args = Visit(node.Arguments);
  45. //
  46. // Nothing changed during the visit; just some unrelated method call that can
  47. // be returned as-is.
  48. //
  49. if (obj == node.Object && args == node.Arguments)
  50. {
  51. return node;
  52. }
  53. var typeArgs = node.Method.IsGenericMethod ? node.Method.GetGenericArguments() : null;
  54. //
  55. // Check whether the method is compatible with the recursively rewritten instance
  56. // and arguments expressions. If so, create a new call expression.
  57. //
  58. if ((node.Method.IsStatic || node.Method.DeclaringType.IsAssignableFrom(obj.Type)) && ArgsMatch(node.Method, args, typeArgs))
  59. {
  60. return Expression.Call(obj, node.Method, args);
  61. }
  62. var method = default(MethodInfo);
  63. //
  64. // Find a corresponding method in the non-expression world, e.g. rewriting from
  65. // the AsyncQueryable methods to the ones on AsyncEnumerable.
  66. //
  67. if (node.Method.DeclaringType == typeof(AsyncQueryable))
  68. {
  69. method = FindEnumerableMethod(node.Method.Name, args, typeArgs);
  70. args = FixupQuotedArgs(method, args);
  71. return Expression.Call(obj, method, args);
  72. }
  73. else
  74. {
  75. method = FindMethod(node.Method.DeclaringType, node.Method.Name, args, typeArgs, BindingFlags.Static | (node.Method.IsPublic ? BindingFlags.Public : BindingFlags.NonPublic));
  76. args = FixupQuotedArgs(method, args);
  77. }
  78. return Expression.Call(obj, method, args);
  79. }
  80. protected override Expression VisitLambda<T>(Expression<T> node)
  81. {
  82. //
  83. // Don't recurse into lambdas; all the ones returning IAsyncQueryable<T>
  84. // are compatible with their IAsyncEnumerable<T> counterparts due to the
  85. // covariant return type.
  86. //
  87. return node;
  88. }
  89. protected override Expression VisitParameter(ParameterExpression node)
  90. {
  91. //
  92. // See remark on VisitLambda.
  93. //
  94. return node;
  95. }
  96. private static Type GetPublicType(Type type)
  97. {
  98. if (!type.IsNestedPrivate())
  99. {
  100. return type;
  101. }
  102. foreach (var ifType in type.GetInterfaces())
  103. {
  104. if (ifType.IsGenericType())
  105. {
  106. var def = ifType.GetGenericTypeDefinition();
  107. if (def == typeof(IAsyncEnumerable<>) || def == typeof(IAsyncGrouping<,>))
  108. {
  109. return ifType;
  110. }
  111. }
  112. }
  113. //
  114. // NB: Add if we ever decide to add the non-generic interface.
  115. //
  116. //if (typeof(IAsyncEnumerable).IsAssignableFrom(type))
  117. //{
  118. // return typeof(IAsyncEnumerable);
  119. //}
  120. return type;
  121. }
  122. private static bool ArgsMatch(MethodInfo method, ReadOnlyCollection<Expression> args, Type[] typeArgs)
  123. {
  124. //
  125. // Number of parameters should match the number of arguments to bind.
  126. //
  127. var parameters = method.GetParameters();
  128. if (parameters.Length != args.Count)
  129. {
  130. return false;
  131. }
  132. //
  133. // Both should be generic or non-generic.
  134. //
  135. if (!method.IsGenericMethod && typeArgs != null && typeArgs.Length != 0)
  136. {
  137. return false;
  138. }
  139. //
  140. // Closed generic methods need to get converted to their open generic counterpart.
  141. //
  142. if (!method.IsGenericMethodDefinition && method.IsGenericMethod && method.ContainsGenericParameters)
  143. {
  144. method = method.GetGenericMethodDefinition();
  145. }
  146. //
  147. // For generic methods, close the candidate using the specified type arguments.
  148. //
  149. if (method.IsGenericMethodDefinition)
  150. {
  151. //
  152. // We should have at least 1 type argument.
  153. //
  154. if (typeArgs == null || typeArgs.Length == 0)
  155. {
  156. return false;
  157. }
  158. //
  159. // The number of type arguments needed should match the specified type argument count.
  160. //
  161. if (method.GetGenericArguments().Length != typeArgs.Length)
  162. {
  163. return false;
  164. }
  165. //
  166. // Close the generic method and re-obtain the parameters.
  167. //
  168. method = method.MakeGenericMethod(typeArgs);
  169. parameters = method.GetParameters();
  170. }
  171. //
  172. // Check for contravariant assignability of each parameter.
  173. //
  174. for (var i = 0; i < args.Count; i++)
  175. {
  176. var type = parameters[i].ParameterType;
  177. //
  178. // Hardening against reflection quirks.
  179. //
  180. if (type == null)
  181. {
  182. return false;
  183. }
  184. //
  185. // Deal with ref or out parameters by using the element type which can
  186. // match the corresponding expression type (ref passing is not encoded
  187. // in the type of expression trees).
  188. //
  189. if (type.IsByRef)
  190. {
  191. type = type.GetElementType();
  192. }
  193. var expression = args[i];
  194. //
  195. // If the expression is assignable to the parameter, all is good. If not,
  196. // it's possible there's a match because we're dealing with a quote that
  197. // needs to be unpacked.
  198. //
  199. if (!type.IsAssignableFrom(expression.Type))
  200. {
  201. //
  202. // Unpack the quote, if any. See AsyncQueryable for examples of operators
  203. // that hit this case.
  204. //
  205. if (expression.NodeType == ExpressionType.Quote)
  206. {
  207. expression = ((UnaryExpression)expression).Operand;
  208. }
  209. //
  210. // Try assigning the raw expression type or the quote-free expression type
  211. // to the parameter. If none of these work, there's no match.
  212. //
  213. if (!type.IsAssignableFrom(expression.Type) && !type.IsAssignableFrom(StripExpression(expression.Type)))
  214. {
  215. return false;
  216. }
  217. }
  218. }
  219. return true;
  220. }
  221. private ReadOnlyCollection<Expression> FixupQuotedArgs(MethodInfo method, ReadOnlyCollection<Expression> argList)
  222. {
  223. //
  224. // Get all of the method parameters. No fix-up needed if empty.
  225. //
  226. var parameters = method.GetParameters();
  227. if (parameters.Length != 0)
  228. {
  229. var list = default(List<Expression>);
  230. //
  231. // Process all parameters. If any fixup is needed, the list will
  232. // get assigned.
  233. //
  234. for (var i = 0; i < parameters.Length; i++)
  235. {
  236. var expression = argList[i];
  237. var parameterInfo = parameters[i];
  238. //
  239. // Perform the fix-up if needed and check the outcome. If a
  240. // change was made, the list is lazily allocated.
  241. //
  242. expression = FixupQuotedExpression(parameterInfo.ParameterType, expression);
  243. if (list == null && expression != argList[i])
  244. {
  245. list = new List<Expression>(argList.Count);
  246. for (var j = 0; j < i; j++)
  247. {
  248. list.Add(argList[j]);
  249. }
  250. }
  251. if (list != null)
  252. {
  253. list.Add(expression);
  254. }
  255. }
  256. //
  257. // If any argument was fixed up, return a new argument list.
  258. //
  259. if (list != null)
  260. {
  261. argList = new ReadOnlyCollection<Expression>(list);
  262. }
  263. }
  264. return argList;
  265. }
  266. private Expression FixupQuotedExpression(Type type, Expression expression)
  267. {
  268. var res = expression;
  269. //
  270. // Keep unquoting until assignability checks pass.
  271. //
  272. while (!type.IsAssignableFrom(res.Type))
  273. {
  274. //
  275. // In case this is not a quote, bail out early.
  276. //
  277. if (res.NodeType != ExpressionType.Quote)
  278. {
  279. //
  280. // Array initialization expressions need special care by unquoting the elements.
  281. //
  282. if (!type.IsAssignableFrom(res.Type) && type.IsArray && res.NodeType == ExpressionType.NewArrayInit)
  283. {
  284. var unquotedType = StripExpression(res.Type);
  285. if (type.IsAssignableFrom(unquotedType))
  286. {
  287. var newArrayExpression = (NewArrayExpression)res;
  288. var count = newArrayExpression.Expressions.Count;
  289. var elementType = type.GetElementType();
  290. var list = new List<Expression>(count);
  291. for (var i = 0; i < count; i++)
  292. {
  293. list.Add(FixupQuotedExpression(elementType, newArrayExpression.Expressions[i]));
  294. }
  295. expression = Expression.NewArrayInit(elementType, list);
  296. }
  297. }
  298. return expression;
  299. }
  300. //
  301. // Unquote and try again; at most two passes should be needed.
  302. //
  303. res = ((UnaryExpression)res).Operand;
  304. }
  305. return res;
  306. }
  307. private static Type StripExpression(Type type)
  308. {
  309. //
  310. // Array of quotes need to be stripped, so extract the element type.
  311. //
  312. var elemType = type.IsArray ? type.GetElementType() : type;
  313. //
  314. // Try to find Expression<T> and obtain T.
  315. //
  316. var genType = FindGenericType(typeof(Expression<>), elemType);
  317. if (genType != null)
  318. {
  319. elemType = genType.GetGenericArguments()[0];
  320. }
  321. //
  322. // Not an array, nothing to do here.
  323. //
  324. if (!type.IsArray)
  325. {
  326. return type;
  327. }
  328. //
  329. // Reconstruct the array type from the stripped element type.
  330. //
  331. var arrayRank = type.GetArrayRank();
  332. if (arrayRank != 1)
  333. {
  334. return elemType.MakeArrayType(arrayRank);
  335. }
  336. return elemType.MakeArrayType();
  337. }
  338. private static MethodInfo FindEnumerableMethod(string name, ReadOnlyCollection<Expression> args, params Type[] typeArgs)
  339. {
  340. //
  341. // Ensure the cached lookup table for AsyncEnumerable methods is initialized.
  342. //
  343. if (_methods == null)
  344. {
  345. _methods = typeof(AsyncEnumerable).GetMethods(BindingFlags.Static | BindingFlags.Public).ToLookup(m => m.Name);
  346. }
  347. //
  348. // Find a match based on the method name and the argument types.
  349. //
  350. var method = _methods[name].FirstOrDefault(m => ArgsMatch(m, args, typeArgs));
  351. if (method == null)
  352. {
  353. throw new InvalidOperationException(string.Format(CultureInfo.InvariantCulture, "Could not find method with name '{0}' on type '{1}'.", name, typeof(Enumerable)));
  354. }
  355. //
  356. // Close the generic method if needed.
  357. //
  358. if (typeArgs != null)
  359. {
  360. return method.MakeGenericMethod(typeArgs);
  361. }
  362. return method;
  363. }
  364. private static MethodInfo FindMethod(Type type, string name, ReadOnlyCollection<Expression> args, Type[] typeArgs, BindingFlags flags)
  365. {
  366. //
  367. // Support the enumerable methods to be defined on another type.
  368. //
  369. var targetType = type.GetTypeInfo().GetCustomAttribute<LocalQueryMethodImplementationTypeAttribute>()?.TargetType ?? type;
  370. //
  371. // Get all the candidates based on name and fail if none are found.
  372. //
  373. var methods = targetType.GetMethods(flags).Where(m => m.Name == name).ToArray();
  374. if (methods.Length == 0)
  375. {
  376. throw new InvalidOperationException(string.Format(CultureInfo.InvariantCulture, "Could not find method with name '{0}' on type '{1}'.", name, type));
  377. }
  378. //
  379. // Find a match based on arguments and fail if no match is found.
  380. //
  381. var method = methods.FirstOrDefault(m => ArgsMatch(m, args, typeArgs));
  382. if (method == null)
  383. {
  384. throw new InvalidOperationException(string.Format(CultureInfo.InvariantCulture, "Could not find a matching method with name '{0}' on type '{1}'.", name, type));
  385. }
  386. //
  387. // Close the generic method if needed.
  388. //
  389. if (typeArgs != null)
  390. {
  391. return method.MakeGenericMethod(typeArgs);
  392. }
  393. return method;
  394. }
  395. private static Type FindGenericType(Type definition, Type type)
  396. {
  397. while (type != null && type != typeof(object))
  398. {
  399. //
  400. // If the current type matches the specified definition, return.
  401. //
  402. if (type.IsGenericType() && type.GetGenericTypeDefinition() == definition)
  403. {
  404. return type;
  405. }
  406. //
  407. // Probe all interfaces implemented by the current type.
  408. //
  409. if (definition.IsInterface())
  410. {
  411. foreach (var ifType in type.GetInterfaces())
  412. {
  413. var res = FindGenericType(definition, ifType);
  414. if (res != null)
  415. {
  416. return res;
  417. }
  418. }
  419. }
  420. //
  421. // Continue up the type hierarchy.
  422. //
  423. type = type.GetBaseType();
  424. }
  425. return null;
  426. }
  427. }
  428. }