Selaa lähdekoodia

Refactor and split source generator into multiple files

Ruben Schmidmeister 4 vuotta sitten
vanhempi
sitoutus
5261258a4b

+ 7 - 0
Ix.NET/Source/System.Linq.Async.SourceGenerator/AsyncMethod.cs

@@ -0,0 +1,7 @@
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+
+namespace System.Linq.Async.SourceGenerator
+{
+    internal sealed record AsyncMethod(IMethodSymbol Symbol, MethodDeclarationSyntax Syntax);
+}

+ 8 - 0
Ix.NET/Source/System.Linq.Async.SourceGenerator/AsyncMethodGrouping.cs

@@ -0,0 +1,8 @@
+using System.Collections.Generic;
+
+using Microsoft.CodeAnalysis;
+
+namespace System.Linq.Async.SourceGenerator
+{
+    internal sealed record AsyncMethodGrouping(SyntaxTree SyntaxTree, IEnumerable<AsyncMethod> Methods);
+}

+ 77 - 62
Ix.NET/Source/System.Linq.Async.SourceGenerator/AsyncOverloadsGenerator.cs

@@ -25,77 +25,92 @@ namespace System.Linq.Async.SourceGenerator
         public void Initialize(GeneratorInitializationContext context)
         {
             context.RegisterForSyntaxNotifications(() => new SyntaxReceiver());
-            context.RegisterForPostInitialization(c => c.AddSource("Attribute.cs", AttributeSource));
+            context.RegisterForPostInitialization(c => c.AddSource("GenerateAsyncOverloadAttribute", AttributeSource));
         }
 
         public void Execute(GeneratorExecutionContext context)
         {
             if (context.SyntaxReceiver is not SyntaxReceiver syntaxReceiver) return;
 
-            var supportFlatAsyncApi = context.ParseOptions.PreprocessorSymbolNames.Contains("SUPPORT_FLAT_ASYNC_API");
-            var attributeSymbol = context.Compilation.GetTypeByMetadataName("System.Linq.GenerateAsyncOverloadAttribute");
-
-            foreach (var grouping in syntaxReceiver.Candidates.GroupBy(c => c.SyntaxTree))
-            {
-                var model = context.Compilation.GetSemanticModel(grouping.Key);
-                var methodsBuilder = new StringBuilder();
-
-                foreach (var candidate in grouping)
-                {
-                    var methodSymbol = model.GetDeclaredSymbol(candidate) ?? throw new NullReferenceException();
-
-                    if (!methodSymbol.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass!, attributeSymbol))) continue;
-
-                    var shortName = methodSymbol.Name.Replace("Core", "");
-                    if (supportFlatAsyncApi)
-                    {
-                        shortName = shortName.Replace("Await", "").Replace("WithCancellation", "");
-                    }
-
-                    var publicMethod = MethodDeclaration(candidate.ReturnType, shortName)
-                        .WithModifiers(TokenList(Token(TriviaList(), SyntaxKind.PublicKeyword, TriviaList(Space)), Token(TriviaList(), SyntaxKind.StaticKeyword, TriviaList(Space))))
-                        .WithTypeParameterList(candidate.TypeParameterList)
-                        .WithParameterList(candidate.ParameterList)
-                        .WithConstraintClauses(candidate.ConstraintClauses)
-                        .WithExpressionBody(ArrowExpressionClause(InvocationExpression(IdentifierName(methodSymbol.Name), ArgumentList(SeparatedList(candidate.ParameterList.Parameters.Select(p => Argument(IdentifierName(p.Identifier))))))))
-                        .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
-                        .WithLeadingTrivia(candidate.GetLeadingTrivia().Where(t => t.GetStructure() is not DirectiveTriviaSyntax));
-
-                    methodsBuilder.AppendLine(publicMethod.ToFullString());
-                }
-
-                if (methodsBuilder.Length == 0) continue;
-
-                var usings = grouping.Key.GetRoot() is CompilationUnitSyntax compilationUnit
-                    ? compilationUnit.Usings
-                    : List<UsingDirectiveSyntax>();
-
-                var overloads = new StringBuilder();
-                overloads.AppendLine("#nullable enable");
-                overloads.AppendLine(usings.ToString());
-                overloads.AppendLine("namespace System.Linq");
-                overloads.AppendLine("{");
-                overloads.AppendLine("    partial class AsyncEnumerable");
-                overloads.AppendLine("    {");
-                overloads.AppendLine(methodsBuilder.ToString());
-                overloads.AppendLine("    }");
-                overloads.AppendLine("}");
-
-                context.AddSource($"{Path.GetFileNameWithoutExtension(grouping.Key.FilePath)}.AsyncOverloads.cs", overloads.ToString());
-            }
+            var options = GetGenerationOptions(context);
+            var methodsBySyntaxTree = GetMethodsGroupedBySyntaxTree(context, syntaxReceiver);
+
+            foreach (var grouping in methodsBySyntaxTree)
+                context.AddSource(
+                    $"{Path.GetFileNameWithoutExtension(grouping.SyntaxTree.FilePath)}.AsyncOverloads",
+                    GenerateOverloads(grouping, options));
+        }
+
+        private static GenerationOptions GetGenerationOptions(GeneratorExecutionContext context)
+            => new(SupportFlatAsyncApi: context.ParseOptions.PreprocessorSymbolNames.Contains("SUPPORT_FLAT_ASYNC_API"));
+
+        private static IEnumerable<AsyncMethodGrouping> GetMethodsGroupedBySyntaxTree(GeneratorExecutionContext context, SyntaxReceiver syntaxReceiver)
+            => GetMethodsGroupedBySyntaxTree(
+                context,
+                syntaxReceiver,
+                GetAsyncOverloadAttributeSymbol(context));
+
+        private static string GenerateOverloads(AsyncMethodGrouping grouping, GenerationOptions options)
+        {
+            var usings = grouping.SyntaxTree.GetRoot() is CompilationUnitSyntax compilationUnit
+                ? compilationUnit.Usings.ToString()
+                : string.Empty;
+
+            var overloads = new StringBuilder();
+            overloads.AppendLine("#nullable enable");
+            overloads.AppendLine(usings);
+            overloads.AppendLine("namespace System.Linq");
+            overloads.AppendLine("{");
+            overloads.AppendLine("    partial class AsyncEnumerable");
+            overloads.AppendLine("    {");
+
+            foreach (var method in grouping.Methods)
+                overloads.AppendLine(GenerateOverload(method, options));
+
+            overloads.AppendLine("    }");
+            overloads.AppendLine("}");
+
+            return overloads.ToString();
         }
 
