AsyncOverloadsGenerator.cs 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. using System.Collections.Generic;
  2. using System.IO;
  3. using System.Text;
  4. using Microsoft.CodeAnalysis;
  5. using Microsoft.CodeAnalysis.CSharp;
  6. using Microsoft.CodeAnalysis.CSharp.Syntax;
  7. using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
  8. namespace System.Linq.Async.SourceGenerator
  9. {
  10. [Generator]
  11. public sealed class AsyncOverloadsGenerator : ISourceGenerator
  12. {
  13. private const string AttributeSource =
  14. "using System;\n" +
  15. "using System.Diagnostics;\n" +
  16. "namespace System.Linq\n" +
  17. "{\n" +
  18. " [AttributeUsage(AttributeTargets.Method)]\n" +
  19. " [Conditional(\"COMPILE_TIME_ONLY\")]\n" +
  20. " internal sealed class GenerateAsyncOverloadAttribute : Attribute { }\n" +
  21. "}\n";
  22. public void Initialize(GeneratorInitializationContext context)
  23. {
  24. context.RegisterForSyntaxNotifications(() => new SyntaxReceiver());
  25. context.RegisterForPostInitialization(c => c.AddSource("GenerateAsyncOverloadAttribute", AttributeSource));
  26. }
  27. public void Execute(GeneratorExecutionContext context)
  28. {
  29. if (context.SyntaxReceiver is not SyntaxReceiver syntaxReceiver) return;
  30. var options = GetGenerationOptions(context);
  31. var methodsBySyntaxTree = GetMethodsGroupedBySyntaxTree(context, syntaxReceiver);
  32. foreach (var grouping in methodsBySyntaxTree)
  33. context.AddSource(
  34. $"{Path.GetFileNameWithoutExtension(grouping.SyntaxTree.FilePath)}.AsyncOverloads",
  35. GenerateOverloads(grouping, options));
  36. }
  37. private static GenerationOptions GetGenerationOptions(GeneratorExecutionContext context)
  38. => new(SupportFlatAsyncApi: context.ParseOptions.PreprocessorSymbolNames.Contains("SUPPORT_FLAT_ASYNC_API"));
  39. private static IEnumerable<AsyncMethodGrouping> GetMethodsGroupedBySyntaxTree(GeneratorExecutionContext context, SyntaxReceiver syntaxReceiver)
  40. => GetMethodsGroupedBySyntaxTree(
  41. context,
  42. syntaxReceiver,
  43. GetAsyncOverloadAttributeSymbol(context));
  44. private static string GenerateOverloads(AsyncMethodGrouping grouping, GenerationOptions options)
  45. {
  46. var usings = grouping.SyntaxTree.GetRoot() is CompilationUnitSyntax compilationUnit
  47. ? compilationUnit.Usings.ToString()
  48. : string.Empty;
  49. var overloads = new StringBuilder();
  50. overloads.AppendLine("#nullable enable");
  51. overloads.AppendLine(usings);
  52. overloads.AppendLine("namespace System.Linq");
  53. overloads.AppendLine("{");
  54. overloads.AppendLine(" partial class AsyncEnumerable");
  55. overloads.AppendLine(" {");
  56. foreach (var method in grouping.Methods)
  57. overloads.AppendLine(GenerateOverload(method, options));
  58. overloads.AppendLine(" }");
  59. overloads.AppendLine("}");
  60. return overloads.ToString();
  61. }
  62. private static string GenerateOverload(AsyncMethod method, GenerationOptions options)
  63. => MethodDeclaration(method.Syntax.ReturnType, GetMethodName(method.Symbol, options))
  64. .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)))
  65. .WithTypeParameterList(method.Syntax.TypeParameterList)
  66. .WithParameterList(method.Syntax.ParameterList)
  67. .WithConstraintClauses(method.Syntax.ConstraintClauses)
  68. .WithExpressionBody(ArrowExpressionClause(
  69. InvocationExpression(
  70. IdentifierName(method.Symbol.Name),
  71. ArgumentList(
  72. SeparatedList(
  73. method.Syntax.ParameterList.Parameters
  74. .Select(p => Argument(IdentifierName(p.Identifier))))))))
  75. .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
  76. .WithLeadingTrivia(method.Syntax.GetLeadingTrivia().Where(t => t.GetStructure() is not DirectiveTriviaSyntax))
  77. .NormalizeWhitespace()
  78. .ToFullString();
  79. private static INamedTypeSymbol GetAsyncOverloadAttributeSymbol(GeneratorExecutionContext context)
  80. => context.Compilation.GetTypeByMetadataName("System.Linq.GenerateAsyncOverloadAttribute") ?? throw new InvalidOperationException();
  81. private static IEnumerable<AsyncMethodGrouping> GetMethodsGroupedBySyntaxTree(GeneratorExecutionContext context, SyntaxReceiver syntaxReceiver, INamedTypeSymbol attributeSymbol)
  82. => from candidate in syntaxReceiver.Candidates
  83. group candidate by candidate.SyntaxTree into grouping
  84. let model = context.Compilation.GetSemanticModel(grouping.Key)
  85. select new AsyncMethodGrouping(
  86. grouping.Key,
  87. from methodSyntax in grouping
  88. let methodSymbol = model.GetDeclaredSymbol(methodSyntax) ?? throw new InvalidOperationException()
  89. where methodSymbol.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass!, attributeSymbol))
  90. select new AsyncMethod(methodSymbol, methodSyntax));
  91. private static string GetMethodName(IMethodSymbol methodSymbol, GenerationOptions options)
  92. {
  93. var methodName = methodSymbol.Name.Replace("Core", "");
  94. return options.SupportFlatAsyncApi
  95. ? methodName.Replace("Await", "").Replace("WithCancellation", "")
  96. : methodName;
  97. }
  98. }
  99. }