AsyncEnumerableRewriter.cs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. // Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information.
  2. using System.Collections.Generic;
  3. using System.Collections.ObjectModel;
  4. using System.Globalization;
  5. using System.Linq.Expressions;
  6. using System.Reflection;
  7. namespace System.Linq
  8. {
  9. /// <summary>
  10. /// Rewrites an expression tree representation using AsyncQueryable methods to the corresponding AsyncEnumerable equivalents.
  11. /// </summary>
  12. internal class AsyncEnumerableRewriter : ExpressionVisitor
  13. {
  14. private static volatile ILookup<string, MethodInfo> s_methods;
  15. protected override Expression VisitConstant(ConstantExpression node)
  16. {
  17. var enumerableQuery = node.Value as AsyncEnumerableQuery;
  18. //
  19. // Not an expression representation obtained from the async enumerable query provider,
  20. // so just a plain constant that can be returned as-is.
  21. //
  22. if (enumerableQuery == null)
  23. {
  24. return node;
  25. }
  26. //
  27. // Expression representation obtained from the async enumerable query provider, so first
  28. // check whether it wraps an enumerable sequence that has been evaluated already.
  29. //
  30. if (enumerableQuery.Enumerable != null)
  31. {
  32. var publicType = GetPublicType(enumerableQuery.Enumerable.GetType());
  33. return Expression.Constant(enumerableQuery.Enumerable, publicType);
  34. }
  35. //
  36. // If not evaluated yet, inline the expression representation.
  37. //
  38. return Visit(enumerableQuery.Expression);
  39. }
  40. protected override Expression VisitMethodCall(MethodCallExpression node)
  41. {
  42. var obj = Visit(node.Object);
  43. var args = Visit(node.Arguments);
  44. //
  45. // Nothing changed during the visit; just some unrelated method call that can
  46. // be returned as-is.
  47. //
  48. if (obj == node.Object && args == node.Arguments)
  49. {
  50. return node;
  51. }
  52. var typeArgs = node.Method.IsGenericMethod ? node.Method.GetGenericArguments() : null;
  53. //
  54. // Check whether the method is compatible with the recursively rewritten instance
  55. // and arguments expressions. If so, create a new call expression.
  56. //
  57. if ((node.Method.IsStatic || node.Method.DeclaringType.IsAssignableFrom(obj.Type)) && ArgsMatch(node.Method, args, typeArgs))
  58. {
  59. return Expression.Call(obj, node.Method, args);
  60. }
  61. var method = default(MethodInfo);
  62. //
  63. // Find a corresponding method in the non-expression world, e.g. rewriting from
  64. // the AsyncQueryable methods to the ones on AsyncEnumerable.
  65. //
  66. if (node.Method.DeclaringType == typeof(AsyncQueryable))
  67. {
  68. method = FindEnumerableMethod(node.Method.Name, args, typeArgs);
  69. args = FixupQuotedArgs(method, args);
  70. return Expression.Call(obj, method, args);
  71. }
  72. else
  73. {
  74. method = FindMethod(node.Method.DeclaringType, node.Method.Name, args, typeArgs, BindingFlags.Static | (node.Method.IsPublic ? BindingFlags.Public : BindingFlags.NonPublic));
  75. args = FixupQuotedArgs(method, args);
  76. }
  77. return Expression.Call(obj, method, args);
  78. }
  79. protected override Expression VisitLambda<T>(Expression<T> node)
  80. {
  81. //
  82. // Don't recurse into lambdas; all the ones returning IAsyncQueryable<T>
  83. // are compatible with their IAsyncEnumerable<T> counterparts due to the
  84. // covariant return type.
  85. //
  86. return node;
  87. }
  88. protected override Expression VisitParameter(ParameterExpression node)
  89. {
  90. //
  91. // See remark on VisitLambda.
  92. //
  93. return node;
  94. }
  95. private static Type GetPublicType(Type type)
  96. {
  97. if (!type.IsNestedPrivate)
  98. {
  99. return type;
  100. }
  101. foreach (var ifType in type.GetInterfaces())
  102. {
  103. if (ifType.IsGenericType)
  104. {
  105. var def = ifType.GetGenericTypeDefinition();
  106. if (def == typeof(IAsyncEnumerable<>) || def == typeof(IAsyncGrouping<,>))
  107. {
  108. return ifType;
  109. }
  110. }
  111. }
  112. //
  113. // TODO: non-generic interface needed?
  114. //
  115. //if (typeof(IAsyncEnumerable).IsAssignableFrom(type))
  116. //{
  117. // return typeof(IAsyncEnumerable);
  118. //}
  119. return type;
  120. }
  121. private static bool ArgsMatch(MethodInfo method, ReadOnlyCollection<Expression> args, Type[] typeArgs)
  122. {
  123. //
  124. // Number of parameters should match the number of arguments to bind.
  125. //
  126. var parameters = method.GetParameters();
  127. if (parameters.Length != args.Count)
  128. {
  129. return false;
  130. }
  131. //
  132. // Both should be generic or non-generic.
  133. //
  134. if (!method.IsGenericMethod && typeArgs != null && typeArgs.Length != 0)
  135. {
  136. return false;
  137. }
  138. //
  139. // Closed generic methods need to get converted to their open generic counterpart.
  140. //
  141. if (!method.IsGenericMethodDefinition && method.IsGenericMethod && method.ContainsGenericParameters)
  142. {
  143. method = method.GetGenericMethodDefinition();
  144. }
  145. //
  146. // For generic methods, close the candidate using the specified type arguments.
  147. //
  148. if (method.IsGenericMethodDefinition)
  149. {
  150. //
  151. // We should have at least 1 type argument.
  152. //
  153. if (typeArgs == null || typeArgs.Length == 0)
  154. {
  155. return false;
  156. }
  157. //
  158. // The number of type arguments needed should match the specified type argument count.
  159. //
  160. if (method.GetGenericArguments().Length != typeArgs.Length)
  161. {
  162. return false;
  163. }
  164. //
  165. // Close the generic method and re-obtain the parameters.
  166. //
  167. method = method.MakeGenericMethod(typeArgs);
  168. parameters = method.GetParameters();
  169. }
  170. //
  171. // Check for contravariant assignability of each parameter.
  172. //
  173. for (var i = 0; i < args.Count; i++)
  174. {
  175. var type = parameters[i].ParameterType;
  176. //
  177. // Hardening against reflection quirks.
  178. //
  179. if (type == null)
  180. {
  181. return false;
  182. }
  183. //
  184. // Deal with ref or out parameters by using the element type which can
  185. // match the corresponding expression type (ref passing is not encoded
  186. // in the type of expression trees).
  187. //
  188. if (type.IsByRef)
  189. {
  190. type = type.GetElementType();
  191. }
  192. var expression = args[i];
  193. //
  194. // If the expression is assignable to the parameter, all is good. If not,
  195. // it's possible there's a match because we're dealing with a quote that
  196. // needs to be unpacked.
  197. //
  198. if (!type.IsAssignableFrom(expression.Type))
  199. {
  200. //
  201. // Unpack the quote, if any. See AsyncQueryable for examples of operators
  202. // that hit this case.
  203. //
  204. if (expression.NodeType == ExpressionType.Quote)
  205. {
  206. expression = ((UnaryExpression)expression).Operand;
  207. }
  208. //
  209. // Try assigning the raw expression type or the quote-free expression type
  210. // to the parameter. If none of these work, there's no match.
  211. //
  212. if (!type.IsAssignableFrom(expression.Type) && !type.IsAssignableFrom(StripExpression(expression.Type)))
  213. {
  214. return false;
  215. }
  216. }
  217. }
  218. return true;
  219. }
  220. private ReadOnlyCollection<Expression> FixupQuotedArgs(MethodInfo method, ReadOnlyCollection<Expression> argList)
  221. {
  222. //
  223. // Get all of the method parameters. No fix-up needed if empty.
  224. //
  225. var parameters = method.GetParameters();
  226. if (parameters.Length != 0)
  227. {
  228. var list = default(List<Expression>);
  229. //
  230. // Process all parameters. If any fixup is needed, the list will
  231. // get assigned.
  232. //
  233. for (var i = 0; i < parameters.Length; i++)
  234. {
  235. var expression = argList[i];
  236. var parameterInfo = parameters[i];
  237. //
  238. // Perform the fix-up if needed and check the outcome. If a
  239. // change was made, the list is lazily allocated.
  240. //
  241. expression = FixupQuotedExpression(parameterInfo.ParameterType, expression);
  242. if (list == null && expression != argList[i])
  243. {
  244. list = new List<Expression>(argList.Count);
  245. for (var j = 0; j < i; j++)
  246. {
  247. list.Add(argList[j]);
  248. }
  249. }
  250. if (list != null)
  251. {
  252. list.Add(expression);
  253. }
  254. }
  255. //
  256. // If any argument was fixed up, return a new argument list.
  257. //
  258. if (list != null)
  259. {
  260. argList = new ReadOnlyCollection<Expression>(list);
  261. }
  262. }
  263. return argList;
  264. }
  265. private Expression FixupQuotedExpression(Type type, Expression expression)
  266. {
  267. var res = expression;
  268. //
  269. // Keep unquoting until assignability checks pass.
  270. //
  271. while (!type.IsAssignableFrom(res.Type))
  272. {
  273. //
  274. // In case this is not a quote, bail out early.
  275. //
  276. if (res.NodeType != ExpressionType.Quote)
  277. {
  278. //
  279. // Array initialization expressions need special care by unquoting the elements.
  280. //
  281. if (!type.IsAssignableFrom(res.Type) && type.IsArray && res.NodeType == ExpressionType.NewArrayInit)
  282. {
  283. var unquotedType = StripExpression(res.Type);
  284. if (type.IsAssignableFrom(unquotedType))
  285. {
  286. var newArrayExpression = (NewArrayExpression)res;
  287. var count = newArrayExpression.Expressions.Count;
  288. var elementType = type.GetElementType();
  289. var list = new List<Expression>(count);
  290. for (var i = 0; i < count; i++)
  291. {
  292. list.Add(FixupQuotedExpression(elementType, newArrayExpression.Expressions[i]));
  293. }
  294. expression = Expression.NewArrayInit(elementType, list);
  295. }
  296. }
  297. return expression;
  298. }
  299. //
  300. // Unquote and try again; at most two passes should be needed.
  301. //
  302. res = ((UnaryExpression)res).Operand;
  303. }
  304. return res;
  305. }
  306. private static Type StripExpression(Type type)
  307. {
  308. //
  309. // Array of quotes need to be stripped, so extract the element type.
  310. //
  311. var elemType = type.IsArray ? type.GetElementType() : type;
  312. //
  313. // Try to find Expression<T> and obtain T.
  314. //
  315. var genType = FindGenericType(typeof(Expression<>), elemType);
  316. if (genType != null)
  317. {
  318. elemType = genType.GetGenericArguments()[0];
  319. }
  320. //
  321. // Not an array, nothing to do here.
  322. //
  323. if (!type.IsArray)
  324. {
  325. return type;
  326. }
  327. //
  328. // Reconstruct the array type from the stripped element type.
  329. //
  330. var arrayRank = type.GetArrayRank();
  331. if (arrayRank != 1)
  332. {
  333. return elemType.MakeArrayType(arrayRank);
  334. }
  335. return elemType.MakeArrayType();
  336. }
  337. private static MethodInfo FindEnumerableMethod(string name, ReadOnlyCollection<Expression> args, params Type[] typeArgs)
  338. {
  339. //
  340. // Ensure the cached lookup table for AsyncEnumerable methods is initialized.
  341. //
  342. if (s_methods == null)
  343. {
  344. s_methods = typeof(AsyncEnumerable).GetMethods(BindingFlags.Static | BindingFlags.Public).ToLookup(m => m.Name);
  345. }
  346. //
  347. // Find a match based on the method name and the argument types.
  348. //
  349. var method = s_methods[name].FirstOrDefault(m => ArgsMatch(m, args, typeArgs));
  350. if (method == null)
  351. {
  352. throw new InvalidOperationException(string.Format(CultureInfo.InvariantCulture, "Could not find method with name '{0}' on type '{1}'.", name, typeof(Enumerable)));
  353. }
  354. //
  355. // Close the generic method if needed.
  356. //
  357. if (typeArgs != null)
  358. {
  359. return method.MakeGenericMethod(typeArgs);
  360. }
  361. return method;
  362. }
  363. private static MethodInfo FindMethod(Type type, string name, ReadOnlyCollection<Expression> args, Type[] typeArgs, BindingFlags flags)
  364. {
  365. //
  366. // Get all the candidates based on name and fail if none are found.
  367. //
  368. var methods = type.GetMethods(flags).Where(m => m.Name == name).ToArray();
  369. if (methods.Length == 0)
  370. {
  371. throw new InvalidOperationException(string.Format(CultureInfo.InvariantCulture, "Could not find method with name '{0}' on type '{1}'.", name, type));
  372. }
  373. //
  374. // Find a match based on arguments and fail if no match is found.
  375. //
  376. var method = methods.FirstOrDefault(m => ArgsMatch(m, args, typeArgs));
  377. if (method == null)
  378. {
  379. throw new InvalidOperationException(string.Format(CultureInfo.InvariantCulture, "Could not find a matching method with name '{0}' on type '{1}'.", name, type));
  380. }
  381. //
  382. // Close the generic method if needed.
  383. //
  384. if (typeArgs != null)
  385. {
  386. return method.MakeGenericMethod(typeArgs);
  387. }
  388. return method;
  389. }
  390. private static Type FindGenericType(Type definition, Type type)
  391. {
  392. while (type != null && type != typeof(object))
  393. {
  394. //
  395. // If the current type matches the specified definition, return.
  396. //
  397. if (type.IsGenericType && type.GetGenericTypeDefinition() == definition)
  398. {
  399. return type;
  400. }
  401. //
  402. // Probe all interfaces implemented by the current type.
  403. //
  404. if (definition.IsInterface)
  405. {
  406. foreach (var ifType in type.GetInterfaces())
  407. {
  408. var res = FindGenericType(definition, ifType);
  409. if (res != null)
  410. {
  411. return res;
  412. }
  413. }
  414. }
  415. //
  416. // Continue up the type hierarchy.
  417. //
  418. type = type.BaseType;
  419. }
  420. return null;
  421. }
  422. }
  423. }