Browse Source

Fix minimal API validation for record structs (#64514)

* Initial plan

* Fix minimal API validation for record structs by renaming IsClass to IsComplexType

Co-authored-by: captainsafia <[email protected]>

---------

Co-authored-by: copilot-swe-agent[bot] <[email protected]>
Co-authored-by: captainsafia <[email protected]>
Copilot 3 months ago
parent
commit
cb092d8ad2

+ 6 - 4
src/Validation/src/RuntimeValidatableParameterInfoResolver.cs

@@ -43,7 +43,7 @@ internal sealed class RuntimeValidatableParameterInfoResolver : IValidatableInfo
         // If there are no validation attributes and this type is not a complex type
         // we don't need to validate it. Complex types without attributes are still
         // validatable because we want to run the validations on the properties.
-        if (validationAttributes.Length == 0 && !IsClass(parameterInfo.ParameterType))
+        if (validationAttributes.Length == 0 && !IsComplexType(parameterInfo.ParameterType))
         {
             validatableInfo = null;
             return false;
@@ -80,7 +80,7 @@ internal sealed class RuntimeValidatableParameterInfoResolver : IValidatableInfo
         private readonly ValidationAttribute[] _validationAttributes = validationAttributes;
     }
 
-    private static bool IsClass(Type type)
+    private static bool IsComplexType(Type type)
     {
         // Skip primitives, enums, common built-in types, and types that are specially
         // handled by RDF/RDG that don't need validation if they don't have attributes
@@ -105,9 +105,11 @@ internal sealed class RuntimeValidatableParameterInfoResolver : IValidatableInfo
         // Check if the underlying type in a nullable is valid
         if (Nullable.GetUnderlyingType(type) is { } nullableType)
         {
-            return IsClass(nullableType);
+            return IsComplexType(nullableType);
         }
 
-        return type.IsClass;
+        // Complex types include both reference types (classes) and value types (structs, record structs)
+        // that aren't in the exclusion list above
+        return type.IsClass || type.IsValueType;
     }
 }

+ 142 - 0
src/Validation/test/Microsoft.Extensions.Validation.GeneratorTests/ValidationsGenerator.RecordType.cs

@@ -384,4 +384,146 @@ public record ValidatableRecord(
         });
 
     }
