AsyncOverloadsGenerator.cs 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the MIT License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.IO;
  6. using System.Text;
  7. using Microsoft.CodeAnalysis;
  8. using Microsoft.CodeAnalysis.CSharp;
  9. using Microsoft.CodeAnalysis.CSharp.Syntax;
  10. using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
  11. namespace System.Linq.Async.SourceGenerator
  12. {
  13. [Generator]
  14. public sealed class AsyncOverloadsGenerator : ISourceGenerator
  15. {
  16. private const string GenerateAsyncOverloadAttributeSource =
  17. "using System;\n" +
  18. "using System.Diagnostics;\n" +
  19. "namespace System.Linq\n" +
  20. "{\n" +
  21. " [AttributeUsage(AttributeTargets.Method)]\n" +
  22. " [Conditional(\"COMPILE_TIME_ONLY\")]\n" +
  23. " internal sealed class GenerateAsyncOverloadAttribute : Attribute { }\n" +
  24. "}\n";
  25. private const string DuplicateAsyncEnumerableAsAsyncEnumerableDeprecatedAttributeSource =
  26. "using System;\n" +
  27. "using System.Diagnostics;\n" +
  28. "namespace System.Linq\n" +
  29. "{\n" +
  30. " [AttributeUsage(AttributeTargets.Assembly)]\n" +
  31. " [Conditional(\"COMPILE_TIME_ONLY\")]\n" +
  32. " internal sealed class DuplicateAsyncEnumerableAsAsyncEnumerableDeprecatedAttribute : Attribute { }\n" +
  33. "}\n";
  34. public void Initialize(GeneratorInitializationContext context)
  35. {
  36. context.RegisterForSyntaxNotifications(() => new SyntaxReceiver());
  37. context.RegisterForPostInitialization(c =>
  38. {
  39. c.AddSource("GenerateAsyncOverloadAttribute", GenerateAsyncOverloadAttributeSource);
  40. c.AddSource("DuplicateAsyncEnumerableAsAsyncEnumerableDeprecatedAttribute", DuplicateAsyncEnumerableAsAsyncEnumerableDeprecatedAttributeSource);
  41. });
  42. }
  43. public void Execute(GeneratorExecutionContext context)
  44. {
  45. if (context.SyntaxReceiver is not SyntaxReceiver syntaxReceiver) return;
  46. var options = GetGenerationOptions(context);
  47. var asyncOverloadAttributeSymbol = GetAsyncOverloadAttributeSymbol(context);
  48. var methodsBySyntaxTree = GetMethodsGroupedBySyntaxTree(context, syntaxReceiver);
  49. foreach (var grouping in methodsBySyntaxTree.Where(g => g.Methods.Any()))
  50. context.AddSource(
  51. $"{Path.GetFileNameWithoutExtension(grouping.SyntaxTree.FilePath)}.AsyncOverloads",
  52. GenerateOverloads(grouping, options, context, asyncOverloadAttributeSymbol));
  53. var duplicateAsyncEnumerableAsAsyncEnumerableDeprecatedAttributeSymbol = GetDuplicateAsyncEnumerableAsAsyncEnumerableDeprecatedAttributeSymbol(context);
  54. var dupBuilder = new DeprecatedDuplicateBuilder(
  55. context,
  56. options,
  57. asyncOverloadAttributeSymbol,
  58. duplicateAsyncEnumerableAsAsyncEnumerableDeprecatedAttributeSymbol,
  59. syntaxReceiver);
  60. dupBuilder.BuildDuplicatesIfRequired();
  61. }
  62. private static GenerationOptions GetGenerationOptions(GeneratorExecutionContext context)
  63. => new(SupportFlatAsyncApi: context.ParseOptions.PreprocessorSymbolNames.Contains("SUPPORT_FLAT_ASYNC_API"));
  64. private static IEnumerable<AsyncMethodGrouping> GetMethodsGroupedBySyntaxTree(GeneratorExecutionContext context, SyntaxReceiver syntaxReceiver)
  65. => GetMethodsGroupedBySyntaxTree(
  66. context,
  67. syntaxReceiver,
  68. GetAsyncOverloadAttributeSymbol(context));
  69. private static string GenerateOverloads(AsyncMethodGrouping grouping, GenerationOptions options, GeneratorExecutionContext context, INamedTypeSymbol attributeSymbol)
  70. {
  71. var usings = grouping.SyntaxTree.GetRoot() is CompilationUnitSyntax compilationUnit
  72. ? compilationUnit.Usings.ToString()
  73. : string.Empty;
  74. // This source generator gets used not just in System.Linq.Async, but also for code that has migrated from
  75. // System.Linq.Async to System.Interactive.Async. (E.g., we define overloads of AverageAsync that accept
  76. // selector callbacks. The .NET runtime library implementation offers no equivalents. We want to continue
  77. // to offer these even though we're decprecating System.Linq.Async, so they migrate into
  78. // System.Interactive.Async.) In those cases, the containing type is typically AsyncEnumerableEx,
  79. // but in System.Linq.Async it is AsyncEnumerable. So we need to discover the containing type name.
  80. var containingTypeName = grouping.Methods.FirstOrDefault()?.Symbol.ContainingType.Name ?? "AsyncEnumerable";
  81. var overloads = new StringBuilder();
  82. overloads.AppendLine("#nullable enable");
  83. overloads.AppendLine(usings);
  84. overloads.AppendLine("namespace System.Linq");
  85. overloads.AppendLine("{");
  86. overloads.AppendLine($" partial class {containingTypeName}");
  87. overloads.AppendLine(" {");
  88. foreach (var method in grouping.Methods)
  89. {
  90. var model = context.Compilation.GetSemanticModel(method.Syntax.SyntaxTree);
  91. overloads.AppendLine(GenerateOverload(method, options, model, attributeSymbol));
  92. }
  93. overloads.AppendLine(" }");
  94. overloads.AppendLine("}");
  95. return overloads.ToString();
  96. }
  97. private static string GenerateOverload(AsyncMethod method, GenerationOptions options, SemanticModel model, INamedTypeSymbol attributeSymbol)
  98. {
  99. var attributeListsWithGenerateAsyncOverloadRemoved = SyntaxFactory.List(method.Syntax.AttributeLists
  100. .Select(list => AttributeList(SeparatedList(
  101. (from a in list.Attributes
  102. let am = model.GetSymbolInfo(a.Name).Symbol?.ContainingType
  103. where !SymbolEqualityComparer.Default.Equals(am, attributeSymbol)
  104. select a))))
  105. .Where(list => list.Attributes.Count > 0));
  106. return MethodDeclaration(method.Syntax.ReturnType, GetMethodNameForGeneratedAsyncMethod(method.Symbol, options))
  107. .WithAttributeLists(attributeListsWithGenerateAsyncOverloadRemoved)
  108. .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)))
  109. .WithTypeParameterList(method.Syntax.TypeParameterList)
  110. .WithParameterList(method.Syntax.ParameterList)
  111. .WithConstraintClauses(method.Syntax.ConstraintClauses)
  112. .WithExpressionBody(ArrowExpressionClause(
  113. InvocationExpression(
  114. IdentifierName(method.Symbol.Name),
  115. ArgumentList(
  116. SeparatedList(
  117. method.Syntax.ParameterList.Parameters
  118. .Select(p => Argument(IdentifierName(p.Identifier))))))))
  119. .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
  120. .WithLeadingTrivia(method.Syntax.GetLeadingTrivia().Where(t => !t.IsKind(SyntaxKind.DisabledTextTrivia) && t.GetStructure() is not DirectiveTriviaSyntax))
  121. .NormalizeWhitespace()
  122. .ToFullString();
  123. }
  124. private static INamedTypeSymbol GetAsyncOverloadAttributeSymbol(GeneratorExecutionContext context)
  125. => context.Compilation.GetTypeByMetadataName("System.Linq.GenerateAsyncOverloadAttribute") ?? throw new InvalidOperationException();
  126. private static INamedTypeSymbol GetDuplicateAsyncEnumerableAsAsyncEnumerableDeprecatedAttributeSymbol(GeneratorExecutionContext context)
  127. => context.Compilation.GetTypeByMetadataName("System.Linq.DuplicateAsyncEnumerableAsAsyncEnumerableDeprecatedAttribute") ?? throw new InvalidOperationException();
  128. private static IEnumerable<AsyncMethodGrouping> GetMethodsGroupedBySyntaxTree(GeneratorExecutionContext context, SyntaxReceiver syntaxReceiver, INamedTypeSymbol attributeSymbol)
  129. => from candidate in syntaxReceiver.CandidateMethods
  130. group candidate by candidate.SyntaxTree into grouping
  131. let model = context.Compilation.GetSemanticModel(grouping.Key)
  132. select new AsyncMethodGrouping(
  133. grouping.Key,
  134. from methodSyntax in grouping
  135. let methodSymbol = model.GetDeclaredSymbol(methodSyntax) ?? throw new InvalidOperationException()
  136. where methodSymbol.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass!, attributeSymbol))
  137. select new AsyncMethod(methodSymbol, methodSyntax));
  138. internal static string GetMethodNameForGeneratedAsyncMethod(IMethodSymbol methodSymbol, GenerationOptions options)
  139. {
  140. var methodName = methodSymbol.Name.Replace("Core", "");
  141. return options.SupportFlatAsyncApi
  142. ? methodName.Replace("Await", "").Replace("WithCancellation", "")
  143. : methodName;
  144. }
  145. }
  146. }