-        private sealed class SyntaxReceiver : ISyntaxReceiver
+        private static string GenerateOverload(AsyncMethod method, GenerationOptions options)
+            => MethodDeclaration(method.Syntax.ReturnType, GetMethodName(method.Symbol, options))
+                .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)))
+                .WithTypeParameterList(method.Syntax.TypeParameterList)
+                .WithParameterList(method.Syntax.ParameterList)
+                .WithConstraintClauses(method.Syntax.ConstraintClauses)
+                .WithExpressionBody(ArrowExpressionClause(
+                    InvocationExpression(
+                        IdentifierName(method.Symbol.Name),
+                        ArgumentList(
+                            SeparatedList(
+                                method.Syntax.ParameterList.Parameters
+                                    .Select(p => Argument(IdentifierName(p.Identifier))))))))
+                .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
+                .WithLeadingTrivia(method.Syntax.GetLeadingTrivia().Where(t => t.GetStructure() is not DirectiveTriviaSyntax))
+                .NormalizeWhitespace()
+                .ToFullString();
+
+        private static INamedTypeSymbol GetAsyncOverloadAttributeSymbol(GeneratorExecutionContext context)
+            => context.Compilation.GetTypeByMetadataName("System.Linq.GenerateAsyncOverloadAttribute") ?? throw new InvalidOperationException();
+
+        private static IEnumerable<AsyncMethodGrouping> GetMethodsGroupedBySyntaxTree(GeneratorExecutionContext context, SyntaxReceiver syntaxReceiver, INamedTypeSymbol attributeSymbol)
+            => from candidate in syntaxReceiver.Candidates
+               group candidate by candidate.SyntaxTree into grouping
+               let model = context.Compilation.GetSemanticModel(grouping.Key)
+               select new AsyncMethodGrouping(
+                  grouping.Key,
+                  from methodSyntax in grouping
+                  let methodSymbol = model.GetDeclaredSymbol(methodSyntax) ?? throw new InvalidOperationException()
+                  where methodSymbol.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass!, attributeSymbol))
+                  select new AsyncMethod(methodSymbol, methodSyntax));
+
+        private static string GetMethodName(IMethodSymbol methodSymbol, GenerationOptions options)
         {
-            public IList<MethodDeclarationSyntax> Candidates { get; } = new List<MethodDeclarationSyntax>();
-
-            public void OnVisitSyntaxNode(SyntaxNode syntaxNode)
-            {
-                if (syntaxNode is MethodDeclarationSyntax { AttributeLists: { Count: >0 } } methodDeclarationSyntax)
-                {
-                    Candidates.Add(methodDeclarationSyntax);
-                }
-            }
+            var methodName = methodSymbol.Name.Replace("Core", "");
+            return options.SupportFlatAsyncApi
+                ? methodName.Replace("Await", "").Replace("WithCancellation", "")
+                : methodName;
         }
     }
 }

+ 4 - 0
Ix.NET/Source/System.Linq.Async.SourceGenerator/GenerationOptions.cs

@@ -0,0 +1,4 @@
+namespace System.Linq.Async.SourceGenerator
+{
+    internal sealed record GenerationOptions(bool SupportFlatAsyncApi);
+}

+ 20 - 0
Ix.NET/Source/System.Linq.Async.SourceGenerator/SyntaxReceiver.cs

@@ -0,0 +1,20 @@
+using System.Collections.Generic;
+
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+
+namespace System.Linq.Async.SourceGenerator
+{
+    internal sealed class SyntaxReceiver : ISyntaxReceiver
+    {
+        public IList<MethodDeclarationSyntax> Candidates { get; } = new List<MethodDeclarationSyntax>();
+
+        public void OnVisitSyntaxNode(SyntaxNode syntaxNode)
+        {
+            if (syntaxNode is MethodDeclarationSyntax { AttributeLists: { Count: >0 } } methodDeclarationSyntax)
+            {
+                Candidates.Add(methodDeclarationSyntax);
+            }
+        }
+    }
+}