+
+    [Fact]
+    public async Task CanValidateRecordStructTypes()
+    {
+        // Arrange
+        var source = """
+using System;
+using System.ComponentModel.DataAnnotations;
+using System.Collections.Generic;
+using System.Threading.Tasks;
+using Microsoft.AspNetCore.Builder;
+using Microsoft.AspNetCore.Http;
+using Microsoft.Extensions.Validation;
+using Microsoft.AspNetCore.Routing;
+using Microsoft.Extensions.DependencyInjection;
+
+var builder = WebApplication.CreateBuilder();
+
+builder.Services.AddValidation();
+
+var app = builder.Build();
+
+app.MapPost("/validatable-record-struct", (ValidatableRecordStruct validatableRecordStruct) => Results.Ok("Passed"));
+
+app.Run();
+
+public record struct SubRecordStruct([Required] string RequiredProperty, [StringLength(10)] string? StringWithLength);
+
+public record struct ValidatableRecordStruct(
+    [Range(10, 100)]
+    int IntegerWithRange,
+    [Range(10, 100), Display(Name = "Valid identifier")]
+    int IntegerWithRangeAndDisplayName,
+    SubRecordStruct SubProperty
+);
+""";
+        await Verify(source, out var compilation);
+        await VerifyEndpoint(compilation, "/validatable-record-struct", async (endpoint, serviceProvider) =>
+        {
+            await InvalidIntegerWithRangeProducesError(endpoint);
+            await InvalidIntegerWithRangeAndDisplayNameProducesError(endpoint);
+            await InvalidSubPropertyProducesError(endpoint);
+            await ValidInputProducesNoWarnings(endpoint);
+
+            async Task InvalidIntegerWithRangeProducesError(Endpoint endpoint)
+            {
+                var payload = """
+                    {
+                        "IntegerWithRange": 5,
+                        "IntegerWithRangeAndDisplayName": 50,
+                        "SubProperty": {
+                            "RequiredProperty": "valid",
+                            "StringWithLength": "valid"
+                        }
+                    }
+                    """;
+                var context = CreateHttpContextWithPayload(payload, serviceProvider);
+
+                await endpoint.RequestDelegate(context);
+
+                var problemDetails = await AssertBadRequest(context);
+                Assert.Collection(problemDetails.Errors, kvp =>
+                {
+                    Assert.Equal("IntegerWithRange", kvp.Key);
+                    Assert.Equal("The field IntegerWithRange must be between 10 and 100.", kvp.Value.Single());
+                });
+            }
+
+            async Task InvalidIntegerWithRangeAndDisplayNameProducesError(Endpoint endpoint)
+            {
+                var payload = """
+                    {
+                        "IntegerWithRange": 50,
+                        "IntegerWithRangeAndDisplayName": 5,
+                        "SubProperty": {
+                            "RequiredProperty": "valid",
+                            "StringWithLength": "valid"
+                        }
+                    }
+                    """;
+                var context = CreateHttpContextWithPayload(payload, serviceProvider);
+
+                await endpoint.RequestDelegate(context);
+
+                var problemDetails = await AssertBadRequest(context);
+                Assert.Collection(problemDetails.Errors, kvp =>
+                {
+                    Assert.Equal("IntegerWithRangeAndDisplayName", kvp.Key);
+                    Assert.Equal("The field Valid identifier must be between 10 and 100.", kvp.Value.Single());
+                });
+            }
+
+            async Task InvalidSubPropertyProducesError(Endpoint endpoint)
+            {
+                var payload = """
+                    {
+                        "IntegerWithRange": 50,
+                        "IntegerWithRangeAndDisplayName": 50,
+                        "SubProperty": {
+                            "RequiredProperty": "",
+                            "StringWithLength": "way-too-long"
+                        }
+                    }
+                    """;
+                var context = CreateHttpContextWithPayload(payload, serviceProvider);
+
+                await endpoint.RequestDelegate(context);
+
+                var problemDetails = await AssertBadRequest(context);
+                Assert.Collection(problemDetails.Errors,
+                kvp =>
+                {
+                    Assert.Equal("SubProperty.RequiredProperty", kvp.Key);
+                    Assert.Equal("The RequiredProperty field is required.", kvp.Value.Single());
+                },
+                kvp =>
+                {
+                    Assert.Equal("SubProperty.StringWithLength", kvp.Key);
+                    Assert.Equal("The field StringWithLength must be a string with a maximum length of 10.", kvp.Value.Single());
+                });
+            }
+
+            async Task ValidInputProducesNoWarnings(Endpoint endpoint)
+            {
+                var payload = """
+                    {
+                        "IntegerWithRange": 50,
+                        "IntegerWithRangeAndDisplayName": 50,
+                        "SubProperty": {
+                            "RequiredProperty": "valid",
+                            "StringWithLength": "valid"
+                        }
+                    }
+                    """;
+                var context = CreateHttpContextWithPayload(payload, serviceProvider);
+                await endpoint.RequestDelegate(context);
+
+                Assert.Equal(200, context.Response.StatusCode);
+            }
+        });
+
+    }
 }

+ 222 - 0
src/Validation/test/Microsoft.Extensions.Validation.GeneratorTests/snapshots/ValidationsGeneratorTests.CanValidateRecordStructTypes#ValidatableInfoResolver.g.verified.cs

