Program.cs 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System;
  5. using System.Collections.Generic;
  6. using System.Linq;
  7. using System.Linq.Expressions;
  8. using System.Reflection;
  9. using System.Runtime.CompilerServices;
  10. using System.Threading;
  11. using System.Threading.Tasks;
  12. namespace ApiCompare
  13. {
  14. class Program
  15. {
  16. private static readonly Type asyncInterfaceType = typeof(IAsyncEnumerable<>);
  17. private static readonly Type syncInterfaceType = typeof(IEnumerable<>);
  18. private static readonly Type asyncOrderedInterfaceType = typeof(IOrderedAsyncEnumerable<>);
  19. private static readonly Type syncOrderedInterfaceType = typeof(IOrderedEnumerable<>);
  20. private static readonly string[] exceptions = new[]
  21. {
  22. "SkipLast", // In .NET Core 2.0
  23. "TakeLast", // In .NET Core 2.0
  24. "Cast", // Non-generic methods
  25. "OfType", // Non-generic methods
  26. "AsEnumerable", // Trivially renamed
  27. "AsAsyncEnumerable" // Trivially renamed
  28. };
  29. private static readonly TypeSubstitutor subst = new TypeSubstitutor(new Dictionary<Type, Type>
  30. {
  31. { asyncInterfaceType, syncInterfaceType },
  32. { asyncOrderedInterfaceType, syncOrderedInterfaceType },
  33. { typeof(IAsyncGrouping<,>), typeof(IGrouping<,>) },
  34. });
  35. static void Main()
  36. {
  37. var asyncOperatorsType = typeof(AsyncEnumerable);
  38. var syncOperatorsType = typeof(Enumerable);
  39. Compare(syncOperatorsType, asyncOperatorsType);
  40. }
  41. static void Compare(Type syncOperatorsType, Type asyncOperatorsType)
  42. {
  43. var syncOperators = GetQueryOperators(new[] { syncInterfaceType, syncOrderedInterfaceType}, syncOperatorsType, exceptions);
  44. var asyncOperators = GetQueryOperators(new[] { asyncInterfaceType, asyncOrderedInterfaceType }, asyncOperatorsType, exceptions);
  45. CompareFactories(syncOperators.Factories, asyncOperators.Factories);
  46. CompareQueryOperators(syncOperators.QueryOperators, asyncOperators.QueryOperators);
  47. CompareAggregates(syncOperators.Aggregates, asyncOperators.Aggregates);
  48. }
  49. static void CompareFactories(ILookup<string, MethodInfo> syncFactories, ILookup<string, MethodInfo> asyncFactories)
  50. {
  51. CompareSets(syncFactories, asyncFactories, CompareFactoryOverloads);
  52. }
  53. static void CompareFactoryOverloads(string name, IEnumerable<MethodInfo> syncMethods, IEnumerable<MethodInfo> asyncMethods)
  54. {
  55. var sync = GetSignatures(syncMethods).ToArray();
  56. var async = GetRewrittenSignatures(asyncMethods).ToArray();
  57. //
  58. // Ensure that async is a superset of sync.
  59. //
  60. var notInAsync = sync.Except(async);
  61. if (notInAsync.Any())
  62. {
  63. foreach (var signature in notInAsync)
  64. {
  65. Console.WriteLine("MISSING " + ToString(signature.Method));
  66. }
  67. }
  68. //
  69. // Check for excess overloads.
  70. //
  71. var notInSync = async.Except(sync);
  72. if (notInSync.Any())
  73. {
  74. foreach (var signature in notInSync)
  75. {
  76. Console.WriteLine("EXCESS " + ToString(signature.Method));
  77. }
  78. }
  79. }
  80. static void CompareQueryOperators(ILookup<string, MethodInfo> syncOperators, ILookup<string, MethodInfo> asyncOperators)
  81. {
  82. CompareSets(syncOperators, asyncOperators, CompareQueryOperatorsOverloads);
  83. }
  84. static void CompareQueryOperatorsOverloads(string name, IEnumerable<MethodInfo> syncMethods, IEnumerable<MethodInfo> asyncMethods)
  85. {
  86. var sync = GetSignatures(syncMethods).ToArray();
  87. var async = GetRewrittenSignatures(asyncMethods).ToArray();
  88. //
  89. // Ensure that async is a superset of sync.
  90. //
  91. var notInAsync = sync.Except(async);
  92. if (notInAsync.Any())
  93. {
  94. foreach (var signature in notInAsync)
  95. {
  96. Console.WriteLine("MISSING " + ToString(signature.Method));
  97. }
  98. }
  99. //
  100. // Find Task-based overloads.
  101. //
  102. var taskBasedSignatures = new List<Signature>();
  103. foreach (var signature in sync)
  104. {
  105. if (signature.ParameterTypes.Any(IsFuncOrActionType))
  106. {
  107. taskBasedSignatures.Add(GetAsyncVariant(signature));
  108. }
  109. }
  110. if (taskBasedSignatures.Count > 0)
  111. {
  112. var notInAsyncTaskBased = taskBasedSignatures.Except(async);
  113. if (notInAsyncTaskBased.Any())
  114. {
  115. foreach (var signature in notInAsyncTaskBased)
  116. {
  117. Console.WriteLine("MISSING " + name + " :: " + signature);
  118. }
  119. }
  120. }
  121. //
  122. // Excess overloads that are neither carbon copies of sync nor Task-based variants of sync.
  123. //
  124. var notInSync = async.Except(sync.Union(taskBasedSignatures));
  125. if (notInSync.Any())
  126. {
  127. foreach (var signature in notInSync)
  128. {
  129. Console.WriteLine("EXCESS " + ToString(signature.Method));
  130. }
  131. }
  132. }
  133. static void CompareAggregates(ILookup<string, MethodInfo> syncAggregates, ILookup<string, MethodInfo> asyncAggregates)
  134. {
  135. CompareSets(syncAggregates, asyncAggregates, CompareAggregateOverloads);
  136. }
  137. static void CompareAggregateOverloads(string name, IEnumerable<MethodInfo> syncMethods, IEnumerable<MethodInfo> asyncMethods)
  138. {
  139. var sync = GetSignatures(syncMethods).Select(GetAsyncAggregateSignature).ToArray();
  140. var async = GetRewrittenSignatures(asyncMethods).ToArray();
  141. //
  142. // Ensure that async is a superset of sync.
  143. //
  144. var notInAsync = sync.Except(async);
  145. if (notInAsync.Any())
  146. {
  147. foreach (var signature in notInAsync)
  148. {
  149. Console.WriteLine("MISSING " + ToString(signature.Method));
  150. }
  151. }
  152. //
  153. // Find Task-based overloads.
  154. //
  155. var taskBasedSignatures = new List<Signature>();
  156. foreach (var signature in sync)
  157. {
  158. if (signature.ParameterTypes.Any(IsFuncOrActionType))
  159. {
  160. taskBasedSignatures.Add(GetAsyncVariant(signature));
  161. }
  162. }
  163. if (taskBasedSignatures.Count > 0)
  164. {
  165. var notInAsyncTaskBased = taskBasedSignatures.Except(async);
  166. if (notInAsyncTaskBased.Any())
  167. {
  168. foreach (var signature in notInAsyncTaskBased)
  169. {
  170. Console.WriteLine("MISSING " + name + " :: " + signature);
  171. }
  172. }
  173. }
  174. //
  175. // Check for overloads with CancellationToken.
  176. //
  177. var withCancellationToken = new List<Signature>();
  178. foreach (var signature in sync)
  179. {
  180. withCancellationToken.Add(AppendCancellationToken(signature));
  181. }
  182. foreach (var signature in taskBasedSignatures)
  183. {
  184. withCancellationToken.Add(AppendCancellationToken(signature));
  185. }
  186. var notInAsyncWithCancellationToken = withCancellationToken.Except(async);
  187. if (notInAsyncWithCancellationToken.Any())
  188. {
  189. foreach (var signature in notInAsyncWithCancellationToken)
  190. {
  191. Console.WriteLine("MISSING " + name + " :: " + signature);
  192. }
  193. }
  194. //
  195. // Excess overloads that are neither carbon copies of sync nor Task-based variants of sync.
  196. //
  197. var notInSync = async.Except(sync.Union(taskBasedSignatures).Union(withCancellationToken));
  198. if (notInSync.Any())
  199. {
  200. foreach (var signature in notInSync)
  201. {
  202. Console.WriteLine("EXCESS " + ToString(signature.Method));
  203. }
  204. }
  205. }
  206. private static bool IsFuncOrActionType(Type type)
  207. {
  208. if (type.IsConstructedGenericType)
  209. {
  210. var defName = type.GetGenericTypeDefinition().Name;
  211. return defName.StartsWith("Func`") || defName.StartsWith("Action`");
  212. }
  213. if (type == typeof(Action))
  214. {
  215. return true;
  216. }
  217. return false;
  218. }
  219. private static Signature GetAsyncVariant(Signature signature)
  220. {
  221. return new Signature
  222. {
  223. ParameterTypes = signature.ParameterTypes.Select(GetAsyncVariant).ToArray(),
  224. ReturnType = signature.ReturnType
  225. };
  226. }
  227. private static Signature AppendCancellationToken(Signature signature)
  228. {
  229. return new Signature
  230. {
  231. ParameterTypes = signature.ParameterTypes.Concat(new[] { typeof(CancellationToken) }).ToArray(),
  232. ReturnType = signature.ReturnType
  233. };
  234. }
  235. private static Type GetAsyncVariant(Type type)
  236. {
  237. if (IsFuncOrActionType(type))
  238. {
  239. if (type == typeof(Action))
  240. {
  241. return typeof(Func<Task>);
  242. }
  243. else
  244. {
  245. var args = type.GetGenericArguments();
  246. var defName = type.GetGenericTypeDefinition().Name;
  247. if (defName.StartsWith("Func`"))
  248. {
  249. var ret = typeof(Task<>).MakeGenericType(args.Last());
  250. return Expression.GetFuncType(args.SkipLast(1).Append(ret).ToArray());
  251. }
  252. else
  253. {
  254. return Expression.GetFuncType(args.Append(typeof(Task)).ToArray());
  255. }
  256. }
  257. }
  258. return type;
  259. }
  260. static void CompareSets(ILookup<string, MethodInfo> sync, ILookup<string, MethodInfo> async, Action<string, IEnumerable<MethodInfo>, IEnumerable<MethodInfo>> compareCore)
  261. {
  262. var syncNames = sync.Select(g => g.Key).ToArray();
  263. var asyncNames = async.Select(g => g.Key).ToArray();
  264. //
  265. // Analyze that async is a superset of sync.
  266. //
  267. var notInAsync = syncNames.Except(asyncNames);
  268. foreach (var n in notInAsync)
  269. {
  270. foreach (var o in sync[n])
  271. {
  272. Console.WriteLine("MISSING " + ToString(o));
  273. }
  274. }
  275. //
  276. // Need to find the same overloads.
  277. //
  278. var inBoth = syncNames.Intersect(asyncNames);
  279. foreach (var n in inBoth)
  280. {
  281. var s = sync[n];
  282. var a = async[n];
  283. compareCore(n, s, a);
  284. }
  285. //
  286. // Report excessive API surface.
  287. //
  288. var onlyInAsync = asyncNames.Except(syncNames);
  289. foreach (var n in onlyInAsync)
  290. {
  291. foreach (var o in async[n])
  292. {
  293. Console.WriteLine("EXCESS " + ToString(o));
  294. }
  295. }
  296. }
  297. static Operators GetQueryOperators(Type[] interfaceTypes, Type operatorsType, string[] exclude)
  298. {
  299. //
  300. // Get all the static methods.
  301. //
  302. var methods = operatorsType.GetMethods(BindingFlags.Public | BindingFlags.Static).Where(m => !exclude.Contains(m.Name));
  303. //
  304. // Get extension methods. These can be either operators or aggregates.
  305. //
  306. var extensionMethods = methods.Where(m => m.IsDefined(typeof(ExtensionAttribute))).ToArray();
  307. //
  308. // Static methods that aren't extension methods can be factories.
  309. //
  310. var factories = methods.Except(extensionMethods).Where(m => m.ReturnType.IsConstructedGenericType && interfaceTypes.Contains(m.ReturnType.GetGenericTypeDefinition())).ToArray();
  311. //
  312. // Extension methods that return the interface type are operators.
  313. //
  314. var queryOperators = extensionMethods.Where(m => m.ReturnType.IsConstructedGenericType && interfaceTypes.Contains(m.ReturnType.GetGenericTypeDefinition())).ToArray();
  315. //
  316. // Extension methods that return another type are aggregates.
  317. //
  318. var aggregates = extensionMethods.Except(queryOperators).ToArray();
  319. //
  320. // Return operators.
  321. //
  322. return new Operators
  323. {
  324. Factories = factories.ToLookup(m => m.Name, m => m),
  325. QueryOperators = queryOperators.ToLookup(m => m.Name, m => m),
  326. Aggregates = aggregates.ToLookup(m => m.Name, m => m),
  327. };
  328. }
  329. static IEnumerable<Signature> GetSignatures(IEnumerable<MethodInfo> methods)
  330. {
  331. return methods.Select(m => GetSignature(m));
  332. }
  333. static IEnumerable<Signature> GetRewrittenSignatures(IEnumerable<MethodInfo> methods)
  334. {
  335. return GetSignatures(methods).Select(s => RewriteSignature(s));
  336. }
  337. static Signature GetSignature(MethodInfo method)
  338. {
  339. if (method.IsGenericMethodDefinition)
  340. {
  341. var newArgs = method.GetGenericArguments().Select((t, i) => Wildcards[i]).ToArray();
  342. method = method.MakeGenericMethod(newArgs);
  343. }
  344. return new Signature
  345. {
  346. Method = method,
  347. ReturnType = method.ReturnType,
  348. ParameterTypes = method.GetParameters().Select(p => p.ParameterType).ToArray()
  349. };
  350. }
  351. static Signature RewriteSignature(Signature signature)
  352. {
  353. return new Signature
  354. {
  355. Method = signature.Method,
  356. ReturnType = subst.Visit(signature.ReturnType),
  357. ParameterTypes = subst.Visit(signature.ParameterTypes)
  358. };
  359. }
  360. static Signature GetAsyncAggregateSignature(Signature signature)
  361. {
  362. var retType = signature.ReturnType == typeof(void) ? typeof(Task) : typeof(Task<>).MakeGenericType(signature.ReturnType);
  363. return new Signature
  364. {
  365. Method = signature.Method,
  366. ReturnType = retType,
  367. ParameterTypes = signature.ParameterTypes
  368. };
  369. }
  370. static string ToString(MethodInfo method)
  371. {
  372. if (method == null)
  373. {
  374. return "UNKNOWN";
  375. }
  376. if (method.IsGenericMethod && !method.IsGenericMethodDefinition)
  377. {
  378. method = method.GetGenericMethodDefinition();
  379. }
  380. return method.ToString();
  381. }
  382. class Operators
  383. {
  384. public ILookup<string, MethodInfo> Factories;
  385. public ILookup<string, MethodInfo> QueryOperators;
  386. public ILookup<string, MethodInfo> Aggregates;
  387. }
  388. class Signature : IEquatable<Signature>
  389. {
  390. public MethodInfo Method;
  391. public Type ReturnType;
  392. public Type[] ParameterTypes;
  393. public static bool operator ==(Signature s1, Signature s2)
  394. {
  395. if ((object)s1 == null && (object)s2 == null)
  396. {
  397. return true;
  398. }
  399. if ((object)s1 == null || (object)s2 == null)
  400. {
  401. return false;
  402. }
  403. return s1.Equals(s2);
  404. }
  405. public static bool operator !=(Signature s1, Signature s2)
  406. {
  407. return !(s1 == s2);
  408. }
  409. public bool Equals(Signature s)
  410. {
  411. return (object)s != null && ReturnType.Equals(s.ReturnType) && ParameterTypes.SequenceEqual(s.ParameterTypes);
  412. }
  413. public override bool Equals(object obj)
  414. {
  415. if (obj is Signature s)
  416. {
  417. return Equals(s);
  418. }
  419. return false;
  420. }
  421. public override int GetHashCode()
  422. {
  423. return ParameterTypes.Concat(new[] { ReturnType }).Aggregate(0, (a, t) => a * 17 + t.GetHashCode());
  424. }
  425. public override string ToString()
  426. {
  427. return "(" + string.Join(", ", ParameterTypes.Select(t => t.ToCSharp())) + ") -> " + ReturnType.ToCSharp();
  428. }
  429. }
  430. class TypeVisitor
  431. {
  432. public virtual Type Visit(Type type)
  433. {
  434. if (type.IsArray)
  435. {
  436. if (type.GetElementType().MakeArrayType() == type)
  437. {
  438. return VisitArray(type);
  439. }
  440. else
  441. {
  442. return VisitMultidimensionalArray(type);
  443. }
  444. }
  445. else if (type.GetTypeInfo().IsGenericTypeDefinition)
  446. {
  447. return VisitGenericTypeDefinition(type);
  448. }
  449. else if (type.IsConstructedGenericType)
  450. {
  451. return VisitGeneric(type);
  452. }
  453. else if (type.IsByRef)
  454. {
  455. return VisitByRef(type);
  456. }
  457. else if (type.IsPointer)
  458. {
  459. return VisitPointer(type);
  460. }
  461. else
  462. {
  463. return VisitSimple(type);
  464. }
  465. }
  466. protected virtual Type VisitArray(Type type)
  467. {
  468. return Visit(type.GetElementType()).MakeArrayType();
  469. }
  470. protected virtual Type VisitMultidimensionalArray(Type type)
  471. {
  472. return Visit(type.GetElementType()).MakeArrayType(type.GetArrayRank());
  473. }
  474. protected virtual Type VisitGenericTypeDefinition(Type type)
  475. {
  476. return type;
  477. }
  478. protected virtual Type VisitGeneric(Type type)
  479. {
  480. return Visit(type.GetGenericTypeDefinition()).MakeGenericType(Visit(type.GenericTypeArguments));
  481. }
  482. protected virtual Type VisitByRef(Type type)
  483. {
  484. return Visit(type.GetElementType()).MakeByRefType();
  485. }
  486. protected virtual Type VisitPointer(Type type)
  487. {
  488. return Visit(type.GetElementType()).MakePointerType();
  489. }
  490. protected virtual Type VisitSimple(Type type)
  491. {
  492. return type;
  493. }
  494. public Type[] Visit(Type[] types)
  495. {
  496. return types.Select(Visit).ToArray();
  497. }
  498. }
  499. class TypeSubstitutor : TypeVisitor
  500. {
  501. private readonly Dictionary<Type, Type> map;
  502. public TypeSubstitutor(Dictionary<Type, Type> map)
  503. {
  504. this.map = map;
  505. }
  506. public override Type Visit(Type type)
  507. {
  508. if (map.TryGetValue(type, out var subst))
  509. {
  510. return subst;
  511. }
  512. return base.Visit(type);
  513. }
  514. }
  515. private static readonly Type[] Wildcards = new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4) };
  516. class T1 { }
  517. class T2 { }
  518. class T3 { }
  519. class T4 { }
  520. }
  521. static class TypeExtensions
  522. {
  523. public static string ToCSharp(this Type type)
  524. {
  525. if (type.IsArray)
  526. {
  527. if (type.GetElementType().MakeArrayType() == type)
  528. {
  529. return type.GetElementType().ToCSharp() + "[]";
  530. }
  531. else
  532. {
  533. return type.GetElementType().ToCSharp() + "[" + new string(',', type.GetArrayRank() - 1) + "]";
  534. }
  535. }
  536. else if (type.IsConstructedGenericType)
  537. {
  538. var def = type.GetGenericTypeDefinition();
  539. var defName = def.Name.Substring(0, def.Name.IndexOf('`'));
  540. return defName + "<" + string.Join(", ", type.GetGenericArguments().Select(ToCSharp)) + ">";
  541. }
  542. else
  543. {
  544. return type.Name;
  545. }
  546. }
  547. }
  548. }