123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191 |
- // Licensed to the .NET Foundation under one or more agreements.
- // The .NET Foundation licenses this file to you under the MIT License.
- // See the LICENSE file in the project root for more information.
- using System.Collections.Generic;
- using System.Globalization;
- using System.Linq.Expressions;
- using System.Reflection;
- namespace System.Linq
- {
- /// <summary>
- /// Provides a set of additional static methods that allow querying enumerable sequences.
- /// </summary>
- public static partial class QueryableEx
- {
- /// <summary>
- /// Gets the local Queryable provider.
- /// </summary>
- public static IQueryProvider Provider { get; } = new QueryProviderShim();
- private sealed class QueryProviderShim : IQueryProvider
- {
- public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
- {
- var provider = Array.Empty<TElement>().AsQueryable().Provider;
- var res = Redir(expression);
- return provider.CreateQuery<TElement>(res);
- }
- public IQueryable CreateQuery(Expression expression)
- {
- return CreateQuery<object>(expression);
- }
- public TResult Execute<TResult>(Expression expression)
- {
- var provider = Array.Empty<TResult>().AsQueryable().Provider;
- var res = Redir(expression);
- return provider.Execute<TResult>(res);
- }
- public object Execute(Expression expression)
- {
- return Execute<object>(expression);
- }
- private static Expression Redir(Expression expression)
- {
- if (expression is MethodCallExpression mce && mce.Method.DeclaringType == typeof(QueryableEx))
- {
- if (mce.Arguments.Count >= 1 && typeof(IQueryProvider).IsAssignableFrom(mce.Arguments[0].Type))
- {
- if (mce.Arguments[0] is ConstantExpression ce)
- {
- if (ce.Value is QueryProviderShim)
- {
- var targetType = typeof(QueryableEx);
- var method = mce.Method;
- var methods = GetMethods(targetType);
- var arguments = mce.Arguments.Skip(1).ToList();
- //
- // From all the operators with the method's name, find the one that matches all arguments.
- //
- var typeArgs = method.IsGenericMethod ? method.GetGenericArguments() : null;
- var targetMethod = methods[method.Name].FirstOrDefault(candidateMethod => ArgsMatch(candidateMethod, arguments, typeArgs));
- if (targetMethod == null)
- throw new InvalidOperationException(string.Format(CultureInfo.CurrentCulture, "There is no method '{0}' on type '{1}' that matches the specified arguments", method.Name, targetType.Name));
- //
- // Restore generic arguments.
- //
- if (typeArgs != null)
- targetMethod = targetMethod.MakeGenericMethod(typeArgs);
- //
- // Finally, we need to deal with mismatches on Expression<Func<...>> versus Func<...>.
- //
- var parameters = targetMethod.GetParameters();
- for (int i = 0, n = parameters.Length; i < n; i++)
- {
- arguments[i] = Unquote(arguments[i]);
- }
- //
- // Emit a new call to the discovered target method.
- //
- return Expression.Call(null, targetMethod, arguments);
- }
- }
- }
- }
- return expression;
- }
- private static ILookup<string, MethodInfo> GetMethods(Type type)
- {
- return type.GetMethods(BindingFlags.Static | BindingFlags.Public).ToLookup(m => m.Name);
- }
- private static bool ArgsMatch(MethodInfo method, IList<Expression> arguments, Type[]? typeArgs)
- {
- //
- // Number of parameters should match. Notice we've sanitized IQueryProvider "this"
- // parameters first (see Redir).
- //
- var parameters = method.GetParameters();
- if (parameters.Length != arguments.Count)
- return false;
- //
- // Genericity should match too.
- //
- if (!method.IsGenericMethod && typeArgs != null && typeArgs.Length > 0)
- return false;
- //
- // Reconstruct the generic method if needed.
- //
- if (method.IsGenericMethodDefinition)
- {
- if (typeArgs == null)
- return false;
- if (method.GetGenericArguments().Length != typeArgs.Length)
- return false;
- var result = method.MakeGenericMethod(typeArgs);
- parameters = result.GetParameters();
- }
- //
- // Check compatibility for the parameter types.
- //
- for (int i = 0, n = arguments.Count; i < n; i++)
- {
- var parameterType = parameters[i].ParameterType;
- var argument = arguments[i];
- //
- // For operators that take a function (like Where, Select), we'll be faced
- // with a quoted argument and a discrepancy between Expression<Func<...>>
- // and the underlying Func<...>.
- //
- if (!parameterType.IsAssignableFrom(argument.Type))
- {
- argument = Unquote(argument);
- if (!parameterType.IsAssignableFrom(argument.Type))
- return false;
- }
- }
- return true;
- }
- private static Expression Unquote(Expression expression)
- {
- //
- // Get rid of all outer quotes around an expression.
- //
- while (expression.NodeType == ExpressionType.Quote)
- expression = ((UnaryExpression)expression).Operand;
- return expression;
- }
- }
- internal static Expression GetSourceExpression<TSource>(IEnumerable<TSource> source)
- {
- if (source is IQueryable<TSource> q)
- return q.Expression;
- return Expression.Constant(source, typeof(IEnumerable<TSource>));
- }
- internal static Expression GetSourceExpression<TSource>(IEnumerable<TSource>[] sources)
- {
- return Expression.NewArrayInit(
- typeof(IEnumerable<TSource>),
- sources.Select(source => GetSourceExpression(source))
- );
- }
- internal static MethodInfo InfoOf<R>(Expression<Func<R>> f)
- {
- return ((MethodCallExpression)f.Body).Method;
- }
- }
- }
|