@@ -0,0 +1,222 @@
+//HintName: ValidatableInfoResolver.g.cs
+#nullable enable annotations
+//------------------------------------------------------------------------------
+// <auto-generated>
+//     This code was generated by a tool.
+//
+//     Changes to this file may cause incorrect behavior and will be lost if
+//     the code is regenerated.
+// </auto-generated>
+//------------------------------------------------------------------------------
+#nullable enable
+#pragma warning disable ASP0029
+
+namespace System.Runtime.CompilerServices
+{
+    [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Validation.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")]
+    [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)]
+    file sealed class InterceptsLocationAttribute : System.Attribute
+    {
+        public InterceptsLocationAttribute(int version, string data)
+        {
+        }
+    }
+}
+
+namespace Microsoft.Extensions.Validation.Generated
+{
+    [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Validation.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")]
+    file sealed class GeneratedValidatablePropertyInfo : global::Microsoft.Extensions.Validation.ValidatablePropertyInfo
+    {
+        public GeneratedValidatablePropertyInfo(
+            [param: global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembers(global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicProperties | global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicConstructors)]
+            global::System.Type containingType,
+            global::System.Type propertyType,
+            string name,
+            string displayName) : base(containingType, propertyType, name, displayName)
+        {
+            ContainingType = containingType;
+            Name = name;
+        }
+
+        [global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembers(global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicProperties | global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicConstructors)]
+        internal global::System.Type ContainingType { get; }
+        internal string Name { get; }
+
+        protected override global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes()
+            => ValidationAttributeCache.GetPropertyValidationAttributes(ContainingType, Name);
+    }
+
+    [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Validation.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")]
+    file sealed class GeneratedValidatableTypeInfo : global::Microsoft.Extensions.Validation.ValidatableTypeInfo
+    {
+        public GeneratedValidatableTypeInfo(
+            [param: global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembers(global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.Interfaces)]
+            global::System.Type type,
+            ValidatablePropertyInfo[] members) : base(type, members)
+        {
+            Type = type;
+        }
+
+        [global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembers(global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.Interfaces)]
+        internal global::System.Type Type { get; }
+
+        protected override global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes()
+            => ValidationAttributeCache.GetTypeValidationAttributes(Type);
+    }
+
+    [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Validation.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")]
+    file class GeneratedValidatableInfoResolver : global::Microsoft.Extensions.Validation.IValidatableInfoResolver
+    {
+        public bool TryGetValidatableTypeInfo(global::System.Type type, [global::System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out global::Microsoft.Extensions.Validation.IValidatableInfo? validatableInfo)
+        {
+            validatableInfo = null;
+            if (type == typeof(global::SubRecordStruct))
+            {
+                validatableInfo = new GeneratedValidatableTypeInfo(
+                    type: typeof(global::SubRecordStruct),
+                    members: [
+                        new GeneratedValidatablePropertyInfo(
+                            containingType: typeof(global::SubRecordStruct),
+                            propertyType: typeof(string),
+                            name: "RequiredProperty",
+                            displayName: "RequiredProperty"
+                        ),
+                        new GeneratedValidatablePropertyInfo(
+                            containingType: typeof(global::SubRecordStruct),
+                            propertyType: typeof(string),
+                            name: "StringWithLength",
+                            displayName: "StringWithLength"
+                        ),
+                    ]
+                );
+                return true;
+            }
+            if (type == typeof(global::ValidatableRecordStruct))
+            {
+                validatableInfo = new GeneratedValidatableTypeInfo(
+                    type: typeof(global::ValidatableRecordStruct),
+                    members: [
+                        new GeneratedValidatablePropertyInfo(
+                            containingType: typeof(global::ValidatableRecordStruct),
+                            propertyType: typeof(int),
+                            name: "IntegerWithRange",
+                            displayName: "IntegerWithRange"
+                        ),
+                        new GeneratedValidatablePropertyInfo(
+                            containingType: typeof(global::ValidatableRecordStruct),
+                            propertyType: typeof(int),
+                            name: "IntegerWithRangeAndDisplayName",
+                            displayName: "Valid identifier"
+                        ),
+                        new GeneratedValidatablePropertyInfo(
+                            containingType: typeof(global::ValidatableRecordStruct),
+                            propertyType: typeof(global::SubRecordStruct),
+                            name: "SubProperty",
+                            displayName: "SubProperty"
+                        ),
+                    ]
+                );
+                return true;
+            }
+
+            return false;
+        }
+
+        // No-ops, rely on runtime code for ParameterInfo-based resolution
+        public bool TryGetValidatableParameterInfo(global::System.Reflection.ParameterInfo parameterInfo, [global::System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out global::Microsoft.Extensions.Validation.IValidatableInfo? validatableInfo)
+        {
+            validatableInfo = null;
+            return false;
+        }
+    }
+
+    [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Validation.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")]
+    file static class GeneratedServiceCollectionExtensions
+    {
+        [InterceptsLocation]
+        public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddValidation(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::System.Action<global::Microsoft.Extensions.Validation.ValidationOptions>? configureOptions = null)
+        {
+            // Use non-extension method to avoid infinite recursion.
+            return global::Microsoft.Extensions.DependencyInjection.ValidationServiceCollectionExtensions.AddValidation(services, options =>
+            {
+                options.Resolvers.Insert(0, new GeneratedValidatableInfoResolver());
+                if (configureOptions is not null)
+                {
+                    configureOptions(options);
+                }
+            });
+        }
+    }
+
+    [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Validation.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")]
+    file static class ValidationAttributeCache
+    {
+        private sealed record CacheKey(
+            [param: global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembers(global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicProperties | global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicConstructors)]
+            [property: global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembers(global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicProperties | global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicConstructors)]
+            global::System.Type ContainingType,
+            string PropertyName);
+        private static readonly global::System.Collections.Concurrent.ConcurrentDictionary<CacheKey, global::System.ComponentModel.DataAnnotations.ValidationAttribute[]> _propertyCache = new();
+        private static readonly global::System.Lazy<global::System.Collections.Concurrent.ConcurrentDictionary<global::System.Type, global::System.ComponentModel.DataAnnotations.ValidationAttribute[]>> _lazyTypeCache = new (() => new ());
+        private static global::System.Collections.Concurrent.ConcurrentDictionary<global::System.Type, global::System.ComponentModel.DataAnnotations.ValidationAttribute[]> TypeCache => _lazyTypeCache.Value;
+
+        public static global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetPropertyValidationAttributes(
+            [global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembers(global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicProperties | global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicConstructors)]
+            global::System.Type containingType,
+            string propertyName)
+        {
+            var key = new CacheKey(containingType, propertyName);
+            return _propertyCache.GetOrAdd(key, static k =>
+            {
+                var results = new global::System.Collections.Generic.List<global::System.ComponentModel.DataAnnotations.ValidationAttribute>();
+
+                // Get attributes from the property
+                var property = k.ContainingType.GetProperty(k.PropertyName);
+                if (property != null)
+                {
+                    var propertyAttributes = global::System.Reflection.CustomAttributeExtensions
+                        .GetCustomAttributes<global::System.ComponentModel.DataAnnotations.ValidationAttribute>(property, inherit: true);
+
+                    results.AddRange(propertyAttributes);
+                }
+
+                // Check constructors for parameters that match the property name
+                // to handle record scenarios
+                foreach (var constructor in k.ContainingType.GetConstructors())
+                {
+                    // Look for parameter with matching name (case insensitive)
+                    var parameter = global::System.Linq.Enumerable.FirstOrDefault(
+                        constructor.GetParameters(),
+                        p => string.Equals(p.Name, k.PropertyName, global::System.StringComparison.OrdinalIgnoreCase));
+
+                    if (parameter != null)
+                    {
+                        var paramAttributes = global::System.Reflection.CustomAttributeExtensions
+                            .GetCustomAttributes<global::System.ComponentModel.DataAnnotations.ValidationAttribute>(parameter, inherit: true);
+
+                        results.AddRange(paramAttributes);
+
+                        break;
+                    }
+                }
+
+                return results.ToArray();
+            });
+        }
+
+
+        public static global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetTypeValidationAttributes(
+            [global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembers(global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.Interfaces)]
+            global::System.Type type
+        )
+        {
+            return TypeCache.GetOrAdd(type, static t =>
+            {
+                var typeAttributes = global::System.Reflection.CustomAttributeExtensions
+                        .GetCustomAttributes<global::System.ComponentModel.DataAnnotations.ValidationAttribute>(t, inherit: true);
+                return global::System.Linq.Enumerable.ToArray(typeAttributes);
+            });
+        }
+    }
+}