|
@@ -25,77 +25,92 @@ namespace System.Linq.Async.SourceGenerator
|
|
|
public void Initialize(GeneratorInitializationContext context)
|
|
public void Initialize(GeneratorInitializationContext context)
|
|
|
{
|
|
{
|
|
|
context.RegisterForSyntaxNotifications(() => new SyntaxReceiver());
|
|
context.RegisterForSyntaxNotifications(() => new SyntaxReceiver());
|
|
|
- context.RegisterForPostInitialization(c => c.AddSource("Attribute.cs", AttributeSource));
|
|
|
|
|
|
|
+ context.RegisterForPostInitialization(c => c.AddSource("GenerateAsyncOverloadAttribute", AttributeSource));
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
public void Execute(GeneratorExecutionContext context)
|
|
public void Execute(GeneratorExecutionContext context)
|
|
|
{
|
|
{
|
|
|
if (context.SyntaxReceiver is not SyntaxReceiver syntaxReceiver) return;
|
|
if (context.SyntaxReceiver is not SyntaxReceiver syntaxReceiver) return;
|
|
|
|
|
|
|
|
- var supportFlatAsyncApi = context.ParseOptions.PreprocessorSymbolNames.Contains("SUPPORT_FLAT_ASYNC_API");
|
|
|
|
|
- var attributeSymbol = context.Compilation.GetTypeByMetadataName("System.Linq.GenerateAsyncOverloadAttribute");
|
|
|
|
|
-
|
|
|
|
|
- foreach (var grouping in syntaxReceiver.Candidates.GroupBy(c => c.SyntaxTree))
|
|
|
|
|
- {
|
|
|
|
|
- var model = context.Compilation.GetSemanticModel(grouping.Key);
|
|
|
|
|
- var methodsBuilder = new StringBuilder();
|
|
|
|
|
-
|
|
|
|
|
- foreach (var candidate in grouping)
|
|
|
|
|
- {
|
|
|
|
|
- var methodSymbol = model.GetDeclaredSymbol(candidate) ?? throw new NullReferenceException();
|
|
|
|
|
-
|
|
|
|
|
- if (!methodSymbol.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass!, attributeSymbol))) continue;
|
|
|
|
|
-
|
|
|
|
|
- var shortName = methodSymbol.Name.Replace("Core", "");
|
|
|
|
|
- if (supportFlatAsyncApi)
|
|
|
|
|
- {
|
|
|
|
|
- shortName = shortName.Replace("Await", "").Replace("WithCancellation", "");
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- var publicMethod = MethodDeclaration(candidate.ReturnType, shortName)
|
|
|
|
|
- .WithModifiers(TokenList(Token(TriviaList(), SyntaxKind.PublicKeyword, TriviaList(Space)), Token(TriviaList(), SyntaxKind.StaticKeyword, TriviaList(Space))))
|
|
|
|
|
- .WithTypeParameterList(candidate.TypeParameterList)
|
|
|
|
|
- .WithParameterList(candidate.ParameterList)
|
|
|
|
|
- .WithConstraintClauses(candidate.ConstraintClauses)
|
|
|
|
|
- .WithExpressionBody(ArrowExpressionClause(InvocationExpression(IdentifierName(methodSymbol.Name), ArgumentList(SeparatedList(candidate.ParameterList.Parameters.Select(p => Argument(IdentifierName(p.Identifier))))))))
|
|
|
|
|
- .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
|
|
|
|
|
- .WithLeadingTrivia(candidate.GetLeadingTrivia().Where(t => t.GetStructure() is not DirectiveTriviaSyntax));
|
|
|
|
|
-
|
|
|
|
|
- methodsBuilder.AppendLine(publicMethod.ToFullString());
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- if (methodsBuilder.Length == 0) continue;
|
|
|
|
|
-
|
|
|
|
|
- var usings = grouping.Key.GetRoot() is CompilationUnitSyntax compilationUnit
|
|
|
|
|
- ? compilationUnit.Usings
|
|
|
|
|
- : List<UsingDirectiveSyntax>();
|
|
|
|
|
-
|
|
|
|
|
- var overloads = new StringBuilder();
|
|
|
|
|
- overloads.AppendLine("#nullable enable");
|
|
|
|
|
- overloads.AppendLine(usings.ToString());
|
|
|
|
|
- overloads.AppendLine("namespace System.Linq");
|
|
|
|
|
- overloads.AppendLine("{");
|
|
|
|
|
- overloads.AppendLine(" partial class AsyncEnumerable");
|
|
|
|
|
- overloads.AppendLine(" {");
|
|
|
|
|
- overloads.AppendLine(methodsBuilder.ToString());
|
|
|
|
|
- overloads.AppendLine(" }");
|
|
|
|
|
- overloads.AppendLine("}");
|
|
|
|
|
-
|
|
|
|
|
- context.AddSource($"{Path.GetFileNameWithoutExtension(grouping.Key.FilePath)}.AsyncOverloads.cs", overloads.ToString());
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ var options = GetGenerationOptions(context);
|
|
|
|
|
+ var methodsBySyntaxTree = GetMethodsGroupedBySyntaxTree(context, syntaxReceiver);
|
|
|
|
|
+
|
|
|
|
|
+ foreach (var grouping in methodsBySyntaxTree)
|
|
|
|
|
+ context.AddSource(
|
|
|
|
|
+ $"{Path.GetFileNameWithoutExtension(grouping.SyntaxTree.FilePath)}.AsyncOverloads",
|
|
|
|
|
+ GenerateOverloads(grouping, options));
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private static GenerationOptions GetGenerationOptions(GeneratorExecutionContext context)
|
|
|
|
|
+ => new(SupportFlatAsyncApi: context.ParseOptions.PreprocessorSymbolNames.Contains("SUPPORT_FLAT_ASYNC_API"));
|
|
|
|
|
+
|
|
|
|
|
+ private static IEnumerable<AsyncMethodGrouping> GetMethodsGroupedBySyntaxTree(GeneratorExecutionContext context, SyntaxReceiver syntaxReceiver)
|
|
|
|
|
+ => GetMethodsGroupedBySyntaxTree(
|
|
|
|
|
+ context,
|
|
|
|
|
+ syntaxReceiver,
|
|
|
|
|
+ GetAsyncOverloadAttributeSymbol(context));
|
|
|
|
|
+
|
|
|
|
|
+ private static string GenerateOverloads(AsyncMethodGrouping grouping, GenerationOptions options)
|
|
|
|
|
+ {
|
|
|
|
|
+ var usings = grouping.SyntaxTree.GetRoot() is CompilationUnitSyntax compilationUnit
|
|
|
|
|
+ ? compilationUnit.Usings.ToString()
|
|
|
|
|
+ : string.Empty;
|
|
|
|
|
+
|
|
|
|
|
+ var overloads = new StringBuilder();
|
|
|
|
|
+ overloads.AppendLine("#nullable enable");
|
|
|
|
|
+ overloads.AppendLine(usings);
|
|
|
|
|
+ overloads.AppendLine("namespace System.Linq");
|
|
|
|
|
+ overloads.AppendLine("{");
|
|
|
|
|
+ overloads.AppendLine(" partial class AsyncEnumerable");
|
|
|
|
|
+ overloads.AppendLine(" {");
|
|
|
|
|
+
|
|
|
|
|
+ foreach (var method in grouping.Methods)
|
|
|
|
|
+ overloads.AppendLine(GenerateOverload(method, options));
|
|
|
|
|
+
|
|
|
|
|
+ overloads.AppendLine(" }");
|
|
|
|
|
+ overloads.AppendLine("}");
|
|
|
|
|
+
|
|
|
|
|
+ return overloads.ToString();
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- private sealed class SyntaxReceiver : ISyntaxReceiver
|
|
|
|
|
|
|
+ private static string GenerateOverload(AsyncMethod method, GenerationOptions options)
|
|
|
|
|
+ => MethodDeclaration(method.Syntax.ReturnType, GetMethodName(method.Symbol, options))
|
|
|
|
|
+ .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)))
|
|
|
|
|
+ .WithTypeParameterList(method.Syntax.TypeParameterList)
|
|
|
|
|
+ .WithParameterList(method.Syntax.ParameterList)
|
|
|
|
|
+ .WithConstraintClauses(method.Syntax.ConstraintClauses)
|
|
|
|
|
+ .WithExpressionBody(ArrowExpressionClause(
|
|
|
|
|
+ InvocationExpression(
|
|
|
|
|
+ IdentifierName(method.Symbol.Name),
|
|
|
|
|
+ ArgumentList(
|
|
|
|
|
+ SeparatedList(
|
|
|
|
|
+ method.Syntax.ParameterList.Parameters
|
|
|
|
|
+ .Select(p => Argument(IdentifierName(p.Identifier))))))))
|
|
|
|
|
+ .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
|
|
|
|
|
+ .WithLeadingTrivia(method.Syntax.GetLeadingTrivia().Where(t => t.GetStructure() is not DirectiveTriviaSyntax))
|
|
|
|
|
+ .NormalizeWhitespace()
|
|
|
|
|
+ .ToFullString();
|
|
|
|
|
+
|
|
|
|
|
+ private static INamedTypeSymbol GetAsyncOverloadAttributeSymbol(GeneratorExecutionContext context)
|
|
|
|
|
+ => context.Compilation.GetTypeByMetadataName("System.Linq.GenerateAsyncOverloadAttribute") ?? throw new InvalidOperationException();
|
|
|
|
|
+
|
|
|
|
|
+ private static IEnumerable<AsyncMethodGrouping> GetMethodsGroupedBySyntaxTree(GeneratorExecutionContext context, SyntaxReceiver syntaxReceiver, INamedTypeSymbol attributeSymbol)
|
|
|
|
|
+ => from candidate in syntaxReceiver.Candidates
|
|
|
|
|
+ group candidate by candidate.SyntaxTree into grouping
|
|
|
|
|
+ let model = context.Compilation.GetSemanticModel(grouping.Key)
|
|
|
|
|
+ select new AsyncMethodGrouping(
|
|
|
|
|
+ grouping.Key,
|
|
|
|
|
+ from methodSyntax in grouping
|
|
|
|
|
+ let methodSymbol = model.GetDeclaredSymbol(methodSyntax) ?? throw new InvalidOperationException()
|
|
|
|
|
+ where methodSymbol.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass!, attributeSymbol))
|
|
|
|
|
+ select new AsyncMethod(methodSymbol, methodSyntax));
|
|
|
|
|
+
|
|
|
|
|
+ private static string GetMethodName(IMethodSymbol methodSymbol, GenerationOptions options)
|
|
|
{
|
|
{
|
|
|
- public IList<MethodDeclarationSyntax> Candidates { get; } = new List<MethodDeclarationSyntax>();
|
|
|
|
|
-
|
|
|
|
|
- public void OnVisitSyntaxNode(SyntaxNode syntaxNode)
|
|
|
|
|
- {
|
|
|
|
|
- if (syntaxNode is MethodDeclarationSyntax { AttributeLists: { Count: >0 } } methodDeclarationSyntax)
|
|
|
|
|
- {
|
|
|
|
|
- Candidates.Add(methodDeclarationSyntax);
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ var methodName = methodSymbol.Name.Replace("Core", "");
|
|
|
|
|
+ return options.SupportFlatAsyncApi
|
|
|
|
|
+ ? methodName.Replace("Await", "").Replace("WithCancellation", "")
|
|
|
|
|
+ : methodName;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|