1
0

AsyncEnumerableRewriter.cs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the MIT License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Collections.ObjectModel;
  6. using System.Globalization;
  7. using System.Linq.Expressions;
  8. using System.Reflection;
  9. namespace System.Linq
  10. {
  11. /// <summary>
  12. /// Rewrites an expression tree representation using AsyncQueryable methods to the corresponding AsyncEnumerable equivalents.
  13. /// </summary>
  14. internal class AsyncEnumerableRewriter : ExpressionVisitor
  15. {
  16. private static volatile ILookup<string, MethodInfo>? _methods;
  17. protected override Expression VisitConstant(ConstantExpression node)
  18. {
  19. //
  20. // Not an expression representation obtained from the async enumerable query provider,
  21. // so just a plain constant that can be returned as-is.
  22. //
  23. if (!(node.Value is AsyncEnumerableQuery enumerableQuery))
  24. {
  25. return node;
  26. }
  27. //
  28. // Expression representation obtained from the async enumerable query provider, so first
  29. // check whether it wraps an enumerable sequence that has been evaluated already.
  30. //
  31. if (enumerableQuery.Enumerable != null)
  32. {
  33. var publicType = GetPublicType(enumerableQuery.Enumerable.GetType());
  34. return Expression.Constant(enumerableQuery.Enumerable, publicType);
  35. }
  36. //
  37. // If not evaluated yet, inline the expression representation.
  38. //
  39. return Visit(enumerableQuery.Expression);
  40. }
  41. protected override Expression VisitMethodCall(MethodCallExpression node)
  42. {
  43. var obj = Visit(node.Object);
  44. var args = Visit(node.Arguments);
  45. //
  46. // Nothing changed during the visit; just some unrelated method call that can
  47. // be returned as-is.
  48. //
  49. if (obj == node.Object && args == node.Arguments)
  50. {
  51. return node;
  52. }
  53. var declType = node.Method.DeclaringType;
  54. var typeArgs = node.Method.IsGenericMethod ? node.Method.GetGenericArguments() : null;
  55. //
  56. // Check whether the method is compatible with the recursively rewritten instance
  57. // and arguments expressions. If so, create a new call expression.
  58. //
  59. if ((node.Method.IsStatic || declType != null && declType.IsAssignableFrom(obj.Type)) && ArgsMatch(node.Method, args, typeArgs))
  60. {
  61. return Expression.Call(obj, node.Method, args);
  62. }
  63. MethodInfo method;
  64. //
  65. // Find a corresponding method in the non-expression world, e.g. rewriting from
  66. // the AsyncQueryable methods to the ones on AsyncEnumerable.
  67. //
  68. if (declType == typeof(AsyncQueryable))
  69. {
  70. method = FindEnumerableMethod(node.Method.Name, args, typeArgs);
  71. args = FixupQuotedArgs(method, args);
  72. return Expression.Call(obj, method, args);
  73. }
  74. else
  75. {
  76. if (declType == null)
  77. {
  78. throw new NotSupportedException(string.Format(CultureInfo.InvariantCulture, "Could not rewrite method with name '{0}' without a DeclaringType.", node.Method.Name));
  79. }
  80. method = FindMethod(declType, node.Method.Name, args, typeArgs, BindingFlags.Static | (node.Method.IsPublic ? BindingFlags.Public : BindingFlags.NonPublic));
  81. args = FixupQuotedArgs(method, args);
  82. }
  83. return Expression.Call(obj, method, args);
  84. }
  85. protected override Expression VisitLambda<T>(Expression<T> node)
  86. {
  87. //
  88. // Don't recurse into lambdas; all the ones returning IAsyncQueryable<T>
  89. // are compatible with their IAsyncEnumerable<T> counterparts due to the
  90. // covariant return type.
  91. //
  92. return node;
  93. }
  94. protected override Expression VisitParameter(ParameterExpression node)
  95. {
  96. //
  97. // See remark on VisitLambda.
  98. //
  99. return node;
  100. }
  101. private static Type GetPublicType(Type type)
  102. {
  103. if (!type.GetTypeInfo().IsNestedPrivate)
  104. {
  105. return type;
  106. }
  107. foreach (var ifType in type.GetInterfaces())
  108. {
  109. if (ifType.GetTypeInfo().IsGenericType)
  110. {
  111. var def = ifType.GetGenericTypeDefinition();
  112. if (def == typeof(IAsyncEnumerable<>) || def == typeof(IAsyncGrouping<,>))
  113. {
  114. return ifType;
  115. }
  116. }
  117. }
  118. //
  119. // NB: Add if we ever decide to add the non-generic interface.
  120. //
  121. //if (typeof(IAsyncEnumerable).IsAssignableFrom(type))
  122. //{
  123. // return typeof(IAsyncEnumerable);
  124. //}
  125. return type;
  126. }
  127. private static bool ArgsMatch(MethodInfo method, ReadOnlyCollection<Expression> args, Type[]? typeArgs)
  128. {
  129. //
  130. // Number of parameters should match the number of arguments to bind.
  131. //
  132. var parameters = method.GetParameters();
  133. if (parameters.Length != args.Count)
  134. {
  135. return false;
  136. }
  137. //
  138. // Both should be generic or non-generic.
  139. //
  140. if (!method.IsGenericMethod && typeArgs != null && typeArgs.Length != 0)
  141. {
  142. return false;
  143. }
  144. //
  145. // Closed generic methods need to get converted to their open generic counterpart.
  146. //
  147. if (!method.IsGenericMethodDefinition && method.IsGenericMethod && method.ContainsGenericParameters)
  148. {
  149. method = method.GetGenericMethodDefinition();
  150. }
  151. //
  152. // For generic methods, close the candidate using the specified type arguments.
  153. //
  154. if (method.IsGenericMethodDefinition)
  155. {
  156. //
  157. // We should have at least 1 type argument.
  158. //
  159. if (typeArgs == null || typeArgs.Length == 0)
  160. {
  161. return false;
  162. }
  163. //
  164. // The number of type arguments needed should match the specified type argument count.
  165. //
  166. if (method.GetGenericArguments().Length != typeArgs.Length)
  167. {
  168. return false;
  169. }
  170. //
  171. // Close the generic method and re-obtain the parameters.
  172. //
  173. method = method.MakeGenericMethod(typeArgs);
  174. parameters = method.GetParameters();
  175. }
  176. //
  177. // Check for contravariant assignability of each parameter.
  178. //
  179. for (var i = 0; i < args.Count; i++)
  180. {
  181. var type = parameters[i].ParameterType;
  182. //
  183. // Hardening against reflection quirks.
  184. //
  185. if (type == null)
  186. {
  187. return false;
  188. }
  189. //
  190. // Deal with ref or out parameters by using the element type which can
  191. // match the corresponding expression type (ref passing is not encoded
  192. // in the type of expression trees).
  193. //
  194. if (type.IsByRef)
  195. {
  196. type = type.GetElementType()!;
  197. }
  198. var expression = args[i];
  199. //
  200. // If the expression is assignable to the parameter, all is good. If not,
  201. // it's possible there's a match because we're dealing with a quote that
  202. // needs to be unpacked.
  203. //
  204. if (!type.IsAssignableFrom(expression.Type))
  205. {
  206. //
  207. // Unpack the quote, if any. See AsyncQueryable for examples of operators
  208. // that hit this case.
  209. //
  210. if (expression.NodeType == ExpressionType.Quote)
  211. {
  212. expression = ((UnaryExpression)expression).Operand;
  213. }
  214. //
  215. // Try assigning the raw expression type or the quote-free expression type
  216. // to the parameter. If none of these work, there's no match.
  217. //
  218. if (!type.IsAssignableFrom(expression.Type) && !type.IsAssignableFrom(StripExpression(expression.Type)))
  219. {
  220. return false;
  221. }
  222. }
  223. }
  224. return true;
  225. }
  226. private static ReadOnlyCollection<Expression> FixupQuotedArgs(MethodInfo method, ReadOnlyCollection<Expression> argList)
  227. {
  228. //
  229. // Get all of the method parameters. No fix-up needed if empty.
  230. //
  231. var parameters = method.GetParameters();
  232. if (parameters.Length != 0)
  233. {
  234. var list = default(List<Expression>);
  235. //
  236. // Process all parameters. If any fixup is needed, the list will
  237. // get assigned.
  238. //
  239. for (var i = 0; i < parameters.Length; i++)
  240. {
  241. var expression = argList[i];
  242. var parameterInfo = parameters[i];
  243. //
  244. // Perform the fix-up if needed and check the outcome. If a
  245. // change was made, the list is lazily allocated.
  246. //
  247. expression = FixupQuotedExpression(parameterInfo.ParameterType, expression);
  248. if (list == null && expression != argList[i])
  249. {
  250. list = new List<Expression>(argList.Count);
  251. for (var j = 0; j < i; j++)
  252. {
  253. list.Add(argList[j]);
  254. }
  255. }
  256. if (list != null)
  257. {
  258. list.Add(expression);
  259. }
  260. }
  261. //
  262. // If any argument was fixed up, return a new argument list.
  263. //
  264. if (list != null)
  265. {
  266. argList = new ReadOnlyCollection<Expression>(list);
  267. }
  268. }
  269. return argList;
  270. }
  271. private static Expression FixupQuotedExpression(Type type, Expression expression)
  272. {
  273. var res = expression;
  274. //
  275. // Keep unquoting until assignability checks pass.
  276. //
  277. while (!type.IsAssignableFrom(res.Type))
  278. {
  279. //
  280. // In case this is not a quote, bail out early.
  281. //
  282. if (res.NodeType != ExpressionType.Quote)
  283. {
  284. //
  285. // Array initialization expressions need special care by unquoting the elements.
  286. //
  287. if (!type.IsAssignableFrom(res.Type) && type.IsArray && res.NodeType == ExpressionType.NewArrayInit)
  288. {
  289. var unquotedType = StripExpression(res.Type);
  290. if (type.IsAssignableFrom(unquotedType))
  291. {
  292. var newArrayExpression = (NewArrayExpression)res;
  293. var count = newArrayExpression.Expressions.Count;
  294. var elementType = type.GetElementType()!;
  295. var list = new List<Expression>(count);
  296. for (var i = 0; i < count; i++)
  297. {
  298. list.Add(FixupQuotedExpression(elementType, newArrayExpression.Expressions[i]));
  299. }
  300. expression = Expression.NewArrayInit(elementType, list);
  301. }
  302. }
  303. return expression;
  304. }
  305. //
  306. // Unquote and try again; at most two passes should be needed.
  307. //
  308. res = ((UnaryExpression)res).Operand;
  309. }
  310. return res;
  311. }
  312. private static Type StripExpression(Type type)
  313. {
  314. //
  315. // Array of quotes need to be stripped, so extract the element type.
  316. //
  317. var elemType = type.IsArray ? type.GetElementType()! : type;
  318. //
  319. // Try to find Expression<T> and obtain T.
  320. //
  321. var genType = FindGenericType(typeof(Expression<>), elemType);
  322. if (genType != null)
  323. {
  324. elemType = genType.GetGenericArguments()[0];
  325. }
  326. //
  327. // Not an array, nothing to do here.
  328. //
  329. if (!type.IsArray)
  330. {
  331. return type;
  332. }
  333. //
  334. // Reconstruct the array type from the stripped element type.
  335. //
  336. var arrayRank = type.GetArrayRank();
  337. if (arrayRank != 1)
  338. {
  339. return elemType.MakeArrayType(arrayRank);
  340. }
  341. return elemType.MakeArrayType();
  342. }
  343. private static MethodInfo FindEnumerableMethod(string name, ReadOnlyCollection<Expression> args, params Type[]? typeArgs)
  344. {
  345. //
  346. // Ensure the cached lookup table for AsyncEnumerable methods is initialized.
  347. //
  348. if (_methods == null)
  349. {
  350. _methods = typeof(AsyncEnumerable).GetMethods(BindingFlags.Static | BindingFlags.Public).ToLookup(m => m.Name);
  351. }
  352. //
  353. // Find a match based on the method name and the argument types.
  354. //
  355. var method = _methods[name].FirstOrDefault(m => ArgsMatch(m, args, typeArgs));
  356. if (method == null)
  357. {
  358. throw new InvalidOperationException(string.Format(CultureInfo.InvariantCulture, "Could not find method with name '{0}' on type '{1}'.", name, typeof(Enumerable)));
  359. }
  360. //
  361. // Close the generic method if needed.
  362. //
  363. if (typeArgs != null)
  364. {
  365. return method.MakeGenericMethod(typeArgs);
  366. }
  367. return method;
  368. }
  369. private static MethodInfo FindMethod(Type type, string name, ReadOnlyCollection<Expression> args, Type[]? typeArgs, BindingFlags flags)
  370. {
  371. //
  372. // Support the enumerable methods to be defined on another type.
  373. //
  374. var targetType = type.GetTypeInfo().GetCustomAttribute<LocalQueryMethodImplementationTypeAttribute>()?.TargetType ?? type;
  375. //
  376. // Get all the candidates based on name and fail if none are found.
  377. //
  378. var methods = targetType.GetMethods(flags).Where(m => m.Name == name).ToArray();
  379. if (methods.Length == 0)
  380. {
  381. throw new InvalidOperationException(string.Format(CultureInfo.InvariantCulture, "Could not find method with name '{0}' on type '{1}'.", name, type));
  382. }
  383. //
  384. // Find a match based on arguments and fail if no match is found.
  385. //
  386. var method = methods.FirstOrDefault(m => ArgsMatch(m, args, typeArgs));
  387. if (method == null)
  388. {
  389. throw new InvalidOperationException(string.Format(CultureInfo.InvariantCulture, "Could not find a matching method with name '{0}' on type '{1}'.", name, type));
  390. }
  391. //
  392. // Close the generic method if needed.
  393. //
  394. if (typeArgs != null)
  395. {
  396. return method.MakeGenericMethod(typeArgs);
  397. }
  398. return method;
  399. }
  400. private static Type? FindGenericType(Type definition, Type? type)
  401. {
  402. while (type != null && type != typeof(object))
  403. {
  404. //
  405. // If the current type matches the specified definition, return.
  406. //
  407. if (type.GetTypeInfo().IsGenericType && type.GetGenericTypeDefinition() == definition)
  408. {
  409. return type;
  410. }
  411. //
  412. // Probe all interfaces implemented by the current type.
  413. //
  414. if (definition.GetTypeInfo().IsInterface)
  415. {
  416. foreach (var ifType in type.GetInterfaces())
  417. {
  418. var res = FindGenericType(definition, ifType);
  419. if (res != null)
  420. {
  421. return res;
  422. }
  423. }
  424. }
  425. //
  426. // Continue up the type hierarchy.
  427. //
  428. type = type.GetTypeInfo().BaseType;
  429. }
  430. return null;
  431. }
  432. }
  433. }