Program.cs 21 KB


  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System;
  5. using System.Collections.Generic;
  6. using System.Linq;
  7. using System.Linq.Expressions;
  8. using System.Reflection;
  9. using System.Runtime.CompilerServices;
  10. using System.Threading;
  11. using System.Threading.Tasks;
  12. namespace ApiCompare
  13. {
  14. class Program
  15. {
  16. private static readonly Type asyncInterfaceType = typeof(IAsyncEnumerable<>);
  17. private static readonly Type syncInterfaceType = typeof(IEnumerable<>);
  18. private static readonly Type asyncOrderedInterfaceType = typeof(IOrderedAsyncEnumerable<>);
  19. private static readonly Type syncOrderedInterfaceType = typeof(IOrderedEnumerable<>);
  20. private static readonly string[] exceptions = new[]
  21. {
  22. "SkipLast", // In .NET Core 2.0
  23. "TakeLast", // In .NET Core 2.0
  24. "Cast", // Non-generic methods
  25. "OfType", // Non-generic methods
  26. "AsEnumerable", // Trivially renamed
  27. "AsAsyncEnumerable", // Trivially renamed
  28. "ForEachAsync", // "foreach await" language substitute for the time being
  29. };
  30. private static readonly TypeSubstitutor subst = new TypeSubstitutor(new Dictionary<Type, Type>
  31. {
  32. { asyncInterfaceType, syncInterfaceType },
  33. { asyncOrderedInterfaceType, syncOrderedInterfaceType },
  34. { typeof(IAsyncGrouping<,>), typeof(IGrouping<,>) },
  35. });
  36. static void Main()
  37. {
  38. var asyncOperatorsType = typeof(AsyncEnumerable);
  39. var syncOperatorsType = typeof(Enumerable);
  40. Compare(syncOperatorsType, asyncOperatorsType);
  41. }
  42. static void Compare(Type syncOperatorsType, Type asyncOperatorsType)
  43. {
  44. var syncOperators = GetQueryOperators(new[] { syncInterfaceType, syncOrderedInterfaceType}, syncOperatorsType, exceptions);
  45. var asyncOperators = GetQueryOperators(new[] { asyncInterfaceType, asyncOrderedInterfaceType }, asyncOperatorsType, exceptions);
  46. CompareFactories(syncOperators.Factories, asyncOperators.Factories);
  47. CompareQueryOperators(syncOperators.QueryOperators, asyncOperators.QueryOperators);
  48. CompareAggregates(syncOperators.Aggregates, asyncOperators.Aggregates);
  49. }
  50. static void CompareFactories(ILookup<string, MethodInfo> syncFactories, ILookup<string, MethodInfo> asyncFactories)
  51. {
  52. CompareSets(syncFactories, asyncFactories, CompareFactoryOverloads);
  53. }
  54. static void CompareFactoryOverloads(string name, IEnumerable<MethodInfo> syncMethods, IEnumerable<MethodInfo> asyncMethods)
  55. {
  56. var sync = GetSignatures(syncMethods).ToArray();
  57. var async = GetRewrittenSignatures(asyncMethods).ToArray();
  58. //
  59. // Ensure that async is a superset of sync.
  60. //
  61. var notInAsync = sync.Except(async);
  62. if (notInAsync.Any())
  63. {
  64. foreach (var signature in notInAsync)
  65. {
  66. Console.WriteLine("MISSING " + ToString(signature.Method));
  67. }
  68. }
  69. //
  70. // Check for excess overloads.
  71. //
  72. var notInSync = async.Except(sync);
  73. if (notInSync.Any())
  74. {
  75. foreach (var signature in notInSync)
  76. {
  77. Console.WriteLine("EXCESS " + ToString(signature.Method));
  78. }
  79. }
  80. }
  81. static void CompareQueryOperators(ILookup<string, MethodInfo> syncOperators, ILookup<string, MethodInfo> asyncOperators)
  82. {
  83. CompareSets(syncOperators, asyncOperators, CompareQueryOperatorsOverloads);
  84. }
  85. static void CompareQueryOperatorsOverloads(string name, IEnumerable<MethodInfo> syncMethods, IEnumerable<MethodInfo> asyncMethods)
  86. {
  87. var sync = GetSignatures(syncMethods).ToArray();
  88. var async = GetRewrittenSignatures(asyncMethods).ToArray();
  89. //
  90. // Ensure that async is a superset of sync.
  91. //
  92. var notInAsync = sync.Except(async);
  93. if (notInAsync.Any())
  94. {
  95. foreach (var signature in notInAsync)
  96. {
  97. Console.WriteLine("MISSING " + ToString(signature.Method));
  98. }
  99. }
  100. //
  101. // Find Task-based overloads.
  102. //
  103. var taskBasedSignatures = new List<Signature>();
  104. foreach (var signature in sync)
  105. {
  106. if (signature.ParameterTypes.Any(IsFuncOrActionType))
  107. {
  108. taskBasedSignatures.Add(GetAsyncVariant(signature));
  109. }
  110. }
  111. if (taskBasedSignatures.Count > 0)
  112. {
  113. var notInAsyncTaskBased = taskBasedSignatures.Except(async);
  114. if (notInAsyncTaskBased.Any())
  115. {
  116. foreach (var signature in notInAsyncTaskBased)
  117. {
  118. Console.WriteLine("MISSING " + name + " :: " + signature);
  119. }
  120. }
  121. }
  122. //
  123. // Excess overloads that are neither carbon copies of sync nor Task-based variants of sync.
  124. //
  125. var notInSync = async.Except(sync.Union(taskBasedSignatures));
  126. if (notInSync.Any())
  127. {
  128. foreach (var signature in notInSync)
  129. {
  130. Console.WriteLine("EXCESS " + ToString(signature.Method));
  131. }
  132. }
  133. }
  134. static void CompareAggregates(ILookup<string, MethodInfo> syncAggregates, ILookup<string, MethodInfo> asyncAggregates)
  135. {
  136. CompareSets(syncAggregates, asyncAggregates, CompareAggregateOverloads);
  137. }
  138. static void CompareAggregateOverloads(string name, IEnumerable<MethodInfo> syncMethods, IEnumerable<MethodInfo> asyncMethods)
  139. {
  140. var sync = GetSignatures(syncMethods).Select(GetAsyncAggregateSignature).ToArray();
  141. var async = GetRewrittenSignatures(asyncMethods).ToArray();
  142. //
  143. // Ensure that async is a superset of sync.
  144. //
  145. var notInAsync = sync.Except(async);
  146. if (notInAsync.Any())
  147. {
  148. foreach (var signature in notInAsync)
  149. {
  150. Console.WriteLine("MISSING " + ToString(signature.Method));
  151. }
  152. }
  153. //
  154. // Find Task-based overloads.
  155. //
  156. var taskBasedSignatures = new List<Signature>();
  157. foreach (var signature in sync)
  158. {
  159. if (signature.ParameterTypes.Any(IsFuncOrActionType))
  160. {
  161. taskBasedSignatures.Add(GetAsyncVariant(signature));
  162. }
  163. }
  164. if (taskBasedSignatures.Count > 0)
  165. {
  166. var notInAsyncTaskBased = taskBasedSignatures.Except(async);
  167. if (notInAsyncTaskBased.Any())
  168. {
  169. foreach (var signature in notInAsyncTaskBased)
  170. {
  171. Console.WriteLine("MISSING " + name + " :: " + signature);
  172. }
  173. }
  174. }
  175. //
  176. // Check for overloads with CancellationToken.
  177. //
  178. var withCancellationToken = new List<Signature>();
  179. foreach (var signature in sync)
  180. {
  181. withCancellationToken.Add(AppendCancellationToken(signature));
  182. }
  183. foreach (var signature in taskBasedSignatures)
  184. {
  185. withCancellationToken.Add(AppendCancellationToken(signature));
  186. }
  187. var notInAsyncWithCancellationToken = withCancellationToken.Except(async);
  188. if (notInAsyncWithCancellationToken.Any())
  189. {
  190. foreach (var signature in notInAsyncWithCancellationToken)
  191. {
  192. Console.WriteLine("MISSING " + name + " :: " + signature);
  193. }
  194. }
  195. //
  196. // Excess overloads that are neither carbon copies of sync nor Task-based variants of sync.
  197. //
  198. var notInSync = async.Except(sync.Union(taskBasedSignatures).Union(withCancellationToken));
  199. if (notInSync.Any())
  200. {
  201. foreach (var signature in notInSync)
  202. {
  203. Console.WriteLine("EXCESS " + ToString(signature.Method));
  204. }
  205. }
  206. }
  207. private static bool IsFuncOrActionType(Type type)
  208. {
  209. if (type.IsConstructedGenericType)
  210. {
  211. var defName = type.GetGenericTypeDefinition().Name;
  212. return defName.StartsWith("Func`") || defName.StartsWith("Action`");
  213. }
  214. if (type == typeof(Action))
  215. {
  216. return true;
  217. }
  218. return false;
  219. }
  220. private static Signature GetAsyncVariant(Signature signature)
  221. {
  222. return new Signature
  223. {
  224. ParameterTypes = signature.ParameterTypes.Select(GetAsyncVariant).ToArray(),
  225. ReturnType = signature.ReturnType
  226. };
  227. }
  228. private static Signature AppendCancellationToken(Signature signature)
  229. {
  230. return new Signature
  231. {
  232. ParameterTypes = signature.ParameterTypes.Concat(new[] { typeof(CancellationToken) }).ToArray(),
  233. ReturnType = signature.ReturnType
  234. };
  235. }
  236. private static Type GetAsyncVariant(Type type)
  237. {
  238. if (IsFuncOrActionType(type))
  239. {
  240. if (type == typeof(Action))
  241. {
  242. return typeof(Func<Task>);
  243. }
  244. else
  245. {
  246. var args = type.GetGenericArguments();
  247. var defName = type.GetGenericTypeDefinition().Name;
  248. if (defName.StartsWith("Func`"))
  249. {
  250. var ret = typeof(Task<>).MakeGenericType(args.Last());
  251. return Expression.GetFuncType(args.SkipLast(1).Append(ret).ToArray());
  252. }
  253. else
  254. {
  255. return Expression.GetFuncType(args.Append(typeof(Task)).ToArray());
  256. }
  257. }
  258. }
  259. return type;
  260. }
  261. static void CompareSets(ILookup<string, MethodInfo> sync, ILookup<string, MethodInfo> async, Action<string, IEnumerable<MethodInfo>, IEnumerable<MethodInfo>> compareCore)
  262. {
  263. var syncNames = sync.Select(g => g.Key).ToArray();
  264. var asyncNames = async.Select(g => g.Key).ToArray();
  265. //
  266. // Analyze that async is a superset of sync.
  267. //
  268. var notInAsync = syncNames.Except(asyncNames);
  269. foreach (var n in notInAsync)
  270. {
  271. foreach (var o in sync[n])
  272. {
  273. Console.WriteLine("MISSING " + ToString(o));
  274. }
  275. }
  276. //
  277. // Need to find the same overloads.
  278. //
  279. var inBoth = syncNames.Intersect(asyncNames);
  280. foreach (var n in inBoth)
  281. {
  282. var s = sync[n];
  283. var a = async[n];
  284. compareCore(n, s, a);
  285. }
  286. //
  287. // Report excessive API surface.
  288. //
  289. var onlyInAsync = asyncNames.Except(syncNames);
  290. foreach (var n in onlyInAsync)
  291. {
  292. foreach (var o in async[n])
  293. {
  294. Console.WriteLine("EXCESS " + ToString(o));
  295. }
  296. }
  297. }
  298. static Operators GetQueryOperators(Type[] interfaceTypes, Type operatorsType, string[] exclude)
  299. {
  300. //
  301. // Get all the static methods.
  302. //
  303. var methods = operatorsType.GetMethods(BindingFlags.Public | BindingFlags.Static).Where(m => !exclude.Contains(m.Name));
  304. //
  305. // Get extension methods. These can be either operators or aggregates.
  306. //
  307. var extensionMethods = methods.Where(m => m.IsDefined(typeof(ExtensionAttribute))).ToArray();
  308. //
  309. // Static methods that aren't extension methods can be factories.
  310. //
  311. var factories = methods.Except(extensionMethods).Where(m => m.ReturnType.IsConstructedGenericType && interfaceTypes.Contains(m.ReturnType.GetGenericTypeDefinition())).ToArray();
  312. //
  313. // Extension methods that return the interface type are operators.
  314. //
  315. var queryOperators = extensionMethods.Where(m => m.ReturnType.IsConstructedGenericType && interfaceTypes.Contains(m.ReturnType.GetGenericTypeDefinition())).ToArray();
  316. //
  317. // Extension methods that return another type are aggregates.
  318. //
  319. var aggregates = extensionMethods.Except(queryOperators).ToArray();
  320. //
  321. // Return operators.
  322. //
  323. return new Operators
  324. {
  325. Factories = factories.ToLookup(m => m.Name, m => m),
  326. QueryOperators = queryOperators.ToLookup(m => m.Name, m => m),
  327. Aggregates = aggregates.ToLookup(m => m.Name, m => m),
  328. };
  329. }
  330. static IEnumerable<Signature> GetSignatures(IEnumerable<MethodInfo> methods)
  331. {
  332. return methods.Select(m => GetSignature(m));
  333. }
  334. static IEnumerable<Signature> GetRewrittenSignatures(IEnumerable<MethodInfo> methods)
  335. {
  336. return GetSignatures(methods).Select(s => RewriteSignature(s));
  337. }
  338. static Signature GetSignature(MethodInfo method)
  339. {
  340. if (method.IsGenericMethodDefinition)
  341. {
  342. var newArgs = method.GetGenericArguments().Select((t, i) => Wildcards[i]).ToArray();
  343. method = method.MakeGenericMethod(newArgs);
  344. }
  345. return new Signature
  346. {
  347. Method = method,
  348. ReturnType = method.ReturnType,
  349. ParameterTypes = method.GetParameters().Select(p => p.ParameterType).ToArray()
  350. };
  351. }
  352. static Signature RewriteSignature(Signature signature)
  353. {
  354. return new Signature
  355. {
  356. Method = signature.Method,
  357. ReturnType = subst.Visit(signature.ReturnType),
  358. ParameterTypes = subst.Visit(signature.ParameterTypes)
  359. };
  360. }
  361. static Signature GetAsyncAggregateSignature(Signature signature)
  362. {
  363. var retType = signature.ReturnType == typeof(void) ? typeof(Task) : typeof(Task<>).MakeGenericType(signature.ReturnType);
  364. return new Signature
  365. {
  366. Method = signature.Method,
  367. ReturnType = retType,
  368. ParameterTypes = signature.ParameterTypes
  369. };
  370. }
  371. static string ToString(MethodInfo method)
  372. {
  373. if (method == null)
  374. {
  375. return "UNKNOWN";
  376. }
  377. if (method.IsGenericMethod && !method.IsGenericMethodDefinition)
  378. {
  379. method = method.GetGenericMethodDefinition();
  380. }
  381. return method.ToString();
  382. }
  383. class Operators
  384. {
  385. public ILookup<string, MethodInfo> Factories;
  386. public ILookup<string, MethodInfo> QueryOperators;
  387. public ILookup<string, MethodInfo> Aggregates;
  388. }
  389. class Signature : IEquatable<Signature>
  390. {
  391. public MethodInfo Method;
  392. public Type ReturnType;
  393. public Type[] ParameterTypes;
  394. public static bool operator ==(Signature s1, Signature s2)
  395. {
  396. if ((object)s1 == null && (object)s2 == null)
  397. {
  398. return true;
  399. }
  400. if ((object)s1 == null || (object)s2 == null)
  401. {
  402. return false;
  403. }
  404. return s1.Equals(s2);
  405. }
  406. public static bool operator !=(Signature s1, Signature s2)
  407. {
  408. return !(s1 == s2);
  409. }
  410. public bool Equals(Signature s)
  411. {
  412. return (object)s != null && ReturnType.Equals(s.ReturnType) && ParameterTypes.SequenceEqual(s.ParameterTypes);
  413. }
  414. public override bool Equals(object obj)
  415. {
  416. if (obj is Signature s)
  417. {
  418. return Equals(s);
  419. }
  420. return false;
  421. }
  422. public override int GetHashCode()
  423. {
  424. return ParameterTypes.Concat(new[] { ReturnType }).Aggregate(0, (a, t) => a * 17 + t.GetHashCode());
  425. }
  426. public override string ToString()
  427. {
  428. return "(" + string.Join(", ", ParameterTypes.Select(t => t.ToCSharp())) + ") -> " + ReturnType.ToCSharp();
  429. }
  430. }
  431. class TypeVisitor
  432. {
  433. public virtual Type Visit(Type type)
  434. {
  435. if (type.IsArray)
  436. {
  437. if (type.GetElementType().MakeArrayType() == type)
  438. {
  439. return VisitArray(type);
  440. }
  441. else
  442. {
  443. return VisitMultidimensionalArray(type);
  444. }
  445. }
  446. else if (type.GetTypeInfo().IsGenericTypeDefinition)
  447. {
  448. return VisitGenericTypeDefinition(type);
  449. }
  450. else if (type.IsConstructedGenericType)
  451. {
  452. return VisitGeneric(type);
  453. }
  454. else if (type.IsByRef)
  455. {
  456. return VisitByRef(type);
  457. }
  458. else if (type.IsPointer)
  459. {
  460. return VisitPointer(type);
  461. }
  462. else
  463. {
  464. return VisitSimple(type);
  465. }
  466. }
  467. protected virtual Type VisitArray(Type type)
  468. {
  469. return Visit(type.GetElementType()).MakeArrayType();
  470. }
  471. protected virtual Type VisitMultidimensionalArray(Type type)
  472. {
  473. return Visit(type.GetElementType()).MakeArrayType(type.GetArrayRank());
  474. }
  475. protected virtual Type VisitGenericTypeDefinition(Type type)
  476. {
  477. return type;
  478. }
  479. protected virtual Type VisitGeneric(Type type)
  480. {
  481. return Visit(type.GetGenericTypeDefinition()).MakeGenericType(Visit(type.GenericTypeArguments));
  482. }
  483. protected virtual Type VisitByRef(Type type)
  484. {
  485. return Visit(type.GetElementType()).MakeByRefType();
  486. }
  487. protected virtual Type VisitPointer(Type type)
  488. {
  489. return Visit(type.GetElementType()).MakePointerType();
  490. }
  491. protected virtual Type VisitSimple(Type type)
  492. {
  493. return type;
  494. }
  495. public Type[] Visit(Type[] types)
  496. {
  497. return types.Select(Visit).ToArray();
  498. }
  499. }
  500. class TypeSubstitutor : TypeVisitor
  501. {
  502. private readonly Dictionary<Type, Type> map;
  503. public TypeSubstitutor(Dictionary<Type, Type> map)
  504. {
  505. this.map = map;
  506. }
  507. public override Type Visit(Type type)
  508. {
  509. if (map.TryGetValue(type, out var subst))
  510. {
  511. return subst;
  512. }
  513. return base.Visit(type);
  514. }
  515. }
  516. private static readonly Type[] Wildcards = new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4) };
  517. class T1 { }
  518. class T2 { }
  519. class T3 { }
  520. class T4 { }
  521. }
  522. static class TypeExtensions
  523. {
  524. public static string ToCSharp(this Type type)
  525. {
  526. if (type.IsArray)
  527. {
  528. if (type.GetElementType().MakeArrayType() == type)
  529. {
  530. return type.GetElementType().ToCSharp() + "[]";
  531. }
  532. else
  533. {
  534. return type.GetElementType().ToCSharp() + "[" + new string(',', type.GetArrayRank() - 1) + "]";
  535. }
  536. }
  537. else if (type.IsConstructedGenericType)
  538. {
  539. var def = type.GetGenericTypeDefinition();
  540. var defName = def.Name.Substring(0, def.Name.IndexOf('`'));
  541. return defName + "<" + string.Join(", ", type.GetGenericArguments().Select(ToCSharp)) + ">";
  542. }
  543. else
  544. {
  545. return type.Name;
  546. }
  547. }
  548. }
  549. }