ObservableQuery.cs 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Globalization;
  6. using System.Linq;
  7. using System.Linq.Expressions;
  8. using System.Reactive.Joins;
  9. using System.Reactive.Linq;
  10. using System.Reflection;
  11. namespace System.Reactive
  12. {
  13. internal class ObservableQueryProvider : IQbservableProvider, IQueryProvider
  14. {
  15. public IQbservable<TResult> CreateQuery<TResult>(Expression expression)
  16. {
  17. if (expression == null)
  18. throw new ArgumentNullException(nameof(expression));
  19. if (!typeof(IObservable<TResult>).IsAssignableFrom(expression.Type))
  20. throw new ArgumentException(Strings_Providers.INVALID_TREE_TYPE, nameof(expression));
  21. return new ObservableQuery<TResult>(expression);
  22. }
  23. IQueryable<TElement> IQueryProvider.CreateQuery<TElement>(Expression expression)
  24. {
  25. //
  26. // Here we're on the edge between IQbservable and IQueryable for the local
  27. // execution case. E.g.:
  28. //
  29. // observable.AsQbservable().<operators>.ToQueryable()
  30. //
  31. // This should be turned into a local execution, with the push-to-pull
  32. // adapter in the middle, so we rewrite to:
  33. //
  34. // observable.AsQbservable().<operators>.ToEnumerable().AsQueryable()
  35. //
  36. var call = expression as MethodCallExpression;
  37. if (call == null || call.Method.DeclaringType != typeof(Qbservable) || call.Method.Name != nameof(Qbservable.ToQueryable))
  38. throw new ArgumentException(Strings_Providers.EXPECTED_TOQUERYABLE_METHODCALL, nameof(expression));
  39. //
  40. // This is the IQbservable<T> object corresponding to the lhs. Now wrap
  41. // it in two calls to get the local queryable.
  42. //
  43. var arg0 = call.Arguments[0];
  44. var res =
  45. Expression.Call(
  46. AsQueryable.MakeGenericMethod(typeof(TElement)),
  47. Expression.Call(
  48. typeof(Observable).GetMethod(nameof(Observable.ToEnumerable)).MakeGenericMethod(typeof(TElement)),
  49. arg0
  50. )
  51. );
  52. //
  53. // Queryable operator calls should be taken care of by the provider for
  54. // LINQ to Objects. So we compile and get the resulting IQueryable<T>
  55. // back to hand it out.
  56. //
  57. return Expression.Lambda<Func<IQueryable<TElement>>>(res).Compile()();
  58. }
  59. private static MethodInfo s_AsQueryable;
  60. private static MethodInfo AsQueryable
  61. {
  62. get
  63. {
  64. if (s_AsQueryable == null)
  65. s_AsQueryable = Qbservable.InfoOf<object>(() => Queryable.AsQueryable<object>(null)).GetGenericMethodDefinition();
  66. return s_AsQueryable;
  67. }
  68. }
  69. IQueryable IQueryProvider.CreateQuery(Expression expression)
  70. {
  71. throw new NotImplementedException();
  72. }
  73. TResult IQueryProvider.Execute<TResult>(Expression expression)
  74. {
  75. throw new NotImplementedException();
  76. }
  77. object IQueryProvider.Execute(Expression expression)
  78. {
  79. throw new NotImplementedException();
  80. }
  81. }
  82. internal class ObservableQuery
  83. {
  84. protected object _source;
  85. protected Expression _expression;
  86. public object Source
  87. {
  88. get { return _source; }
  89. }
  90. public Expression Expression
  91. {
  92. get { return _expression; }
  93. }
  94. }
  95. internal class ObservableQuery<TSource> : ObservableQuery, IQbservable<TSource>
  96. {
  97. internal ObservableQuery(IObservable<TSource> source)
  98. {
  99. _source = source;
  100. _expression = Expression.Constant(this);
  101. }
  102. internal ObservableQuery(Expression expression)
  103. {
  104. _expression = expression;
  105. }
  106. public Type ElementType => typeof(TSource);
  107. public IQbservableProvider Provider => Qbservable.Provider;
  108. public IDisposable Subscribe(IObserver<TSource> observer)
  109. {
  110. if (_source == null)
  111. {
  112. var rewriter = new ObservableRewriter();
  113. var body = rewriter.Visit(_expression);
  114. var f = Expression.Lambda<Func<IObservable<TSource>>>(body);
  115. _source = f.Compile()();
  116. }
  117. //
  118. // [OK] Use of unsafe Subscribe: non-pretentious mapping to IObservable<T> behavior equivalent to the expression tree.
  119. //
  120. return ((IObservable<TSource>)_source).Subscribe/*Unsafe*/(observer);
  121. }
  122. public override string ToString()
  123. {
  124. if (_expression is ConstantExpression c && c.Value == this)
  125. {
  126. if (_source != null)
  127. return _source.ToString();
  128. return "null";
  129. }
  130. return _expression.ToString();
  131. }
  132. class ObservableRewriter : ExpressionVisitor
  133. {
  134. protected override Expression VisitConstant(ConstantExpression/*!*/ node)
  135. {
  136. if (node.Value is ObservableQuery query)
  137. {
  138. var source = query.Source;
  139. if (source != null)
  140. {
  141. return Expression.Constant(source);
  142. }
  143. else
  144. {
  145. return Visit(query.Expression);
  146. }
  147. }
  148. return node;
  149. }
  150. protected override Expression VisitMethodCall(MethodCallExpression/*!*/ node)
  151. {
  152. var method = node.Method;
  153. var declaringType = method.DeclaringType;
  154. #if (CRIPPLED_REFLECTION && HAS_WINRT)
  155. var baseType = declaringType.GetTypeInfo().BaseType;
  156. #else
  157. var baseType = declaringType.BaseType;
  158. #endif
  159. if (baseType == typeof(QueryablePattern))
  160. {
  161. if (method.Name == "Then")
  162. {
  163. //
  164. // Retarget Then to the corresponding pattern. Recursive visit of the lhs will rewrite
  165. // the chain of And operators.
  166. //
  167. var pattern = Visit(node.Object);
  168. var arguments = node.Arguments.Select(arg => Unquote(Visit(arg))).ToArray();
  169. var then = Expression.Call(pattern, "Then", method.GetGenericArguments(), arguments);
  170. return then;
  171. }
  172. else if (method.Name == "And")
  173. {
  174. //
  175. // Retarget And to the corresponding pattern.
  176. //
  177. var lhs = Visit(node.Object);
  178. var arguments = node.Arguments.Select(arg => Visit(arg)).ToArray();
  179. var and = Expression.Call(lhs, "And", method.GetGenericArguments(), arguments);
  180. return and;
  181. }
  182. }
  183. else
  184. {
  185. var arguments = node.Arguments.AsEnumerable();
  186. //
  187. // Checking for an IQbservable operator, being either:
  188. // - an extension method on IQbservableProvider
  189. // - an extension method on IQbservable<T>
  190. //
  191. var isOperator = false;
  192. var firstParameter = method.GetParameters().FirstOrDefault();
  193. if (firstParameter != null)
  194. {
  195. var firstParameterType = firstParameter.ParameterType;
  196. //
  197. // Operators like Qbservable.Amb have an n-ary form that take in an IQbservableProvider
  198. // as the first argument. In such a case we need to make sure that the given provider is
  199. // the one targeting regular Observable. If not, we keep the subtree as-is and let that
  200. // provider handle the execution.
  201. //
  202. if (firstParameterType == typeof(IQbservableProvider))
  203. {
  204. isOperator = true;
  205. //
  206. // Since we could be inside a lambda expression where one tries to obtain a query
  207. // provider, or that provider could be stored in an outer variable, we need to
  208. // evaluate the expression to obtain an IQbservableProvider object.
  209. //
  210. var provider = Expression.Lambda<Func<IQbservableProvider>>(Visit(node.Arguments[0])).Compile()();
  211. //
  212. // Let's see whether the ObservableQuery provider is targeted. This one always goes
  213. // to local execution. E.g.:
  214. //
  215. // var xs = Observable.Return(1).AsQbservable();
  216. // var ys = Observable.Return(2).AsQbservable();
  217. // var zs = Observable.Return(3).AsQbservable();
  218. //
  219. // var res = Qbservable.Provider.Amb(xs, ys, zs);
  220. // ^^^^^^^^^^^^^^^^^^^
  221. //
  222. if (provider is ObservableQueryProvider)
  223. {
  224. //
  225. // For further rewrite, simply ignore the query provider argument now to match
  226. // up with the Observable signature. E.g.:
  227. //
  228. // var res = Qbservable.Provider.Amb(xs, ys, zs);
  229. // = Qbservable.Amb(Qbservable.Provider, xs, ys, zs)'
  230. // ^^^^^^^^^^^^^^^^^^^
  231. // ->
  232. // var res = Observable.Amb(xs, ys, zs);
  233. //
  234. arguments = arguments.Skip(1);
  235. }
  236. else
  237. {
  238. //
  239. // We've hit an unknown provider and will defer further execution to it. Upon
  240. // calling Subscribe to the node's output, that provider will take care of it.
  241. //
  242. return node;
  243. }
  244. }
  245. else if (typeof(IQbservable).IsAssignableFrom(firstParameterType))
  246. {
  247. isOperator = true;
  248. }
  249. }
  250. if (isOperator)
  251. {
  252. var args = VisitQbservableOperatorArguments(method, arguments);
  253. return FindObservableMethod(method, args);
  254. }
  255. }
  256. return base.VisitMethodCall(node);
  257. }
  258. #if NO_VISITLAMBDAOFT
  259. protected override Expression VisitLambda(LambdaExpression node)
  260. #else
  261. protected override Expression VisitLambda<T>(Expression<T> node)
  262. #endif
  263. {
  264. return node;
  265. }
  266. private IList<Expression> VisitQbservableOperatorArguments(MethodInfo method, IEnumerable<Expression> arguments)
  267. {
  268. //
  269. // Recognize the Qbservable.When<TResult>(IQbservableProvider, QueryablePlan<TResult>[])
  270. // overload in order to substitute the array with a Plan<TResult>[].
  271. //
  272. if (method.Name == "When")
  273. {
  274. var lastArgument = arguments.Last();
  275. if (lastArgument.NodeType == ExpressionType.NewArrayInit)
  276. {
  277. var paramsArray = (NewArrayExpression)lastArgument;
  278. return new List<Expression>
  279. {
  280. Expression.NewArrayInit(
  281. typeof(Plan<>).MakeGenericType(method.GetGenericArguments()[0]),
  282. paramsArray.Expressions.Select(param => Visit(param))
  283. )
  284. };
  285. }
  286. }
  287. return arguments.Select(arg => Visit(arg)).ToList();
  288. }
  289. class Lazy<T>
  290. {
  291. private readonly Func<T> _factory;
  292. private T _value;
  293. private bool _initialized;
  294. public Lazy(Func<T> factory)
  295. {
  296. _factory = factory;
  297. }
  298. public T Value
  299. {
  300. get
  301. {
  302. lock (_factory)
  303. {
  304. if (!_initialized)
  305. {
  306. _value = _factory();
  307. _initialized = true;
  308. }
  309. }
  310. return _value;
  311. }
  312. }
  313. }
  314. private static Lazy<ILookup<string, MethodInfo>> _observableMethods = new Lazy<ILookup<string, MethodInfo>>(() => GetMethods(typeof(Observable)));
  315. private static MethodCallExpression FindObservableMethod(MethodInfo method, IList<Expression> arguments)
  316. {
  317. //
  318. // Where to look for the matching operator?
  319. //
  320. var targetType = default(Type);
  321. var methods = default(ILookup<string, MethodInfo>);
  322. if (method.DeclaringType == typeof(Qbservable))
  323. {
  324. targetType = typeof(Observable);
  325. methods = _observableMethods.Value;
  326. }
  327. else
  328. {
  329. targetType = method.DeclaringType;
  330. #if (CRIPPLED_REFLECTION && HAS_WINRT)
  331. var typeInfo = targetType.GetTypeInfo();
  332. if (typeInfo.IsDefined(typeof(LocalQueryMethodImplementationTypeAttribute), false))
  333. {
  334. var mapping = (LocalQueryMethodImplementationTypeAttribute)typeInfo.GetCustomAttributes(typeof(LocalQueryMethodImplementationTypeAttribute), false).Single();
  335. targetType = mapping.TargetType;
  336. }
  337. #else
  338. if (targetType.IsDefined(typeof(LocalQueryMethodImplementationTypeAttribute), false))
  339. {
  340. var mapping = (LocalQueryMethodImplementationTypeAttribute)targetType.GetCustomAttributes(typeof(LocalQueryMethodImplementationTypeAttribute), false)[0];
  341. targetType = mapping.TargetType;
  342. }
  343. #endif
  344. methods = GetMethods(targetType);
  345. }
  346. //
  347. // From all the operators with the method's name, find the one that matches all arguments.
  348. //
  349. var typeArgs = method.IsGenericMethod ? method.GetGenericArguments() : null;
  350. var targetMethod = methods[method.Name].FirstOrDefault(candidateMethod => ArgsMatch(candidateMethod, arguments, typeArgs));
  351. if (targetMethod == null)
  352. throw new InvalidOperationException(string.Format(CultureInfo.CurrentCulture, Strings_Providers.NO_MATCHING_METHOD_FOUND, method.Name, targetType.Name));
  353. //
  354. // Restore generic arguments.
  355. //
  356. if (typeArgs != null)
  357. targetMethod = targetMethod.MakeGenericMethod(typeArgs);
  358. //
  359. // Finally, we need to deal with mismatches on Expression<Func<...>> versus Func<...>.
  360. //
  361. var parameters = targetMethod.GetParameters();
  362. for (int i = 0, n = parameters.Length; i < n; i++)
  363. {
  364. arguments[i] = Unquote(arguments[i]);
  365. }
  366. //
  367. // Emit a new call to the discovered target method.
  368. //
  369. return Expression.Call(null, targetMethod, arguments);
  370. }
  371. private static ILookup<string, MethodInfo> GetMethods(Type type)
  372. {
  373. #if !(CRIPPLED_REFLECTION && HAS_WINRT)
  374. return type.GetMethods(BindingFlags.Static | BindingFlags.Public).ToLookup(m => m.Name);
  375. #else
  376. return type.GetTypeInfo().DeclaredMethods.Where(m => m.IsStatic && m.IsPublic).ToLookup(m => m.Name);
  377. #endif
  378. }
  379. private static bool ArgsMatch(MethodInfo method, IList<Expression> arguments, Type[] typeArgs)
  380. {
  381. //
  382. // Number of parameters should match. Notice we've sanitized IQbservableProvider "this"
  383. // parameters first (see VisitMethodCall).
  384. //
  385. var parameters = method.GetParameters();
  386. if (parameters.Length != arguments.Count)
  387. return false;
  388. //
  389. // Genericity should match too.
  390. //
  391. if (!method.IsGenericMethod && typeArgs != null && typeArgs.Length > 0)
  392. return false;
  393. //
  394. // Reconstruct the generic method if needed.
  395. //
  396. if (method.IsGenericMethodDefinition)
  397. {
  398. if (typeArgs == null)
  399. return false;
  400. if (method.GetGenericArguments().Length != typeArgs.Length)
  401. return false;
  402. var result = method.MakeGenericMethod(typeArgs);
  403. parameters = result.GetParameters();
  404. }
  405. //
  406. // Check compatibility for the parameter types.
  407. //
  408. for (int i = 0, n = arguments.Count; i < n; i++)
  409. {
  410. var parameterType = parameters[i].ParameterType;
  411. var argument = arguments[i];
  412. //
  413. // For operators that take a function (like Where, Select), we'll be faced
  414. // with a quoted argument and a discrepancy between Expression<Func<...>>
  415. // and the underlying Func<...>.
  416. //
  417. if (!parameterType.IsAssignableFrom(argument.Type))
  418. {
  419. argument = Unquote(argument);
  420. if (!parameterType.IsAssignableFrom(argument.Type))
  421. return false;
  422. }
  423. }
  424. return true;
  425. }
  426. private static Expression Unquote(Expression expression)
  427. {
  428. //
  429. // Get rid of all outer quotes around an expression.
  430. //
  431. while (expression.NodeType == ExpressionType.Quote)
  432. expression = ((UnaryExpression)expression).Operand;
  433. return expression;
  434. }
  435. }
  436. }
  437. }