1
0

Program.cs 21 KB


  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System;
  5. using System.Collections.Generic;
  6. using System.Linq;
  7. using System.Linq.Expressions;
  8. using System.Reflection;
  9. using System.Runtime.CompilerServices;
  10. using System.Threading;
  11. using System.Threading.Tasks;
  12. namespace ApiCompare
  13. {
  14. class Program
  15. {
  16. private static readonly Type asyncInterfaceType = typeof(IAsyncEnumerable<>);
  17. private static readonly Type syncInterfaceType = typeof(IEnumerable<>);
  18. private static readonly Type asyncOrderedInterfaceType = typeof(IOrderedAsyncEnumerable<>);
  19. private static readonly Type syncOrderedInterfaceType = typeof(IOrderedEnumerable<>);
  20. private static readonly string[] exceptions = new[]
  21. {
  22. "SkipLast", // In .NET Core 2.0
  23. "TakeLast", // In .NET Core 2.0
  24. "ToHashSet", // In .NET Core 2.0
  25. "Cast", // Non-generic methods
  26. "OfType", // Non-generic methods
  27. "AsEnumerable", // Trivially renamed
  28. "AsAsyncEnumerable", // Trivially renamed
  29. "ForEachAsync", // "foreach await" language substitute for the time being
  30. };
  31. private static readonly TypeSubstitutor subst = new TypeSubstitutor(new Dictionary<Type, Type>
  32. {
  33. { asyncInterfaceType, syncInterfaceType },
  34. { asyncOrderedInterfaceType, syncOrderedInterfaceType },
  35. { typeof(IAsyncGrouping<,>), typeof(IGrouping<,>) },
  36. });
  37. static void Main()
  38. {
  39. var asyncOperatorsType = typeof(AsyncEnumerable);
  40. var syncOperatorsType = typeof(Enumerable);
  41. Compare(syncOperatorsType, asyncOperatorsType);
  42. }
  43. static void Compare(Type syncOperatorsType, Type asyncOperatorsType)
  44. {
  45. var syncOperators = GetQueryOperators(new[] { syncInterfaceType, syncOrderedInterfaceType}, syncOperatorsType, exceptions);
  46. var asyncOperators = GetQueryOperators(new[] { asyncInterfaceType, asyncOrderedInterfaceType }, asyncOperatorsType, exceptions);
  47. CompareFactories(syncOperators.Factories, asyncOperators.Factories);
  48. CompareQueryOperators(syncOperators.QueryOperators, asyncOperators.QueryOperators);
  49. CompareAggregates(syncOperators.Aggregates, asyncOperators.Aggregates);
  50. }
  51. static void CompareFactories(ILookup<string, MethodInfo> syncFactories, ILookup<string, MethodInfo> asyncFactories)
  52. {
  53. CompareSets(syncFactories, asyncFactories, CompareFactoryOverloads);
  54. }
  55. static void CompareFactoryOverloads(string name, IEnumerable<MethodInfo> syncMethods, IEnumerable<MethodInfo> asyncMethods)
  56. {
  57. var sync = GetSignatures(syncMethods).ToArray();
  58. var async = GetRewrittenSignatures(asyncMethods).ToArray();
  59. //
  60. // Ensure that async is a superset of sync.
  61. //
  62. var notInAsync = sync.Except(async);
  63. if (notInAsync.Any())
  64. {
  65. foreach (var signature in notInAsync)
  66. {
  67. Console.WriteLine("MISSING " + ToString(signature.Method));
  68. }
  69. }
  70. //
  71. // Check for excess overloads.
  72. //
  73. var notInSync = async.Except(sync);
  74. if (notInSync.Any())
  75. {
  76. foreach (var signature in notInSync)
  77. {
  78. Console.WriteLine("EXCESS " + ToString(signature.Method));
  79. }
  80. }
  81. }
  82. static void CompareQueryOperators(ILookup<string, MethodInfo> syncOperators, ILookup<string, MethodInfo> asyncOperators)
  83. {
  84. CompareSets(syncOperators, asyncOperators, CompareQueryOperatorsOverloads);
  85. }
  86. static void CompareQueryOperatorsOverloads(string name, IEnumerable<MethodInfo> syncMethods, IEnumerable<MethodInfo> asyncMethods)
  87. {
  88. var sync = GetSignatures(syncMethods).ToArray();
  89. var async = GetRewrittenSignatures(asyncMethods).ToArray();
  90. //
  91. // Ensure that async is a superset of sync.
  92. //
  93. var notInAsync = sync.Except(async);
  94. if (notInAsync.Any())
  95. {
  96. foreach (var signature in notInAsync)
  97. {
  98. Console.WriteLine("MISSING " + ToString(signature.Method));
  99. }
  100. }
  101. //
  102. // Find Task-based overloads.
  103. //
  104. var taskBasedSignatures = new List<Signature>();
  105. foreach (var signature in sync)
  106. {
  107. if (signature.ParameterTypes.Any(IsFuncOrActionType))
  108. {
  109. taskBasedSignatures.Add(GetAsyncVariant(signature));
  110. }
  111. }
  112. if (taskBasedSignatures.Count > 0)
  113. {
  114. var notInAsyncTaskBased = taskBasedSignatures.Except(async);
  115. if (notInAsyncTaskBased.Any())
  116. {
  117. foreach (var signature in notInAsyncTaskBased)
  118. {
  119. Console.WriteLine("MISSING " + name + " :: " + signature);
  120. }
  121. }
  122. }
  123. //
  124. // Excess overloads that are neither carbon copies of sync nor Task-based variants of sync.
  125. //
  126. var notInSync = async.Except(sync.Union(taskBasedSignatures));
  127. if (notInSync.Any())
  128. {
  129. foreach (var signature in notInSync)
  130. {
  131. Console.WriteLine("EXCESS " + ToString(signature.Method));
  132. }
  133. }
  134. }
  135. static void CompareAggregates(ILookup<string, MethodInfo> syncAggregates, ILookup<string, MethodInfo> asyncAggregates)
  136. {
  137. CompareSets(syncAggregates, asyncAggregates, CompareAggregateOverloads);
  138. }
  139. static void CompareAggregateOverloads(string name, IEnumerable<MethodInfo> syncMethods, IEnumerable<MethodInfo> asyncMethods)
  140. {
  141. var sync = GetSignatures(syncMethods).Select(GetAsyncAggregateSignature).ToArray();
  142. var async = GetRewrittenSignatures(asyncMethods).ToArray();
  143. //
  144. // Ensure that async is a superset of sync.
  145. //
  146. var notInAsync = sync.Except(async);
  147. if (notInAsync.Any())
  148. {
  149. foreach (var signature in notInAsync)
  150. {
  151. Console.WriteLine("MISSING " + ToString(signature.Method));
  152. }
  153. }
  154. //
  155. // Find Task-based overloads.
  156. //
  157. var taskBasedSignatures = new List<Signature>();
  158. foreach (var signature in sync)
  159. {
  160. if (signature.ParameterTypes.Any(IsFuncOrActionType))
  161. {
  162. taskBasedSignatures.Add(GetAsyncVariant(signature));
  163. }
  164. }
  165. if (taskBasedSignatures.Count > 0)
  166. {
  167. var notInAsyncTaskBased = taskBasedSignatures.Except(async);
  168. if (notInAsyncTaskBased.Any())
  169. {
  170. foreach (var signature in notInAsyncTaskBased)
  171. {
  172. Console.WriteLine("MISSING " + name + " :: " + signature);
  173. }
  174. }
  175. }
  176. //
  177. // Check for overloads with CancellationToken.
  178. //
  179. var withCancellationToken = new List<Signature>();
  180. foreach (var signature in sync)
  181. {
  182. withCancellationToken.Add(AppendCancellationToken(signature));
  183. }
  184. foreach (var signature in taskBasedSignatures)
  185. {
  186. withCancellationToken.Add(AppendCancellationToken(signature));
  187. }
  188. var notInAsyncWithCancellationToken = withCancellationToken.Except(async);
  189. if (notInAsyncWithCancellationToken.Any())
  190. {
  191. foreach (var signature in notInAsyncWithCancellationToken)
  192. {
  193. Console.WriteLine("MISSING " + name + " :: " + signature);
  194. }
  195. }
  196. //
  197. // Excess overloads that are neither carbon copies of sync nor Task-based variants of sync.
  198. //
  199. var notInSync = async.Except(sync.Union(taskBasedSignatures).Union(withCancellationToken));
  200. if (notInSync.Any())
  201. {
  202. foreach (var signature in notInSync)
  203. {
  204. Console.WriteLine("EXCESS " + ToString(signature.Method));
  205. }
  206. }
  207. }
  208. private static bool IsFuncOrActionType(Type type)
  209. {
  210. if (type.IsConstructedGenericType)
  211. {
  212. var defName = type.GetGenericTypeDefinition().Name;
  213. return defName.StartsWith("Func`") || defName.StartsWith("Action`");
  214. }
  215. if (type == typeof(Action))
  216. {
  217. return true;
  218. }
  219. return false;
  220. }
  221. private static Signature GetAsyncVariant(Signature signature)
  222. {
  223. return new Signature
  224. {
  225. ParameterTypes = signature.ParameterTypes.Select(GetAsyncVariant).ToArray(),
  226. ReturnType = signature.ReturnType
  227. };
  228. }
  229. private static Signature AppendCancellationToken(Signature signature)
  230. {
  231. return new Signature
  232. {
  233. ParameterTypes = signature.ParameterTypes.Concat(new[] { typeof(CancellationToken) }).ToArray(),
  234. ReturnType = signature.ReturnType
  235. };
  236. }
  237. private static Type GetAsyncVariant(Type type)
  238. {
  239. if (IsFuncOrActionType(type))
  240. {
  241. if (type == typeof(Action))
  242. {
  243. return typeof(Func<Task>);
  244. }
  245. else
  246. {
  247. var args = type.GetGenericArguments();
  248. var defName = type.GetGenericTypeDefinition().Name;
  249. if (defName.StartsWith("Func`"))
  250. {
  251. var ret = typeof(Task<>).MakeGenericType(args.Last());
  252. return Expression.GetFuncType(args.SkipLast(1).Append(ret).ToArray());
  253. }
  254. else
  255. {
  256. return Expression.GetFuncType(args.Append(typeof(Task)).ToArray());
  257. }
  258. }
  259. }
  260. return type;
  261. }
  262. static void CompareSets(ILookup<string, MethodInfo> sync, ILookup<string, MethodInfo> async, Action<string, IEnumerable<MethodInfo>, IEnumerable<MethodInfo>> compareCore)
  263. {
  264. var syncNames = sync.Select(g => g.Key).ToArray();
  265. var asyncNames = async.Select(g => g.Key).ToArray();
  266. //
  267. // Analyze that async is a superset of sync.
  268. //
  269. var notInAsync = syncNames.Except(asyncNames);
  270. foreach (var n in notInAsync)
  271. {
  272. foreach (var o in sync[n])
  273. {
  274. Console.WriteLine("MISSING " + ToString(o));
  275. }
  276. }
  277. //
  278. // Need to find the same overloads.
  279. //
  280. var inBoth = syncNames.Intersect(asyncNames);
  281. foreach (var n in inBoth)
  282. {
  283. var s = sync[n];
  284. var a = async[n];
  285. compareCore(n, s, a);
  286. }
  287. //
  288. // Report excessive API surface.
  289. //
  290. var onlyInAsync = asyncNames.Except(syncNames);
  291. foreach (var n in onlyInAsync)
  292. {
  293. foreach (var o in async[n])
  294. {
  295. Console.WriteLine("EXCESS " + ToString(o));
  296. }
  297. }
  298. }
  299. static Operators GetQueryOperators(Type[] interfaceTypes, Type operatorsType, string[] exclude)
  300. {
  301. //
  302. // Get all the static methods.
  303. //
  304. var methods = operatorsType.GetMethods(BindingFlags.Public | BindingFlags.Static).Where(m => !exclude.Contains(m.Name));
  305. //
  306. // Get extension methods. These can be either operators or aggregates.
  307. //
  308. var extensionMethods = methods.Where(m => m.IsDefined(typeof(ExtensionAttribute))).ToArray();
  309. //
  310. // Static methods that aren't extension methods can be factories.
  311. //
  312. var factories = methods.Except(extensionMethods).Where(m => m.ReturnType.IsConstructedGenericType && interfaceTypes.Contains(m.ReturnType.GetGenericTypeDefinition())).ToArray();
  313. //
  314. // Extension methods that return the interface type are operators.
  315. //
  316. var queryOperators = extensionMethods.Where(m => m.ReturnType.IsConstructedGenericType && interfaceTypes.Contains(m.ReturnType.GetGenericTypeDefinition())).ToArray();
  317. //
  318. // Extension methods that return another type are aggregates.
  319. //
  320. var aggregates = extensionMethods.Except(queryOperators).ToArray();
  321. //
  322. // Return operators.
  323. //
  324. return new Operators
  325. {
  326. Factories = factories.ToLookup(m => m.Name, m => m),
  327. QueryOperators = queryOperators.ToLookup(m => m.Name, m => m),
  328. Aggregates = aggregates.ToLookup(m => m.Name, m => m),
  329. };
  330. }
  331. static IEnumerable<Signature> GetSignatures(IEnumerable<MethodInfo> methods)
  332. {
  333. return methods.Select(m => GetSignature(m));
  334. }
  335. static IEnumerable<Signature> GetRewrittenSignatures(IEnumerable<MethodInfo> methods)
  336. {
  337. return GetSignatures(methods).Select(s => RewriteSignature(s));
  338. }
  339. static Signature GetSignature(MethodInfo method)
  340. {
  341. if (method.IsGenericMethodDefinition)
  342. {
  343. var newArgs = method.GetGenericArguments().Select((t, i) => Wildcards[i]).ToArray();
  344. method = method.MakeGenericMethod(newArgs);
  345. }
  346. return new Signature
  347. {
  348. Method = method,
  349. ReturnType = method.ReturnType,
  350. ParameterTypes = method.GetParameters().Select(p => p.ParameterType).ToArray()
  351. };
  352. }
  353. static Signature RewriteSignature(Signature signature)
  354. {
  355. return new Signature
  356. {
  357. Method = signature.Method,
  358. ReturnType = subst.Visit(signature.ReturnType),
  359. ParameterTypes = subst.Visit(signature.ParameterTypes)
  360. };
  361. }
  362. static Signature GetAsyncAggregateSignature(Signature signature)
  363. {
  364. var retType = signature.ReturnType == typeof(void) ? typeof(Task) : typeof(Task<>).MakeGenericType(signature.ReturnType);
  365. return new Signature
  366. {
  367. Method = signature.Method,
  368. ReturnType = retType,
  369. ParameterTypes = signature.ParameterTypes
  370. };
  371. }
  372. static string ToString(MethodInfo method)
  373. {
  374. if (method == null)
  375. {
  376. return "UNKNOWN";
  377. }
  378. if (method.IsGenericMethod && !method.IsGenericMethodDefinition)
  379. {
  380. method = method.GetGenericMethodDefinition();
  381. }
  382. return method.ToString();
  383. }
  384. class Operators
  385. {
  386. public ILookup<string, MethodInfo> Factories;
  387. public ILookup<string, MethodInfo> QueryOperators;
  388. public ILookup<string, MethodInfo> Aggregates;
  389. }
  390. class Signature : IEquatable<Signature>
  391. {
  392. public MethodInfo Method;
  393. public Type ReturnType;
  394. public Type[] ParameterTypes;
  395. public static bool operator ==(Signature s1, Signature s2)
  396. {
  397. if ((object)s1 == null && (object)s2 == null)
  398. {
  399. return true;
  400. }
  401. if ((object)s1 == null || (object)s2 == null)
  402. {
  403. return false;
  404. }
  405. return s1.Equals(s2);
  406. }
  407. public static bool operator !=(Signature s1, Signature s2)
  408. {
  409. return !(s1 == s2);
  410. }
  411. public bool Equals(Signature s)
  412. {
  413. return (object)s != null && ReturnType.Equals(s.ReturnType) && ParameterTypes.SequenceEqual(s.ParameterTypes);
  414. }
  415. public override bool Equals(object obj)
  416. {
  417. if (obj is Signature s)
  418. {
  419. return Equals(s);
  420. }
  421. return false;
  422. }
  423. public override int GetHashCode()
  424. {
  425. return ParameterTypes.Concat(new[] { ReturnType }).Aggregate(0, (a, t) => a * 17 + t.GetHashCode());
  426. }
  427. public override string ToString()
  428. {
  429. return "(" + string.Join(", ", ParameterTypes.Select(t => t.ToCSharp())) + ") -> " + ReturnType.ToCSharp();
  430. }
  431. }
  432. class TypeVisitor
  433. {
  434. public virtual Type Visit(Type type)
  435. {
  436. if (type.IsArray)
  437. {
  438. if (type.GetElementType().MakeArrayType() == type)
  439. {
  440. return VisitArray(type);
  441. }
  442. else
  443. {
  444. return VisitMultidimensionalArray(type);
  445. }
  446. }
  447. else if (type.GetTypeInfo().IsGenericTypeDefinition)
  448. {
  449. return VisitGenericTypeDefinition(type);
  450. }
  451. else if (type.IsConstructedGenericType)
  452. {
  453. return VisitGeneric(type);
  454. }
  455. else if (type.IsByRef)
  456. {
  457. return VisitByRef(type);
  458. }
  459. else if (type.IsPointer)
  460. {
  461. return VisitPointer(type);
  462. }
  463. else
  464. {
  465. return VisitSimple(type);
  466. }
  467. }
  468. protected virtual Type VisitArray(Type type)
  469. {
  470. return Visit(type.GetElementType()).MakeArrayType();
  471. }
  472. protected virtual Type VisitMultidimensionalArray(Type type)
  473. {
  474. return Visit(type.GetElementType()).MakeArrayType(type.GetArrayRank());
  475. }
  476. protected virtual Type VisitGenericTypeDefinition(Type type)
  477. {
  478. return type;
  479. }
  480. protected virtual Type VisitGeneric(Type type)
  481. {
  482. return Visit(type.GetGenericTypeDefinition()).MakeGenericType(Visit(type.GenericTypeArguments));
  483. }
  484. protected virtual Type VisitByRef(Type type)
  485. {
  486. return Visit(type.GetElementType()).MakeByRefType();
  487. }
  488. protected virtual Type VisitPointer(Type type)
  489. {
  490. return Visit(type.GetElementType()).MakePointerType();
  491. }
  492. protected virtual Type VisitSimple(Type type)
  493. {
  494. return type;
  495. }
  496. public Type[] Visit(Type[] types)
  497. {
  498. return types.Select(Visit).ToArray();
  499. }
  500. }
  501. class TypeSubstitutor : TypeVisitor
  502. {
  503. private readonly Dictionary<Type, Type> map;
  504. public TypeSubstitutor(Dictionary<Type, Type> map)
  505. {
  506. this.map = map;
  507. }
  508. public override Type Visit(Type type)
  509. {
  510. if (map.TryGetValue(type, out var subst))
  511. {
  512. return subst;
  513. }
  514. return base.Visit(type);
  515. }
  516. }
  517. private static readonly Type[] Wildcards = new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4) };
  518. class T1 { }
  519. class T2 { }
  520. class T3 { }
  521. class T4 { }
  522. }
  523. static class TypeExtensions
  524. {
  525. public static string ToCSharp(this Type type)
  526. {
  527. if (type.IsArray)
  528. {
  529. if (type.GetElementType().MakeArrayType() == type)
  530. {
  531. return type.GetElementType().ToCSharp() + "[]";
  532. }
  533. else
  534. {
  535. return type.GetElementType().ToCSharp() + "[" + new string(',', type.GetArrayRank() - 1) + "]";
  536. }
  537. }
  538. else if (type.IsConstructedGenericType)
  539. {
  540. var def = type.GetGenericTypeDefinition();
  541. var defName = def.Name.Substring(0, def.Name.IndexOf('`'));
  542. return defName + "<" + string.Join(", ", type.GetGenericArguments().Select(ToCSharp)) + ">";
  543. }
  544. else
  545. {
  546. return type.Name;
  547. }
  548. }
  549. }
  550. }