Browse Source

Add parameterless RequireCors overload (#47341)

Safia Abdalla 3 years ago
parent
commit
414e3eb9c5

+ 18 - 9
src/Middleware/CORS/src/Infrastructure/CorsEndpointConventionBuilderExtensions.cs

@@ -11,6 +11,22 @@ namespace Microsoft.AspNetCore.Builder;
 /// </summary>
 public static class CorsEndpointConventionBuilderExtensions
 {
+    /// <summary>
+    /// Adds a CORS policy with the default policy name to the endpoint(s).
+    /// </summary>
+    /// <param name="builder">The endpoint convention builder.</param>
+    /// <returns>The original convention builder parameter.</returns>
+    public static TBuilder RequireCors<TBuilder>(this TBuilder builder) where TBuilder : IEndpointConventionBuilder
+    {
+        ArgumentNullException.ThrowIfNull(builder);
+
+        builder.Add(endpointBuilder =>
+        {
+            endpointBuilder.Metadata.Add(new EnableCorsAttribute());
+        });
+        return builder;
+    }
+
     /// <summary>
     /// Adds a CORS policy with the specified name to the endpoint(s).
     /// </summary>
@@ -19,10 +35,7 @@ public static class CorsEndpointConventionBuilderExtensions
     /// <returns>The original convention builder parameter.</returns>
     public static TBuilder RequireCors<TBuilder>(this TBuilder builder, string policyName) where TBuilder : IEndpointConventionBuilder
     {
-        if (builder == null)
-        {
-            throw new ArgumentNullException(nameof(builder));
-        }
+        ArgumentNullException.ThrowIfNull(builder);
 
         builder.Add(endpointBuilder =>
         {
@@ -39,11 +52,7 @@ public static class CorsEndpointConventionBuilderExtensions
     /// <returns>The original convention builder parameter.</returns>
     public static TBuilder RequireCors<TBuilder>(this TBuilder builder, Action<CorsPolicyBuilder> configurePolicy) where TBuilder : IEndpointConventionBuilder
     {
-        if (builder == null)
-        {
-            throw new ArgumentNullException(nameof(builder));
-        }
-
+        ArgumentNullException.ThrowIfNull(builder);
         ArgumentNullException.ThrowIfNull(configurePolicy);
 
         var policyBuilder = new CorsPolicyBuilder();

+ 1 - 0
src/Middleware/CORS/src/PublicAPI.Unshipped.txt

@@ -1 +1,2 @@
 #nullable enable
+static Microsoft.AspNetCore.Builder.CorsEndpointConventionBuilderExtensions.RequireCors<TBuilder>(this TBuilder builder) -> TBuilder

+ 21 - 0
src/Middleware/CORS/test/UnitTests/CorsEndpointConventionBuilderExtensionsTests.cs

@@ -51,6 +51,27 @@ public class CorsEndpointConventionBuilderExtensionsTests
         Assert.True(metadata.Policy.AllowAnyOrigin);
     }
 
+    [Fact]
+    public void RequireCors_NoParameter_MetadataAdded()
+    {
+        // Arrange
+        var testConventionBuilder = new TestEndpointConventionBuilder();
+
+        // Act
+        testConventionBuilder.RequireCors();
+
+        // Assert
+        var addCorsPolicy = Assert.Single(testConventionBuilder.Conventions);
+
+        var endpointModel = new TestEndpointBuilder();
+        addCorsPolicy(endpointModel);
+        var endpoint = endpointModel.Build();
+
+        var metadata = endpoint.Metadata.GetMetadata<IEnableCorsAttribute>();
+        Assert.NotNull(metadata);
+        Assert.Null(metadata.PolicyName);
+    }
+
     [Fact]
     public void RequireCors_ChainedCall_ReturnedBuilderIsDerivedType()
     {