Program.cs 22 KB


  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the MIT License.
  3. // See the LICENSE file in the project root for more information.
  4. using System;
  5. using System.Collections.Generic;
  6. using System.Linq;
  7. using System.Linq.Expressions;
  8. using System.Reflection;
  9. using System.Runtime.CompilerServices;
  10. using System.Threading;
  11. using System.Threading.Tasks;
  12. namespace ApiCompare
  13. {
  14. internal class Program
  15. {
  16. private static readonly Type AsyncInterfaceType = typeof(IAsyncEnumerable<>);
  17. private static readonly Type SyncInterfaceType = typeof(IEnumerable<>);
  18. private static readonly Type AsyncOrderedInterfaceType = typeof(IOrderedAsyncEnumerable<>);
  19. private static readonly Type SyncOrderedInterfaceType = typeof(IOrderedEnumerable<>);
  20. private static readonly string[] Exceptions =
  21. [
  22. "SkipLast", // In .NET Core 2.0
  23. "TakeLast", // In .NET Core 2.0
  24. "ToHashSet", // In .NET Core 2.0
  25. "Cast", // Non-generic methods
  26. "OfType", // Non-generic methods
  27. "AsEnumerable", // Trivially renamed
  28. "AsAsyncEnumerable", // Trivially renamed
  29. "ForEachAsync", // "foreach await" language substitute for the time being
  30. "ToAsyncEnumerable", // First-class conversions
  31. "ToEnumerable", // First-class conversions
  32. "ToObservable", // First-class conversions
  33. ];
  34. private static readonly TypeSubstitutor Subst = new(new Dictionary<Type, Type>
  35. {
  36. { AsyncInterfaceType, SyncInterfaceType },
  37. { AsyncOrderedInterfaceType, SyncOrderedInterfaceType },
  38. { typeof(IAsyncGrouping<,>), typeof(IGrouping<,>) },
  39. });
  40. private static void Main()
  41. {
  42. var asyncOperatorsType = typeof(AsyncEnumerable);
  43. var syncOperatorsType = typeof(Enumerable);
  44. Compare(syncOperatorsType, asyncOperatorsType);
  45. }
  46. private static void Compare(Type syncOperatorsType, Type asyncOperatorsType)
  47. {
  48. var syncOperators = GetQueryOperators([SyncInterfaceType, SyncOrderedInterfaceType], syncOperatorsType, Exceptions);
  49. var asyncOperators = GetQueryOperators([AsyncInterfaceType, AsyncOrderedInterfaceType], asyncOperatorsType, Exceptions);
  50. CompareFactories(syncOperators.Factories, asyncOperators.Factories);
  51. CompareQueryOperators(syncOperators.QueryOperators, asyncOperators.QueryOperators);
  52. CompareAggregates(syncOperators.Aggregates, asyncOperators.Aggregates);
  53. }
  54. private static void CompareFactories(ILookup<string, MethodInfo> syncFactories, ILookup<string, MethodInfo> asyncFactories)
  55. {
  56. CompareSets(syncFactories, asyncFactories, CompareFactoryOverloads);
  57. }
  58. private static void CompareFactoryOverloads(string name, IEnumerable<MethodInfo> syncMethods, IEnumerable<MethodInfo> asyncMethods)
  59. {
  60. var sync = GetSignatures(syncMethods).ToArray();
  61. var async = GetRewrittenSignatures(asyncMethods).ToArray();
  62. //
  63. // Ensure that async is a superset of sync.
  64. //
  65. var notInAsync = sync.Except(async);
  66. if (notInAsync.Any())
  67. {
  68. foreach (var signature in notInAsync)
  69. {
  70. Console.WriteLine("MISSING " + ToString(signature.Method));
  71. }
  72. }
  73. //
  74. // Check for excess overloads.
  75. //
  76. var notInSync = async.Except(sync);
  77. if (notInSync.Any())
  78. {
  79. foreach (var signature in notInSync)
  80. {
  81. Console.WriteLine("EXCESS " + ToString(signature.Method));
  82. }
  83. }
  84. }
  85. private static void CompareQueryOperators(ILookup<string, MethodInfo> syncOperators, ILookup<string, MethodInfo> asyncOperators)
  86. {
  87. CompareSets(syncOperators, asyncOperators, CompareQueryOperatorsOverloads);
  88. }
  89. private static void CompareQueryOperatorsOverloads(string name, IEnumerable<MethodInfo> syncMethods, IEnumerable<MethodInfo> asyncMethods)
  90. {
  91. var sync = GetSignatures(syncMethods).ToArray();
  92. var async = GetRewrittenSignatures(asyncMethods).ToArray();
  93. //
  94. // Ensure that async is a superset of sync.
  95. //
  96. var notInAsync = sync.Except(async);
  97. if (notInAsync.Any())
  98. {
  99. foreach (var signature in notInAsync)
  100. {
  101. Console.WriteLine("MISSING " + ToString(signature.Method));
  102. }
  103. }
  104. //
  105. // Find Task-based overloads.
  106. //
  107. var taskBasedSignatures = new List<Signature>();
  108. foreach (var signature in sync)
  109. {
  110. if (signature.ParameterTypes.Any(IsFuncOrActionType))
  111. {
  112. taskBasedSignatures.Add(GetAsyncVariant(signature));
  113. }
  114. }
  115. if (taskBasedSignatures.Count > 0)
  116. {
  117. var notInAsyncTaskBased = taskBasedSignatures.Except(async);
  118. if (notInAsyncTaskBased.Any())
  119. {
  120. foreach (var signature in notInAsyncTaskBased)
  121. {
  122. Console.WriteLine("MISSING " + name + " :: " + signature);
  123. }
  124. }
  125. }
  126. //
  127. // Excess overloads that are neither carbon copies of sync nor Task-based variants of sync.
  128. //
  129. var notInSync = async.Except(sync.Union(taskBasedSignatures));
  130. if (notInSync.Any())
  131. {
  132. foreach (var signature in notInSync)
  133. {
  134. Console.WriteLine("EXCESS " + ToString(signature.Method));
  135. }
  136. }
  137. }
  138. private static void CompareAggregates(ILookup<string, MethodInfo> syncAggregates, ILookup<string, MethodInfo> asyncAggregates)
  139. {
  140. CompareSets(syncAggregates, asyncAggregates, CompareAggregateOverloads);
  141. }
  142. private static void CompareAggregateOverloads(string name, IEnumerable<MethodInfo> syncMethods, IEnumerable<MethodInfo> asyncMethods)
  143. {
  144. var sync = GetSignatures(syncMethods).Select(GetAsyncAggregateSignature).ToArray();
  145. var async = GetRewrittenSignatures(asyncMethods).ToArray();
  146. //
  147. // Ensure that async is a superset of sync.
  148. //
  149. var notInAsync = sync.Except(async);
  150. if (notInAsync.Any())
  151. {
  152. foreach (var signature in notInAsync)
  153. {
  154. Console.WriteLine("MISSING " + ToString(signature.Method));
  155. }
  156. }
  157. //
  158. // Find Task-based overloads.
  159. //
  160. var taskBasedSignatures = new List<Signature>();
  161. foreach (var signature in sync)
  162. {
  163. if (signature.ParameterTypes.Any(IsFuncOrActionType))
  164. {
  165. taskBasedSignatures.Add(GetAsyncVariant(signature));
  166. }
  167. }
  168. if (taskBasedSignatures.Count > 0)
  169. {
  170. var notInAsyncTaskBased = taskBasedSignatures.Except(async);
  171. if (notInAsyncTaskBased.Any())
  172. {
  173. foreach (var signature in notInAsyncTaskBased)
  174. {
  175. Console.WriteLine("MISSING " + name + " :: " + signature);
  176. }
  177. }
  178. }
  179. //
  180. // Check for overloads with CancellationToken.
  181. //
  182. var withCancellationToken = new List<Signature>();
  183. foreach (var signature in sync)
  184. {
  185. withCancellationToken.Add(AppendCancellationToken(signature));
  186. }
  187. foreach (var signature in taskBasedSignatures)
  188. {
  189. withCancellationToken.Add(AppendCancellationToken(signature));
  190. }
  191. var notInAsyncWithCancellationToken = withCancellationToken.Except(async);
  192. if (notInAsyncWithCancellationToken.Any())
  193. {
  194. foreach (var signature in notInAsyncWithCancellationToken)
  195. {
  196. Console.WriteLine("MISSING " + name + " :: " + signature);
  197. }
  198. }
  199. //
  200. // Excess overloads that are neither carbon copies of sync nor Task-based variants of sync.
  201. //
  202. var notInSync = async.Except(sync.Union(taskBasedSignatures).Union(withCancellationToken));
  203. if (notInSync.Any())
  204. {
  205. foreach (var signature in notInSync)
  206. {
  207. Console.WriteLine("EXCESS " + ToString(signature.Method));
  208. }
  209. }
  210. }
  211. private static bool IsFuncOrActionType(Type type)
  212. {
  213. if (type.IsConstructedGenericType)
  214. {
  215. var defName = type.GetGenericTypeDefinition().Name;
  216. return defName.StartsWith("Func`") || defName.StartsWith("Action`");
  217. }
  218. if (type == typeof(Action))
  219. {
  220. return true;
  221. }
  222. return false;
  223. }
  224. private static Signature GetAsyncVariant(Signature signature)
  225. {
  226. return new Signature(
  227. parameterTypes: signature.ParameterTypes.Select(GetAsyncVariant).ToArray(),
  228. returnType: signature.ReturnType);
  229. }
  230. private static Signature AppendCancellationToken(Signature signature)
  231. {
  232. return new Signature(
  233. parameterTypes: [.. signature.ParameterTypes, typeof(CancellationToken)],
  234. returnType: signature.ReturnType);
  235. }
  236. private static Type GetAsyncVariant(Type type)
  237. {
  238. if (IsFuncOrActionType(type))
  239. {
  240. if (type == typeof(Action))
  241. {
  242. return typeof(Func<Task>);
  243. }
  244. else
  245. {
  246. var args = type.GetGenericArguments();
  247. var defName = type.GetGenericTypeDefinition().Name;
  248. if (defName.StartsWith("Func`"))
  249. {
  250. var ret = typeof(Task<>).MakeGenericType(args.Last());
  251. return Expression.GetFuncType(Enumerable.SkipLast(args, 1).Append(ret).ToArray());
  252. }
  253. else
  254. {
  255. return Expression.GetFuncType([.. args, typeof(Task)]);
  256. }
  257. }
  258. }
  259. return type;
  260. }
  261. private static void CompareSets(ILookup<string, MethodInfo> sync, ILookup<string, MethodInfo> async, Action<string, IEnumerable<MethodInfo>, IEnumerable<MethodInfo>> compareCore)
  262. {
  263. var syncNames = sync.Select(g => g.Key).ToArray();
  264. var asyncNames = async.Select(g => g.Key).ToArray();
  265. //
  266. // Analyze that async is a superset of sync.
  267. //
  268. var notInAsync = syncNames.Except(asyncNames);
  269. foreach (var n in notInAsync)
  270. {
  271. foreach (var o in sync[n])
  272. {
  273. Console.WriteLine("MISSING " + ToString(o));
  274. }
  275. }
  276. //
  277. // Need to find the same overloads.
  278. //
  279. var inBoth = syncNames.Intersect(asyncNames);
  280. foreach (var n in inBoth)
  281. {
  282. var s = sync[n];
  283. var a = async[n];
  284. compareCore(n, s, a);
  285. }
  286. //
  287. // Report excessive API surface.
  288. //
  289. var onlyInAsync = asyncNames.Except(syncNames);
  290. foreach (var n in onlyInAsync)
  291. {
  292. foreach (var o in async[n])
  293. {
  294. Console.WriteLine("EXCESS " + ToString(o));
  295. }
  296. }
  297. }
  298. private static Operators GetQueryOperators(Type[] interfaceTypes, Type operatorsType, string[] exclude)
  299. {
  300. //
  301. // Get all the static methods.
  302. //
  303. var methods = operatorsType.GetMethods(BindingFlags.Public | BindingFlags.Static).Where(m => !exclude.Contains(m.Name));
  304. //
  305. // Get extension methods. These can be either operators or aggregates.
  306. //
  307. var extensionMethods = methods.Where(m => m.IsDefined(typeof(ExtensionAttribute))).ToArray();
  308. //
  309. // Static methods that aren't extension methods can be factories.
  310. //
  311. var factories = methods.Except(extensionMethods).Where(m => m.ReturnType.IsConstructedGenericType && interfaceTypes.Contains(m.ReturnType.GetGenericTypeDefinition())).ToArray();
  312. //
  313. // Extension methods that return the interface type are operators.
  314. //
  315. var queryOperators = extensionMethods.Where(m => m.ReturnType.IsConstructedGenericType && interfaceTypes.Contains(m.ReturnType.GetGenericTypeDefinition())).ToArray();
  316. //
  317. // Extension methods that return another type are aggregates.
  318. //
  319. var aggregates = extensionMethods.Except(queryOperators).ToArray();
  320. //
  321. // Return operators.
  322. //
  323. return new Operators(
  324. Factories: factories.ToLookup(m => m.Name, m => m),
  325. QueryOperators: queryOperators.ToLookup(m => m.Name, m => m),
  326. Aggregates: aggregates.ToLookup(m => m.Name, m => m));
  327. }
  328. private static IEnumerable<Signature> GetSignatures(IEnumerable<MethodInfo> methods)
  329. {
  330. return methods.Select(m => GetSignature(m));
  331. }
  332. private static IEnumerable<Signature> GetRewrittenSignatures(IEnumerable<MethodInfo> methods)
  333. {
  334. return GetSignatures(methods).Select(s => RewriteSignature(s));
  335. }
  336. private static Signature GetSignature(MethodInfo method)
  337. {
  338. if (method.IsGenericMethodDefinition)
  339. {
  340. var newArgs = method.GetGenericArguments().Select((t, i) => Wildcards[i]).ToArray();
  341. method = method.MakeGenericMethod(newArgs);
  342. }
  343. return new Signature(
  344. returnType: method.ReturnType,
  345. parameterTypes: method.GetParameters().Select(p => p.ParameterType).ToArray(),
  346. method: method);
  347. }
  348. private static Signature RewriteSignature(Signature signature)
  349. {
  350. return new Signature(
  351. returnType: Subst.Visit(signature.ReturnType),
  352. parameterTypes: Subst.Visit(signature.ParameterTypes),
  353. method: signature.Method);
  354. }
  355. private static Signature GetAsyncAggregateSignature(Signature signature)
  356. {
  357. var retType = signature.ReturnType == typeof(void) ? typeof(Task) : typeof(Task<>).MakeGenericType(signature.ReturnType);
  358. return new Signature(
  359. returnType: retType,
  360. parameterTypes: signature.ParameterTypes,
  361. method: signature.Method);
  362. }
  363. private static string ToString(MethodInfo? method)
  364. {
  365. if (method == null)
  366. {
  367. return "UNKNOWN";
  368. }
  369. if (method.IsGenericMethod && !method.IsGenericMethodDefinition)
  370. {
  371. method = method.GetGenericMethodDefinition();
  372. }
  373. return method.ToString() ?? "UNKNOWN";
  374. }
  375. private record Operators(
  376. ILookup<string, MethodInfo> Factories,
  377. ILookup<string, MethodInfo> QueryOperators,
  378. ILookup<string, MethodInfo> Aggregates);
  379. private class Signature(
  380. Type returnType,
  381. Type[] parameterTypes,
  382. MethodInfo? method = null) : IEquatable<Signature>
  383. {
  384. public MethodInfo? Method { get; } = method;
  385. public Type ReturnType { get; } = returnType;
  386. public Type[] ParameterTypes { get; } = parameterTypes;
  387. public static bool operator ==(Signature s1, Signature s2)
  388. {
  389. if (s1 is null && s2 is null)
  390. {
  391. return true;
  392. }
  393. if (s1 is null || s2 is null)
  394. {
  395. return false;
  396. }
  397. return s1.Equals(s2);
  398. }
  399. public static bool operator !=(Signature s1, Signature s2)
  400. {
  401. return !(s1 == s2);
  402. }
  403. public bool Equals(Signature? s)
  404. {
  405. return s is not null && ReturnType.Equals(s.ReturnType) && ParameterTypes.SequenceEqual(s.ParameterTypes);
  406. }
  407. public override bool Equals(object? obj) => obj is Signature s && Equals(s);
  408. public override int GetHashCode()
  409. {
  410. return ParameterTypes.Concat([ReturnType]).Aggregate(0, (a, t) => a * 17 + t.GetHashCode());
  411. }
  412. public override string ToString()
  413. {
  414. return "(" + string.Join(", ", ParameterTypes.Select(t => t.ToCSharp())) + ") -> " + ReturnType.ToCSharp();
  415. }
  416. }
  417. private class TypeVisitor
  418. {
  419. public virtual Type Visit(Type type)
  420. {
  421. if (type.IsArray)
  422. {
  423. if (type.GetElementType()!.MakeArrayType() == type)
  424. {
  425. return VisitArray(type);
  426. }
  427. else
  428. {
  429. return VisitMultidimensionalArray(type);
  430. }
  431. }
  432. else if (type.GetTypeInfo().IsGenericTypeDefinition)
  433. {
  434. return VisitGenericTypeDefinition(type);
  435. }
  436. else if (type.IsConstructedGenericType)
  437. {
  438. return VisitGeneric(type);
  439. }
  440. else if (type.IsByRef)
  441. {
  442. return VisitByRef(type);
  443. }
  444. else if (type.IsPointer)
  445. {
  446. return VisitPointer(type);
  447. }
  448. else
  449. {
  450. return VisitSimple(type);
  451. }
  452. }
  453. protected virtual Type VisitArray(Type type)
  454. {
  455. return Visit(type.GetElementType() ?? throw new ArgumentException($"{type} does not have an element type")).MakeArrayType();
  456. }
  457. protected virtual Type VisitMultidimensionalArray(Type type)
  458. {
  459. return Visit(type.GetElementType() ?? throw new ArgumentException($"{type} does not have an element type")).MakeArrayType(type.GetArrayRank());
  460. }
  461. protected virtual Type VisitGenericTypeDefinition(Type type)
  462. {
  463. return type;
  464. }
  465. protected virtual Type VisitGeneric(Type type)
  466. {
  467. return Visit(type.GetGenericTypeDefinition()).MakeGenericType(Visit(type.GenericTypeArguments));
  468. }
  469. protected virtual Type VisitByRef(Type type)
  470. {
  471. return Visit(type.GetElementType() ?? throw new ArgumentException($"{type} does not have an element type")).MakeByRefType();
  472. }
  473. protected virtual Type VisitPointer(Type type)
  474. {
  475. return Visit(type.GetElementType() ?? throw new ArgumentException($"{type} does not have an element type")).MakePointerType();
  476. }
  477. protected virtual Type VisitSimple(Type type)
  478. {
  479. return type;
  480. }
  481. public Type[] Visit(Type[] types)
  482. {
  483. return types.Select(Visit).ToArray();
  484. }
  485. }
  486. private class TypeSubstitutor(Dictionary<Type, Type> map) : TypeVisitor
  487. {
  488. public override Type Visit(Type type)
  489. {
  490. if (map.TryGetValue(type, out var subst))
  491. {
  492. return subst;
  493. }
  494. return base.Visit(type);
  495. }
  496. }
  497. private static readonly Type[] Wildcards = [typeof(T1), typeof(T2), typeof(T3), typeof(T4)];
  498. private class T1 { }
  499. private class T2 { }
  500. private class T3 { }
  501. private class T4 { }
  502. }
  503. internal static class TypeExtensions
  504. {
  505. public static string ToCSharp(this Type type)
  506. {
  507. if (type.IsArray)
  508. {
  509. var elementType = type.GetElementType()!;
  510. if (elementType.MakeArrayType() == type)
  511. {
  512. return elementType.ToCSharp() + "[]";
  513. }
  514. else
  515. {
  516. return elementType.ToCSharp() + "[" + new string(',', type.GetArrayRank() - 1) + "]";
  517. }
  518. }
  519. else if (type.IsConstructedGenericType)
  520. {
  521. var def = type.GetGenericTypeDefinition();
  522. var defName = def.Name[..def.Name.IndexOf('`')];
  523. return defName + "<" + string.Join(", ", type.GetGenericArguments().Select(ToCSharp)) + ">";
  524. }
  525. else
  526. {
  527. return type.Name;
  528. }
  529. }
  530. }
  531. }