AsyncOverloadsGenerator.cs 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. using System.Collections.Generic;
  2. using System.IO;
  3. using System.Text;
  4. using Microsoft.CodeAnalysis;
  5. using Microsoft.CodeAnalysis.CSharp;
  6. using Microsoft.CodeAnalysis.CSharp.Syntax;
  7. using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
  8. namespace System.Linq.Async.SourceGenerator
  9. {
  10. [Generator]
  11. public sealed class AsyncOverloadsGenerator : ISourceGenerator
  12. {
  13. private const string AttributeSource =
  14. "using System;\n" +
  15. "using System.Diagnostics;\n" +
  16. "namespace System.Linq\n" +
  17. "{\n" +
  18. " [AttributeUsage(AttributeTargets.Method)]\n" +
  19. " [Conditional(\"COMPILE_TIME_ONLY\")]\n" +
  20. " internal sealed class GenerateAsyncOverloadAttribute : Attribute { }\n" +
  21. "}\n";
  22. public void Initialize(GeneratorInitializationContext context)
  23. {
  24. context.RegisterForSyntaxNotifications(() => new SyntaxReceiver());
  25. context.RegisterForPostInitialization(c => c.AddSource("GenerateAsyncOverloadAttribute", AttributeSource));
  26. }
  27. public void Execute(GeneratorExecutionContext context)
  28. {
  29. if (context.SyntaxReceiver is not SyntaxReceiver syntaxReceiver) return;
  30. var options = GetGenerationOptions(context);
  31. var attributeSymbol = GetAsyncOverloadAttributeSymbol(context);
  32. var methodsBySyntaxTree = GetMethodsGroupedBySyntaxTree(context, syntaxReceiver);
  33. foreach (var grouping in methodsBySyntaxTree)
  34. context.AddSource(
  35. $"{Path.GetFileNameWithoutExtension(grouping.SyntaxTree.FilePath)}.AsyncOverloads",
  36. GenerateOverloads(grouping, options, context, attributeSymbol));
  37. }
  38. private static GenerationOptions GetGenerationOptions(GeneratorExecutionContext context)
  39. => new(SupportFlatAsyncApi: context.ParseOptions.PreprocessorSymbolNames.Contains("SUPPORT_FLAT_ASYNC_API"));
  40. private static IEnumerable<AsyncMethodGrouping> GetMethodsGroupedBySyntaxTree(GeneratorExecutionContext context, SyntaxReceiver syntaxReceiver)
  41. => GetMethodsGroupedBySyntaxTree(
  42. context,
  43. syntaxReceiver,
  44. GetAsyncOverloadAttributeSymbol(context));
  45. private static string GenerateOverloads(AsyncMethodGrouping grouping, GenerationOptions options, GeneratorExecutionContext context, INamedTypeSymbol attributeSymbol)
  46. {
  47. var usings = grouping.SyntaxTree.GetRoot() is CompilationUnitSyntax compilationUnit
  48. ? compilationUnit.Usings.ToString()
  49. : string.Empty;
  50. // This source generator gets used not just in System.Linq.Async, but also for code that has migrated from
  51. // System.Linq.Async to System.Interactive.Async. (E.g., we define overloads of AverageAsync that accept
  52. // selector callbacks. The .NET runtime library implementation offers no equivalents. We want to continue
  53. // to offer these even though we're decprecating System.Linq.Async, so they migrate into
  54. // System.Interactive.Async.) In those cases, the containing type is typically AsyncEnumerableEx,
  55. // but in System.Linq.Async it is AsyncEnumerable. So we need to discover the containing type name.
  56. var containingTypeName = grouping.Methods.FirstOrDefault()?.Symbol.ContainingType.Name ?? "AsyncEnumerable";
  57. var overloads = new StringBuilder();
  58. overloads.AppendLine("#nullable enable");
  59. overloads.AppendLine(usings);
  60. overloads.AppendLine("namespace System.Linq");
  61. overloads.AppendLine("{");
  62. overloads.AppendLine($" partial class {containingTypeName}");
  63. overloads.AppendLine(" {");
  64. foreach (var method in grouping.Methods)
  65. {
  66. var model = context.Compilation.GetSemanticModel(method.Syntax.SyntaxTree);
  67. overloads.AppendLine(GenerateOverload(method, options, model, attributeSymbol));
  68. }
  69. overloads.AppendLine(" }");
  70. overloads.AppendLine("}");
  71. return overloads.ToString();
  72. }
  73. private static string GenerateOverload(AsyncMethod method, GenerationOptions options, SemanticModel model, INamedTypeSymbol attributeSymbol)
  74. {
  75. var attributeListsWithGenerateAsyncOverloadRemoved = SyntaxFactory.List(method.Syntax.AttributeLists
  76. .Select(list => AttributeList(SeparatedList(
  77. (from a in list.Attributes
  78. let am = model.GetSymbolInfo(a.Name).Symbol?.ContainingType
  79. where !SymbolEqualityComparer.Default.Equals(am, attributeSymbol)
  80. select a))))
  81. .Where(list => list.Attributes.Count > 0));
  82. return MethodDeclaration(method.Syntax.ReturnType, GetMethodName(method.Symbol, options))
  83. .WithAttributeLists(attributeListsWithGenerateAsyncOverloadRemoved)
  84. .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)))
  85. .WithTypeParameterList(method.Syntax.TypeParameterList)
  86. .WithParameterList(method.Syntax.ParameterList)
  87. .WithConstraintClauses(method.Syntax.ConstraintClauses)
  88. .WithExpressionBody(ArrowExpressionClause(
  89. InvocationExpression(
  90. IdentifierName(method.Symbol.Name),
  91. ArgumentList(
  92. SeparatedList(
  93. method.Syntax.ParameterList.Parameters
  94. .Select(p => Argument(IdentifierName(p.Identifier))))))))
  95. .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
  96. .WithLeadingTrivia(method.Syntax.GetLeadingTrivia().Where(t => !t.IsKind(SyntaxKind.DisabledTextTrivia) && t.GetStructure() is not DirectiveTriviaSyntax))
  97. .NormalizeWhitespace()
  98. .ToFullString();
  99. }
  100. private static INamedTypeSymbol GetAsyncOverloadAttributeSymbol(GeneratorExecutionContext context)
  101. => context.Compilation.GetTypeByMetadataName("System.Linq.GenerateAsyncOverloadAttribute") ?? throw new InvalidOperationException();
  102. private static IEnumerable<AsyncMethodGrouping> GetMethodsGroupedBySyntaxTree(GeneratorExecutionContext context, SyntaxReceiver syntaxReceiver, INamedTypeSymbol attributeSymbol)
  103. => from candidate in syntaxReceiver.Candidates
  104. group candidate by candidate.SyntaxTree into grouping
  105. let model = context.Compilation.GetSemanticModel(grouping.Key)
  106. select new AsyncMethodGrouping(
  107. grouping.Key,
  108. from methodSyntax in grouping
  109. let methodSymbol = model.GetDeclaredSymbol(methodSyntax) ?? throw new InvalidOperationException()
  110. where methodSymbol.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass!, attributeSymbol))
  111. select new AsyncMethod(methodSymbol, methodSyntax));
  112. private static string GetMethodName(IMethodSymbol methodSymbol, GenerationOptions options)
  113. {
  114. var methodName = methodSymbol.Name.Replace("Core", "");
  115. return options.SupportFlatAsyncApi
  116. ? methodName.Replace("Await", "").Replace("WithCancellation", "")
  117. : methodName;
  118. }
  119. }
  120. }