QueryableEx.cs 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the MIT License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Globalization;
  6. using System.Linq.Expressions;
  7. using System.Reflection;
  8. namespace System.Linq
  9. {
  10. /// <summary>
  11. /// Provides a set of additional static methods that allow querying enumerable sequences.
  12. /// </summary>
  13. public static partial class QueryableEx
  14. {
  15. /// <summary>
  16. /// Gets the local Queryable provider.
  17. /// </summary>
  18. public static IQueryProvider Provider { get; } = new QueryProviderShim();
  19. private sealed class QueryProviderShim : IQueryProvider
  20. {
  21. public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
  22. {
  23. var provider = Array.Empty<TElement>().AsQueryable().Provider;
  24. var res = Redir(expression);
  25. return provider.CreateQuery<TElement>(res);
  26. }
  27. public IQueryable CreateQuery(Expression expression)
  28. {
  29. return CreateQuery<object>(expression);
  30. }
  31. public TResult Execute<TResult>(Expression expression)
  32. {
  33. var provider = Array.Empty<TResult>().AsQueryable().Provider;
  34. var res = Redir(expression);
  35. return provider.Execute<TResult>(res);
  36. }
  37. public object Execute(Expression expression)
  38. {
  39. return Execute<object>(expression);
  40. }
  41. private static Expression Redir(Expression expression)
  42. {
  43. if (expression is MethodCallExpression mce && mce.Method.DeclaringType == typeof(QueryableEx))
  44. {
  45. if (mce.Arguments.Count >= 1 && typeof(IQueryProvider).IsAssignableFrom(mce.Arguments[0].Type))
  46. {
  47. if (mce.Arguments[0] is ConstantExpression ce)
  48. {
  49. if (ce.Value is QueryProviderShim)
  50. {
  51. var targetType = typeof(QueryableEx);
  52. var method = mce.Method;
  53. var methods = GetMethods(targetType);
  54. var arguments = mce.Arguments.Skip(1).ToList();
  55. //
  56. // From all the operators with the method's name, find the one that matches all arguments.
  57. //
  58. var typeArgs = method.IsGenericMethod ? method.GetGenericArguments() : null;
  59. var targetMethod = methods[method.Name].FirstOrDefault(candidateMethod => ArgsMatch(candidateMethod, arguments, typeArgs));
  60. if (targetMethod == null)
  61. throw new InvalidOperationException(string.Format(CultureInfo.CurrentCulture, "There is no method '{0}' on type '{1}' that matches the specified arguments", method.Name, targetType.Name));
  62. //
  63. // Restore generic arguments.
  64. //
  65. if (typeArgs != null)
  66. targetMethod = targetMethod.MakeGenericMethod(typeArgs);
  67. //
  68. // Finally, we need to deal with mismatches on Expression<Func<...>> versus Func<...>.
  69. //
  70. var parameters = targetMethod.GetParameters();
  71. for (int i = 0, n = parameters.Length; i < n; i++)
  72. {
  73. arguments[i] = Unquote(arguments[i]);
  74. }
  75. //
  76. // Emit a new call to the discovered target method.
  77. //
  78. return Expression.Call(null, targetMethod, arguments);
  79. }
  80. }
  81. }
  82. }
  83. return expression;
  84. }
  85. private static ILookup<string, MethodInfo> GetMethods(Type type)
  86. {
  87. return type.GetMethods(BindingFlags.Static | BindingFlags.Public).ToLookup(m => m.Name);
  88. }
  89. private static bool ArgsMatch(MethodInfo method, IList<Expression> arguments, Type[]? typeArgs)
  90. {
  91. //
  92. // Number of parameters should match. Notice we've sanitized IQueryProvider "this"
  93. // parameters first (see Redir).
  94. //
  95. var parameters = method.GetParameters();
  96. if (parameters.Length != arguments.Count)
  97. return false;
  98. //
  99. // Genericity should match too.
  100. //
  101. if (!method.IsGenericMethod && typeArgs != null && typeArgs.Length > 0)
  102. return false;
  103. //
  104. // Reconstruct the generic method if needed.
  105. //
  106. if (method.IsGenericMethodDefinition)
  107. {
  108. if (typeArgs == null)
  109. return false;
  110. if (method.GetGenericArguments().Length != typeArgs.Length)
  111. return false;
  112. var result = method.MakeGenericMethod(typeArgs);
  113. parameters = result.GetParameters();
  114. }
  115. //
  116. // Check compatibility for the parameter types.
  117. //
  118. for (int i = 0, n = arguments.Count; i < n; i++)
  119. {
  120. var parameterType = parameters[i].ParameterType;
  121. var argument = arguments[i];
  122. //
  123. // For operators that take a function (like Where, Select), we'll be faced
  124. // with a quoted argument and a discrepancy between Expression<Func<...>>
  125. // and the underlying Func<...>.
  126. //
  127. if (!parameterType.IsAssignableFrom(argument.Type))
  128. {
  129. argument = Unquote(argument);
  130. if (!parameterType.IsAssignableFrom(argument.Type))
  131. return false;
  132. }
  133. }
  134. return true;
  135. }
  136. private static Expression Unquote(Expression expression)
  137. {
  138. //
  139. // Get rid of all outer quotes around an expression.
  140. //
  141. while (expression.NodeType == ExpressionType.Quote)
  142. expression = ((UnaryExpression)expression).Operand;
  143. return expression;
  144. }
  145. }
  146. internal static Expression GetSourceExpression<TSource>(IEnumerable<TSource> source)
  147. {
  148. if (source is IQueryable<TSource> q)
  149. return q.Expression;
  150. return Expression.Constant(source, typeof(IEnumerable<TSource>));
  151. }
  152. internal static Expression GetSourceExpression<TSource>(IEnumerable<TSource>[] sources)
  153. {
  154. return Expression.NewArrayInit(
  155. typeof(IEnumerable<TSource>),
  156. sources.Select(source => GetSourceExpression(source))
  157. );
  158. }
  159. internal static MethodInfo InfoOf<R>(Expression<Func<R>> f)
  160. {
  161. return ((MethodCallExpression)f.Body).Method;
  162. }
  163. }
  164. }