AsyncEnumerableRewriter.cs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Collections.ObjectModel;
  6. using System.Globalization;
  7. using System.Linq.Expressions;
  8. using System.Reflection;
  9. namespace System.Linq
  10. {
  11. /// <summary>
  12. /// Rewrites an expression tree representation using AsyncQueryable methods to the corresponding AsyncEnumerable equivalents.
  13. /// </summary>
  14. internal class AsyncEnumerableRewriter : ExpressionVisitor
  15. {
  16. private static volatile ILookup<string, MethodInfo> s_methods;
  17. protected override Expression VisitConstant(ConstantExpression node)
  18. {
  19. var enumerableQuery = node.Value as AsyncEnumerableQuery;
  20. //
  21. // Not an expression representation obtained from the async enumerable query provider,
  22. // so just a plain constant that can be returned as-is.
  23. //
  24. if (enumerableQuery == null)
  25. {
  26. return node;
  27. }
  28. //
  29. // Expression representation obtained from the async enumerable query provider, so first
  30. // check whether it wraps an enumerable sequence that has been evaluated already.
  31. //
  32. if (enumerableQuery.Enumerable != null)
  33. {
  34. var publicType = GetPublicType(enumerableQuery.Enumerable.GetType());
  35. return Expression.Constant(enumerableQuery.Enumerable, publicType);
  36. }
  37. //
  38. // If not evaluated yet, inline the expression representation.
  39. //
  40. return Visit(enumerableQuery.Expression);
  41. }
  42. protected override Expression VisitMethodCall(MethodCallExpression node)
  43. {
  44. var obj = Visit(node.Object);
  45. var args = Visit(node.Arguments);
  46. //
  47. // Nothing changed during the visit; just some unrelated method call that can
  48. // be returned as-is.
  49. //
  50. if (obj == node.Object && args == node.Arguments)
  51. {
  52. return node;
  53. }
  54. var typeArgs = node.Method.IsGenericMethod ? node.Method.GetGenericArguments() : null;
  55. //
  56. // Check whether the method is compatible with the recursively rewritten instance
  57. // and arguments expressions. If so, create a new call expression.
  58. //
  59. if ((node.Method.IsStatic || node.Method.DeclaringType.IsAssignableFrom(obj.Type)) && ArgsMatch(node.Method, args, typeArgs))
  60. {
  61. return Expression.Call(obj, node.Method, args);
  62. }
  63. var method = default(MethodInfo);
  64. //
  65. // Find a corresponding method in the non-expression world, e.g. rewriting from
  66. // the AsyncQueryable methods to the ones on AsyncEnumerable.
  67. //
  68. if (node.Method.DeclaringType == typeof(AsyncQueryable))
  69. {
  70. method = FindEnumerableMethod(node.Method.Name, args, typeArgs);
  71. args = FixupQuotedArgs(method, args);
  72. return Expression.Call(obj, method, args);
  73. }
  74. else
  75. {
  76. method = FindMethod(node.Method.DeclaringType, node.Method.Name, args, typeArgs, BindingFlags.Static | (node.Method.IsPublic ? BindingFlags.Public : BindingFlags.NonPublic));
  77. args = FixupQuotedArgs(method, args);
  78. }
  79. return Expression.Call(obj, method, args);
  80. }
  81. protected override Expression VisitLambda<T>(Expression<T> node)
  82. {
  83. //
  84. // Don't recurse into lambdas; all the ones returning IAsyncQueryable<T>
  85. // are compatible with their IAsyncEnumerable<T> counterparts due to the
  86. // covariant return type.
  87. //
  88. return node;
  89. }
  90. protected override Expression VisitParameter(ParameterExpression node)
  91. {
  92. //
  93. // See remark on VisitLambda.
  94. //
  95. return node;
  96. }
  97. private static Type GetPublicType(Type type)
  98. {
  99. if (!type.IsNestedPrivate())
  100. {
  101. return type;
  102. }
  103. foreach (var ifType in type.GetInterfaces())
  104. {
  105. if (ifType.IsGenericType())
  106. {
  107. var def = ifType.GetGenericTypeDefinition();
  108. if (def == typeof(IAsyncEnumerable<>) || def == typeof(IAsyncGrouping<,>))
  109. {
  110. return ifType;
  111. }
  112. }
  113. }
  114. //
  115. // NB: Add if we ever decide to add the non-generic interface.
  116. //
  117. //if (typeof(IAsyncEnumerable).IsAssignableFrom(type))
  118. //{
  119. // return typeof(IAsyncEnumerable);
  120. //}
  121. return type;
  122. }
  123. private static bool ArgsMatch(MethodInfo method, ReadOnlyCollection<Expression> args, Type[] typeArgs)
  124. {
  125. //
  126. // Number of parameters should match the number of arguments to bind.
  127. //
  128. var parameters = method.GetParameters();
  129. if (parameters.Length != args.Count)
  130. {
  131. return false;
  132. }
  133. //
  134. // Both should be generic or non-generic.
  135. //
  136. if (!method.IsGenericMethod && typeArgs != null && typeArgs.Length != 0)
  137. {
  138. return false;
  139. }
  140. //
  141. // Closed generic methods need to get converted to their open generic counterpart.
  142. //
  143. if (!method.IsGenericMethodDefinition && method.IsGenericMethod && method.ContainsGenericParameters)
  144. {
  145. method = method.GetGenericMethodDefinition();
  146. }
  147. //
  148. // For generic methods, close the candidate using the specified type arguments.
  149. //
  150. if (method.IsGenericMethodDefinition)
  151. {
  152. //
  153. // We should have at least 1 type argument.
  154. //
  155. if (typeArgs == null || typeArgs.Length == 0)
  156. {
  157. return false;
  158. }
  159. //
  160. // The number of type arguments needed should match the specified type argument count.
  161. //
  162. if (method.GetGenericArguments().Length != typeArgs.Length)
  163. {
  164. return false;
  165. }
  166. //
  167. // Close the generic method and re-obtain the parameters.
  168. //
  169. method = method.MakeGenericMethod(typeArgs);
  170. parameters = method.GetParameters();
  171. }
  172. //
  173. // Check for contravariant assignability of each parameter.
  174. //
  175. for (var i = 0; i < args.Count; i++)
  176. {
  177. var type = parameters[i].ParameterType;
  178. //
  179. // Hardening against reflection quirks.
  180. //
  181. if (type == null)
  182. {
  183. return false;
  184. }
  185. //
  186. // Deal with ref or out parameters by using the element type which can
  187. // match the corresponding expression type (ref passing is not encoded
  188. // in the type of expression trees).
  189. //
  190. if (type.IsByRef)
  191. {
  192. type = type.GetElementType();
  193. }
  194. var expression = args[i];
  195. //
  196. // If the expression is assignable to the parameter, all is good. If not,
  197. // it's possible there's a match because we're dealing with a quote that
  198. // needs to be unpacked.
  199. //
  200. if (!type.IsAssignableFrom(expression.Type))
  201. {
  202. //
  203. // Unpack the quote, if any. See AsyncQueryable for examples of operators
  204. // that hit this case.
  205. //
  206. if (expression.NodeType == ExpressionType.Quote)
  207. {
  208. expression = ((UnaryExpression)expression).Operand;
  209. }
  210. //
  211. // Try assigning the raw expression type or the quote-free expression type
  212. // to the parameter. If none of these work, there's no match.
  213. //
  214. if (!type.IsAssignableFrom(expression.Type) && !type.IsAssignableFrom(StripExpression(expression.Type)))
  215. {
  216. return false;
  217. }
  218. }
  219. }
  220. return true;
  221. }
  222. private ReadOnlyCollection<Expression> FixupQuotedArgs(MethodInfo method, ReadOnlyCollection<Expression> argList)
  223. {
  224. //
  225. // Get all of the method parameters. No fix-up needed if empty.
  226. //
  227. var parameters = method.GetParameters();
  228. if (parameters.Length != 0)
  229. {
  230. var list = default(List<Expression>);
  231. //
  232. // Process all parameters. If any fixup is needed, the list will
  233. // get assigned.
  234. //
  235. for (var i = 0; i < parameters.Length; i++)
  236. {
  237. var expression = argList[i];
  238. var parameterInfo = parameters[i];
  239. //
  240. // Perform the fix-up if needed and check the outcome. If a
  241. // change was made, the list is lazily allocated.
  242. //
  243. expression = FixupQuotedExpression(parameterInfo.ParameterType, expression);
  244. if (list == null && expression != argList[i])
  245. {
  246. list = new List<Expression>(argList.Count);
  247. for (var j = 0; j < i; j++)
  248. {
  249. list.Add(argList[j]);
  250. }
  251. }
  252. if (list != null)
  253. {
  254. list.Add(expression);
  255. }
  256. }
  257. //
  258. // If any argument was fixed up, return a new argument list.
  259. //
  260. if (list != null)
  261. {
  262. argList = new ReadOnlyCollection<Expression>(list);
  263. }
  264. }
  265. return argList;
  266. }
  267. private Expression FixupQuotedExpression(Type type, Expression expression)
  268. {
  269. var res = expression;
  270. //
  271. // Keep unquoting until assignability checks pass.
  272. //
  273. while (!type.IsAssignableFrom(res.Type))
  274. {
  275. //
  276. // In case this is not a quote, bail out early.
  277. //
  278. if (res.NodeType != ExpressionType.Quote)
  279. {
  280. //
  281. // Array initialization expressions need special care by unquoting the elements.
  282. //
  283. if (!type.IsAssignableFrom(res.Type) && type.IsArray && res.NodeType == ExpressionType.NewArrayInit)
  284. {
  285. var unquotedType = StripExpression(res.Type);
  286. if (type.IsAssignableFrom(unquotedType))
  287. {
  288. var newArrayExpression = (NewArrayExpression)res;
  289. var count = newArrayExpression.Expressions.Count;
  290. var elementType = type.GetElementType();
  291. var list = new List<Expression>(count);
  292. for (var i = 0; i < count; i++)
  293. {
  294. list.Add(FixupQuotedExpression(elementType, newArrayExpression.Expressions[i]));
  295. }
  296. expression = Expression.NewArrayInit(elementType, list);
  297. }
  298. }
  299. return expression;
  300. }
  301. //
  302. // Unquote and try again; at most two passes should be needed.
  303. //
  304. res = ((UnaryExpression)res).Operand;
  305. }
  306. return res;
  307. }
  308. private static Type StripExpression(Type type)
  309. {
  310. //
  311. // Array of quotes need to be stripped, so extract the element type.
  312. //
  313. var elemType = type.IsArray ? type.GetElementType() : type;
  314. //
  315. // Try to find Expression<T> and obtain T.
  316. //
  317. var genType = FindGenericType(typeof(Expression<>), elemType);
  318. if (genType != null)
  319. {
  320. elemType = genType.GetGenericArguments()[0];
  321. }
  322. //
  323. // Not an array, nothing to do here.
  324. //
  325. if (!type.IsArray)
  326. {
  327. return type;
  328. }
  329. //
  330. // Reconstruct the array type from the stripped element type.
  331. //
  332. var arrayRank = type.GetArrayRank();
  333. if (arrayRank != 1)
  334. {
  335. return elemType.MakeArrayType(arrayRank);
  336. }
  337. return elemType.MakeArrayType();
  338. }
  339. private static MethodInfo FindEnumerableMethod(string name, ReadOnlyCollection<Expression> args, params Type[] typeArgs)
  340. {
  341. //
  342. // Ensure the cached lookup table for AsyncEnumerable methods is initialized.
  343. //
  344. if (s_methods == null)
  345. {
  346. s_methods = typeof(AsyncEnumerable).GetMethods(BindingFlags.Static | BindingFlags.Public).ToLookup(m => m.Name);
  347. }
  348. //
  349. // Find a match based on the method name and the argument types.
  350. //
  351. var method = s_methods[name].FirstOrDefault(m => ArgsMatch(m, args, typeArgs));
  352. if (method == null)
  353. {
  354. throw new InvalidOperationException(string.Format(CultureInfo.InvariantCulture, "Could not find method with name '{0}' on type '{1}'.", name, typeof(Enumerable)));
  355. }
  356. //
  357. // Close the generic method if needed.
  358. //
  359. if (typeArgs != null)
  360. {
  361. return method.MakeGenericMethod(typeArgs);
  362. }
  363. return method;
  364. }
  365. private static MethodInfo FindMethod(Type type, string name, ReadOnlyCollection<Expression> args, Type[] typeArgs, BindingFlags flags)
  366. {
  367. //
  368. // Get all the candidates based on name and fail if none are found.
  369. //
  370. var methods = type.GetMethods(flags).Where(m => m.Name == name).ToArray();
  371. if (methods.Length == 0)
  372. {
  373. throw new InvalidOperationException(string.Format(CultureInfo.InvariantCulture, "Could not find method with name '{0}' on type '{1}'.", name, type));
  374. }
  375. //
  376. // Find a match based on arguments and fail if no match is found.
  377. //
  378. var method = methods.FirstOrDefault(m => ArgsMatch(m, args, typeArgs));
  379. if (method == null)
  380. {
  381. throw new InvalidOperationException(string.Format(CultureInfo.InvariantCulture, "Could not find a matching method with name '{0}' on type '{1}'.", name, type));
  382. }
  383. //
  384. // Close the generic method if needed.
  385. //
  386. if (typeArgs != null)
  387. {
  388. return method.MakeGenericMethod(typeArgs);
  389. }
  390. return method;
  391. }
  392. private static Type FindGenericType(Type definition, Type type)
  393. {
  394. while (type != null && type != typeof(object))
  395. {
  396. //
  397. // If the current type matches the specified definition, return.
  398. //
  399. if (type.IsGenericType() && type.GetGenericTypeDefinition() == definition)
  400. {
  401. return type;
  402. }
  403. //
  404. // Probe all interfaces implemented by the current type.
  405. //
  406. if (definition.IsInterface())
  407. {
  408. foreach (var ifType in type.GetInterfaces())
  409. {
  410. var res = FindGenericType(definition, ifType);
  411. if (res != null)
  412. {
  413. return res;
  414. }
  415. }
  416. }
  417. //
  418. // Continue up the type hierarchy.
  419. //
  420. type = type.GetBaseType();
  421. }
  422. return null;
  423. }
  424. }
  425. }