123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- using System.Collections.Generic;
- using System.IO;
- using System.Text;
- using Microsoft.CodeAnalysis;
- using Microsoft.CodeAnalysis.CSharp;
- using Microsoft.CodeAnalysis.CSharp.Syntax;
- using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
- namespace System.Linq.Async.SourceGenerator
- {
- [Generator]
- public sealed class AsyncOverloadsGenerator : ISourceGenerator
- {
- private const string AttributeSource =
- "using System;\n" +
- "using System.Diagnostics;\n" +
- "namespace System.Linq\n" +
- "{\n" +
- " [AttributeUsage(AttributeTargets.Method)]\n" +
- " [Conditional(\"COMPILE_TIME_ONLY\")]\n" +
- " internal sealed class GenerateAsyncOverloadAttribute : Attribute { }\n" +
- "}\n";
- public void Initialize(GeneratorInitializationContext context)
- {
- context.RegisterForSyntaxNotifications(() => new SyntaxReceiver());
- context.RegisterForPostInitialization(c => c.AddSource("GenerateAsyncOverloadAttribute", AttributeSource));
- }
- public void Execute(GeneratorExecutionContext context)
- {
- if (context.SyntaxReceiver is not SyntaxReceiver syntaxReceiver) return;
- var options = GetGenerationOptions(context);
- var attributeSymbol = GetAsyncOverloadAttributeSymbol(context);
- var methodsBySyntaxTree = GetMethodsGroupedBySyntaxTree(context, syntaxReceiver);
- foreach (var grouping in methodsBySyntaxTree)
- context.AddSource(
- $"{Path.GetFileNameWithoutExtension(grouping.SyntaxTree.FilePath)}.AsyncOverloads",
- GenerateOverloads(grouping, options, context, attributeSymbol));
- }
- private static GenerationOptions GetGenerationOptions(GeneratorExecutionContext context)
- => new(SupportFlatAsyncApi: context.ParseOptions.PreprocessorSymbolNames.Contains("SUPPORT_FLAT_ASYNC_API"));
- private static IEnumerable<AsyncMethodGrouping> GetMethodsGroupedBySyntaxTree(GeneratorExecutionContext context, SyntaxReceiver syntaxReceiver)
- => GetMethodsGroupedBySyntaxTree(
- context,
- syntaxReceiver,
- GetAsyncOverloadAttributeSymbol(context));
- private static string GenerateOverloads(AsyncMethodGrouping grouping, GenerationOptions options, GeneratorExecutionContext context, INamedTypeSymbol attributeSymbol)
- {
- var usings = grouping.SyntaxTree.GetRoot() is CompilationUnitSyntax compilationUnit
- ? compilationUnit.Usings.ToString()
- : string.Empty;
- // This source generator gets used not just in System.Linq.Async, but also for code that has migrated from
- // System.Linq.Async to System.Interactive.Async. (E.g., we define overloads of AverageAsync that accept
- // selector callbacks. The .NET runtime library implementation offers no equivalents. We want to continue
- // to offer these even though we're decprecating System.Linq.Async, so they migrate into
- // System.Interactive.Async.) In those cases, the containing type is typically AsyncEnumerableEx,
- // but in System.Linq.Async it is AsyncEnumerable. So we need to discover the containing type name.
- var containingTypeName = grouping.Methods.FirstOrDefault()?.Symbol.ContainingType.Name ?? "AsyncEnumerable";
- var overloads = new StringBuilder();
- overloads.AppendLine("#nullable enable");
- overloads.AppendLine(usings);
- overloads.AppendLine("namespace System.Linq");
- overloads.AppendLine("{");
- overloads.AppendLine($" partial class {containingTypeName}");
- overloads.AppendLine(" {");
- foreach (var method in grouping.Methods)
- {
- var model = context.Compilation.GetSemanticModel(method.Syntax.SyntaxTree);
- overloads.AppendLine(GenerateOverload(method, options, model, attributeSymbol));
- }
- overloads.AppendLine(" }");
- overloads.AppendLine("}");
- return overloads.ToString();
- }
- private static string GenerateOverload(AsyncMethod method, GenerationOptions options, SemanticModel model, INamedTypeSymbol attributeSymbol)
- {
- var attributeListsWithGenerateAsyncOverloadRemoved = SyntaxFactory.List(method.Syntax.AttributeLists
- .Select(list => AttributeList(SeparatedList(
- (from a in list.Attributes
- let am = model.GetSymbolInfo(a.Name).Symbol?.ContainingType
- where !SymbolEqualityComparer.Default.Equals(am, attributeSymbol)
- select a))))
- .Where(list => list.Attributes.Count > 0));
- return MethodDeclaration(method.Syntax.ReturnType, GetMethodName(method.Symbol, options))
- .WithAttributeLists(attributeListsWithGenerateAsyncOverloadRemoved)
- .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)))
- .WithTypeParameterList(method.Syntax.TypeParameterList)
- .WithParameterList(method.Syntax.ParameterList)
- .WithConstraintClauses(method.Syntax.ConstraintClauses)
- .WithExpressionBody(ArrowExpressionClause(
- InvocationExpression(
- IdentifierName(method.Symbol.Name),
- ArgumentList(
- SeparatedList(
- method.Syntax.ParameterList.Parameters
- .Select(p => Argument(IdentifierName(p.Identifier))))))))
- .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
- .WithLeadingTrivia(method.Syntax.GetLeadingTrivia().Where(t => !t.IsKind(SyntaxKind.DisabledTextTrivia) && t.GetStructure() is not DirectiveTriviaSyntax))
- .NormalizeWhitespace()
- .ToFullString();
- }
- private static INamedTypeSymbol GetAsyncOverloadAttributeSymbol(GeneratorExecutionContext context)
- => context.Compilation.GetTypeByMetadataName("System.Linq.GenerateAsyncOverloadAttribute") ?? throw new InvalidOperationException();
- private static IEnumerable<AsyncMethodGrouping> GetMethodsGroupedBySyntaxTree(GeneratorExecutionContext context, SyntaxReceiver syntaxReceiver, INamedTypeSymbol attributeSymbol)
- => from candidate in syntaxReceiver.Candidates
- group candidate by candidate.SyntaxTree into grouping
- let model = context.Compilation.GetSemanticModel(grouping.Key)
- select new AsyncMethodGrouping(
- grouping.Key,
- from methodSyntax in grouping
- let methodSymbol = model.GetDeclaredSymbol(methodSyntax) ?? throw new InvalidOperationException()
- where methodSymbol.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass!, attributeSymbol))
- select new AsyncMethod(methodSymbol, methodSyntax));
- private static string GetMethodName(IMethodSymbol methodSymbol, GenerationOptions options)
- {
- var methodName = methodSymbol.Name.Replace("Core", "");
- return options.SupportFlatAsyncApi
- ? methodName.Replace("Await", "").Replace("WithCancellation", "")
- : methodName;
- }
- }
- }
|