123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673 |
- // Licensed to the .NET Foundation under one or more agreements.
- // The .NET Foundation licenses this file to you under the Apache 2.0 License.
- // See the LICENSE file in the project root for more information.
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using System.Linq.Expressions;
- using System.Reflection;
- using System.Runtime.CompilerServices;
- using System.Threading;
- using System.Threading.Tasks;
- namespace ApiCompare
- {
- class Program
- {
- private static readonly Type asyncInterfaceType = typeof(IAsyncEnumerable<>);
- private static readonly Type syncInterfaceType = typeof(IEnumerable<>);
- private static readonly Type asyncOrderedInterfaceType = typeof(IOrderedAsyncEnumerable<>);
- private static readonly Type syncOrderedInterfaceType = typeof(IOrderedEnumerable<>);
- private static readonly string[] exceptions = new[]
- {
- "SkipLast", // In .NET Core 2.0
- "TakeLast", // In .NET Core 2.0
- "ToHashSet", // In .NET Core 2.0
- "Cast", // Non-generic methods
- "OfType", // Non-generic methods
- "AsEnumerable", // Trivially renamed
- "AsAsyncEnumerable", // Trivially renamed
- "ForEachAsync", // "foreach await" language substitute for the time being
- };
- private static readonly TypeSubstitutor subst = new TypeSubstitutor(new Dictionary<Type, Type>
- {
- { asyncInterfaceType, syncInterfaceType },
- { asyncOrderedInterfaceType, syncOrderedInterfaceType },
- { typeof(IAsyncGrouping<,>), typeof(IGrouping<,>) },
- });
- static void Main()
- {
- var asyncOperatorsType = typeof(AsyncEnumerable);
- var syncOperatorsType = typeof(Enumerable);
- Compare(syncOperatorsType, asyncOperatorsType);
- }
- static void Compare(Type syncOperatorsType, Type asyncOperatorsType)
- {
- var syncOperators = GetQueryOperators(new[] { syncInterfaceType, syncOrderedInterfaceType}, syncOperatorsType, exceptions);
- var asyncOperators = GetQueryOperators(new[] { asyncInterfaceType, asyncOrderedInterfaceType }, asyncOperatorsType, exceptions);
- CompareFactories(syncOperators.Factories, asyncOperators.Factories);
- CompareQueryOperators(syncOperators.QueryOperators, asyncOperators.QueryOperators);
- CompareAggregates(syncOperators.Aggregates, asyncOperators.Aggregates);
- }
- static void CompareFactories(ILookup<string, MethodInfo> syncFactories, ILookup<string, MethodInfo> asyncFactories)
- {
- CompareSets(syncFactories, asyncFactories, CompareFactoryOverloads);
- }
- static void CompareFactoryOverloads(string name, IEnumerable<MethodInfo> syncMethods, IEnumerable<MethodInfo> asyncMethods)
- {
- var sync = GetSignatures(syncMethods).ToArray();
- var async = GetRewrittenSignatures(asyncMethods).ToArray();
- //
- // Ensure that async is a superset of sync.
- //
- var notInAsync = sync.Except(async);
- if (notInAsync.Any())
- {
- foreach (var signature in notInAsync)
- {
- Console.WriteLine("MISSING " + ToString(signature.Method));
- }
- }
- //
- // Check for excess overloads.
- //
- var notInSync = async.Except(sync);
- if (notInSync.Any())
- {
- foreach (var signature in notInSync)
- {
- Console.WriteLine("EXCESS " + ToString(signature.Method));
- }
- }
- }
- static void CompareQueryOperators(ILookup<string, MethodInfo> syncOperators, ILookup<string, MethodInfo> asyncOperators)
- {
- CompareSets(syncOperators, asyncOperators, CompareQueryOperatorsOverloads);
- }
- static void CompareQueryOperatorsOverloads(string name, IEnumerable<MethodInfo> syncMethods, IEnumerable<MethodInfo> asyncMethods)
- {
- var sync = GetSignatures(syncMethods).ToArray();
- var async = GetRewrittenSignatures(asyncMethods).ToArray();
- //
- // Ensure that async is a superset of sync.
- //
- var notInAsync = sync.Except(async);
- if (notInAsync.Any())
- {
- foreach (var signature in notInAsync)
- {
- Console.WriteLine("MISSING " + ToString(signature.Method));
- }
- }
- //
- // Find Task-based overloads.
- //
- var taskBasedSignatures = new List<Signature>();
- foreach (var signature in sync)
- {
- if (signature.ParameterTypes.Any(IsFuncOrActionType))
- {
- taskBasedSignatures.Add(GetAsyncVariant(signature));
- }
- }
- if (taskBasedSignatures.Count > 0)
- {
- var notInAsyncTaskBased = taskBasedSignatures.Except(async);
- if (notInAsyncTaskBased.Any())
- {
- foreach (var signature in notInAsyncTaskBased)
- {
- Console.WriteLine("MISSING " + name + " :: " + signature);
- }
- }
- }
- //
- // Excess overloads that are neither carbon copies of sync nor Task-based variants of sync.
- //
- var notInSync = async.Except(sync.Union(taskBasedSignatures));
- if (notInSync.Any())
- {
- foreach (var signature in notInSync)
- {
- Console.WriteLine("EXCESS " + ToString(signature.Method));
- }
- }
- }
- static void CompareAggregates(ILookup<string, MethodInfo> syncAggregates, ILookup<string, MethodInfo> asyncAggregates)
- {
- CompareSets(syncAggregates, asyncAggregates, CompareAggregateOverloads);
- }
- static void CompareAggregateOverloads(string name, IEnumerable<MethodInfo> syncMethods, IEnumerable<MethodInfo> asyncMethods)
- {
- var sync = GetSignatures(syncMethods).Select(GetAsyncAggregateSignature).ToArray();
- var async = GetRewrittenSignatures(asyncMethods).ToArray();
- //
- // Ensure that async is a superset of sync.
- //
- var notInAsync = sync.Except(async);
- if (notInAsync.Any())
- {
- foreach (var signature in notInAsync)
- {
- Console.WriteLine("MISSING " + ToString(signature.Method));
- }
- }
- //
- // Find Task-based overloads.
- //
- var taskBasedSignatures = new List<Signature>();
- foreach (var signature in sync)
- {
- if (signature.ParameterTypes.Any(IsFuncOrActionType))
- {
- taskBasedSignatures.Add(GetAsyncVariant(signature));
- }
- }
- if (taskBasedSignatures.Count > 0)
- {
- var notInAsyncTaskBased = taskBasedSignatures.Except(async);
- if (notInAsyncTaskBased.Any())
- {
- foreach (var signature in notInAsyncTaskBased)
- {
- Console.WriteLine("MISSING " + name + " :: " + signature);
- }
- }
- }
- //
- // Check for overloads with CancellationToken.
- //
- var withCancellationToken = new List<Signature>();
- foreach (var signature in sync)
- {
- withCancellationToken.Add(AppendCancellationToken(signature));
- }
- foreach (var signature in taskBasedSignatures)
- {
- withCancellationToken.Add(AppendCancellationToken(signature));
- }
- var notInAsyncWithCancellationToken = withCancellationToken.Except(async);
- if (notInAsyncWithCancellationToken.Any())
- {
- foreach (var signature in notInAsyncWithCancellationToken)
- {
- Console.WriteLine("MISSING " + name + " :: " + signature);
- }
- }
- //
- // Excess overloads that are neither carbon copies of sync nor Task-based variants of sync.
- //
- var notInSync = async.Except(sync.Union(taskBasedSignatures).Union(withCancellationToken));
- if (notInSync.Any())
- {
- foreach (var signature in notInSync)
- {
- Console.WriteLine("EXCESS " + ToString(signature.Method));
- }
- }
- }
- private static bool IsFuncOrActionType(Type type)
- {
- if (type.IsConstructedGenericType)
- {
- var defName = type.GetGenericTypeDefinition().Name;
- return defName.StartsWith("Func`") || defName.StartsWith("Action`");
- }
- if (type == typeof(Action))
- {
- return true;
- }
- return false;
- }
- private static Signature GetAsyncVariant(Signature signature)
- {
- return new Signature
- {
- ParameterTypes = signature.ParameterTypes.Select(GetAsyncVariant).ToArray(),
- ReturnType = signature.ReturnType
- };
- }
- private static Signature AppendCancellationToken(Signature signature)
- {
- return new Signature
- {
- ParameterTypes = signature.ParameterTypes.Concat(new[] { typeof(CancellationToken) }).ToArray(),
- ReturnType = signature.ReturnType
- };
- }
- private static Type GetAsyncVariant(Type type)
- {
- if (IsFuncOrActionType(type))
- {
- if (type == typeof(Action))
- {
- return typeof(Func<Task>);
- }
- else
- {
- var args = type.GetGenericArguments();
- var defName = type.GetGenericTypeDefinition().Name;
- if (defName.StartsWith("Func`"))
- {
- var ret = typeof(Task<>).MakeGenericType(args.Last());
- return Expression.GetFuncType(args.SkipLast(1).Append(ret).ToArray());
- }
- else
- {
- return Expression.GetFuncType(args.Append(typeof(Task)).ToArray());
- }
- }
- }
- return type;
- }
- static void CompareSets(ILookup<string, MethodInfo> sync, ILookup<string, MethodInfo> async, Action<string, IEnumerable<MethodInfo>, IEnumerable<MethodInfo>> compareCore)
- {
- var syncNames = sync.Select(g => g.Key).ToArray();
- var asyncNames = async.Select(g => g.Key).ToArray();
- //
- // Analyze that async is a superset of sync.
- //
- var notInAsync = syncNames.Except(asyncNames);
- foreach (var n in notInAsync)
- {
- foreach (var o in sync[n])
- {
- Console.WriteLine("MISSING " + ToString(o));
- }
- }
- //
- // Need to find the same overloads.
- //
- var inBoth = syncNames.Intersect(asyncNames);
- foreach (var n in inBoth)
- {
- var s = sync[n];
- var a = async[n];
- compareCore(n, s, a);
- }
- //
- // Report excessive API surface.
- //
- var onlyInAsync = asyncNames.Except(syncNames);
- foreach (var n in onlyInAsync)
- {
- foreach (var o in async[n])
- {
- Console.WriteLine("EXCESS " + ToString(o));
- }
- }
- }
- static Operators GetQueryOperators(Type[] interfaceTypes, Type operatorsType, string[] exclude)
- {
- //
- // Get all the static methods.
- //
- var methods = operatorsType.GetMethods(BindingFlags.Public | BindingFlags.Static).Where(m => !exclude.Contains(m.Name));
- //
- // Get extension methods. These can be either operators or aggregates.
- //
- var extensionMethods = methods.Where(m => m.IsDefined(typeof(ExtensionAttribute))).ToArray();
- //
- // Static methods that aren't extension methods can be factories.
- //
- var factories = methods.Except(extensionMethods).Where(m => m.ReturnType.IsConstructedGenericType && interfaceTypes.Contains(m.ReturnType.GetGenericTypeDefinition())).ToArray();
- //
- // Extension methods that return the interface type are operators.
- //
- var queryOperators = extensionMethods.Where(m => m.ReturnType.IsConstructedGenericType && interfaceTypes.Contains(m.ReturnType.GetGenericTypeDefinition())).ToArray();
- //
- // Extension methods that return another type are aggregates.
- //
- var aggregates = extensionMethods.Except(queryOperators).ToArray();
- //
- // Return operators.
- //
- return new Operators
- {
- Factories = factories.ToLookup(m => m.Name, m => m),
- QueryOperators = queryOperators.ToLookup(m => m.Name, m => m),
- Aggregates = aggregates.ToLookup(m => m.Name, m => m),
- };
- }
- static IEnumerable<Signature> GetSignatures(IEnumerable<MethodInfo> methods)
- {
- return methods.Select(m => GetSignature(m));
- }
- static IEnumerable<Signature> GetRewrittenSignatures(IEnumerable<MethodInfo> methods)
- {
- return GetSignatures(methods).Select(s => RewriteSignature(s));
- }
- static Signature GetSignature(MethodInfo method)
- {
- if (method.IsGenericMethodDefinition)
- {
- var newArgs = method.GetGenericArguments().Select((t, i) => Wildcards[i]).ToArray();
- method = method.MakeGenericMethod(newArgs);
- }
- return new Signature
- {
- Method = method,
- ReturnType = method.ReturnType,
- ParameterTypes = method.GetParameters().Select(p => p.ParameterType).ToArray()
- };
- }
- static Signature RewriteSignature(Signature signature)
- {
- return new Signature
- {
- Method = signature.Method,
- ReturnType = subst.Visit(signature.ReturnType),
- ParameterTypes = subst.Visit(signature.ParameterTypes)
- };
- }
- static Signature GetAsyncAggregateSignature(Signature signature)
- {
- var retType = signature.ReturnType == typeof(void) ? typeof(Task) : typeof(Task<>).MakeGenericType(signature.ReturnType);
- return new Signature
- {
- Method = signature.Method,
- ReturnType = retType,
- ParameterTypes = signature.ParameterTypes
- };
- }
- static string ToString(MethodInfo method)
- {
- if (method == null)
- {
- return "UNKNOWN";
- }
- if (method.IsGenericMethod && !method.IsGenericMethodDefinition)
- {
- method = method.GetGenericMethodDefinition();
- }
- return method.ToString();
- }
- class Operators
- {
- public ILookup<string, MethodInfo> Factories;
- public ILookup<string, MethodInfo> QueryOperators;
- public ILookup<string, MethodInfo> Aggregates;
- }
- class Signature : IEquatable<Signature>
- {
- public MethodInfo Method;
- public Type ReturnType;
- public Type[] ParameterTypes;
- public static bool operator ==(Signature s1, Signature s2)
- {
- if ((object)s1 == null && (object)s2 == null)
- {
- return true;
- }
- if ((object)s1 == null || (object)s2 == null)
- {
- return false;
- }
- return s1.Equals(s2);
- }
- public static bool operator !=(Signature s1, Signature s2)
- {
- return !(s1 == s2);
- }
- public bool Equals(Signature s)
- {
- return (object)s != null && ReturnType.Equals(s.ReturnType) && ParameterTypes.SequenceEqual(s.ParameterTypes);
- }
- public override bool Equals(object obj)
- {
- if (obj is Signature s)
- {
- return Equals(s);
- }
- return false;
- }
- public override int GetHashCode()
- {
- return ParameterTypes.Concat(new[] { ReturnType }).Aggregate(0, (a, t) => a * 17 + t.GetHashCode());
- }
- public override string ToString()
- {
- return "(" + string.Join(", ", ParameterTypes.Select(t => t.ToCSharp())) + ") -> " + ReturnType.ToCSharp();
- }
- }
- class TypeVisitor
- {
- public virtual Type Visit(Type type)
- {
- if (type.IsArray)
- {
- if (type.GetElementType().MakeArrayType() == type)
- {
- return VisitArray(type);
- }
- else
- {
- return VisitMultidimensionalArray(type);
- }
- }
- else if (type.GetTypeInfo().IsGenericTypeDefinition)
- {
- return VisitGenericTypeDefinition(type);
- }
- else if (type.IsConstructedGenericType)
- {
- return VisitGeneric(type);
- }
- else if (type.IsByRef)
- {
- return VisitByRef(type);
- }
- else if (type.IsPointer)
- {
- return VisitPointer(type);
- }
- else
- {
- return VisitSimple(type);
- }
- }
- protected virtual Type VisitArray(Type type)
- {
- return Visit(type.GetElementType()).MakeArrayType();
- }
- protected virtual Type VisitMultidimensionalArray(Type type)
- {
- return Visit(type.GetElementType()).MakeArrayType(type.GetArrayRank());
- }
- protected virtual Type VisitGenericTypeDefinition(Type type)
- {
- return type;
- }
- protected virtual Type VisitGeneric(Type type)
- {
- return Visit(type.GetGenericTypeDefinition()).MakeGenericType(Visit(type.GenericTypeArguments));
- }
- protected virtual Type VisitByRef(Type type)
- {
- return Visit(type.GetElementType()).MakeByRefType();
- }
- protected virtual Type VisitPointer(Type type)
- {
- return Visit(type.GetElementType()).MakePointerType();
- }
- protected virtual Type VisitSimple(Type type)
- {
- return type;
- }
- public Type[] Visit(Type[] types)
- {
- return types.Select(Visit).ToArray();
- }
- }
- class TypeSubstitutor : TypeVisitor
- {
- private readonly Dictionary<Type, Type> map;
- public TypeSubstitutor(Dictionary<Type, Type> map)
- {
- this.map = map;
- }
- public override Type Visit(Type type)
- {
- if (map.TryGetValue(type, out var subst))
- {
- return subst;
- }
- return base.Visit(type);
- }
- }
- private static readonly Type[] Wildcards = new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4) };
- class T1 { }
- class T2 { }
- class T3 { }
- class T4 { }
- }
- static class TypeExtensions
- {
- public static string ToCSharp(this Type type)
- {
- if (type.IsArray)
- {
- if (type.GetElementType().MakeArrayType() == type)
- {
- return type.GetElementType().ToCSharp() + "[]";
- }
- else
- {
- return type.GetElementType().ToCSharp() + "[" + new string(',', type.GetArrayRank() - 1) + "]";
- }
- }
- else if (type.IsConstructedGenericType)
- {
- var def = type.GetGenericTypeDefinition();
- var defName = def.Name.Substring(0, def.Name.IndexOf('`'));
- return defName + "<" + string.Join(", ", type.GetGenericArguments().Select(ToCSharp)) + ">";
- }
- else
- {
- return type.Name;
- }
- }
- }
- }
|