Browse Source

[Blazor] Add IPersistentComponentStateSerializer<T> interface for custom serialization extensibility (#62559)

* Implements support for custom serialization on the declarative persistent component model.
* Developers can register an instance of `builder.Services.AddSingleton<PersistentComponentStateSerializer<TData>, CustomSerializer>();` to handle serialization of persistent component state properties of that type.
Copilot 8 months ago
parent
commit
8883b98afe

+ 12 - 0
activate.sh

@@ -0,0 +1,12 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Buffers;
+
+namespace Microsoft.AspNetCore.Components;
+
+internal interface IPersistentComponentStateSerializer
+{
+    void Persist(Type type, object value, IBufferWriter<byte> writer);
+    object Restore(Type type, ReadOnlySequence<byte> data);
+}

+ 1 - 0
src/Components/Components/src/Microsoft.AspNetCore.Components.csproj

@@ -19,6 +19,7 @@
     <Compile Include="$(ComponentsSharedSourceRoot)src\HotReloadManager.cs" LinkBase="HotReload" />
     <Compile Include="$(ComponentsSharedSourceRoot)src\RootTypeCache.cs" LinkBase="Shared" />
     <Compile Include="$(SharedSourceRoot)LinkerFlags.cs" LinkBase="Shared" />
+    <Compile Include="$(SharedSourceRoot)PooledArrayBufferWriter.cs" LinkBase="Shared" />
     <Compile Include="$(SharedSourceRoot)QueryStringEnumerable.cs" LinkBase="Shared" />
     <Compile Include="$(SharedSourceRoot)Debugger\DictionaryItemDebugView.cs" LinkBase="Shared" />
     <Compile Include="$(SharedSourceRoot)Debugger\DictionaryDebugView.cs" LinkBase="Shared" />

+ 35 - 0
src/Components/Components/src/PersistentComponentState.cs

@@ -110,6 +110,28 @@ public class PersistentComponentState
         _currentState.Add(key, JsonSerializer.SerializeToUtf8Bytes(instance, type, JsonSerializerOptionsProvider.Options));
     }
 
+    /// <summary>
+    /// Persists the provided byte array under the given key.
+    /// </summary>
+    /// <param name="key">The key to use to persist the state.</param>
+    /// <param name="data">The byte array to persist.</param>
+    internal void PersistAsBytes(string key, byte[] data)
+    {
+        ArgumentNullException.ThrowIfNull(key);
+
+        if (!PersistingState)
+        {
+            throw new InvalidOperationException("Persisting state is only allowed during an OnPersisting callback.");
+        }
+
+        if (_currentState.ContainsKey(key))
+        {
+            throw new ArgumentException($"There is already a persisted object under the same key '{key}'");
+        }
+
+        _currentState.Add(key, data);
+    }
+
     /// <summary>
     /// Tries to retrieve the persisted state as JSON with the given <paramref name="key"/> and deserializes it into an
     /// instance of type <typeparamref name="TValue"/>.
@@ -155,6 +177,19 @@ public class PersistentComponentState
         }
     }
 
+    /// <summary>
+    /// Tries to retrieve the persisted state as raw bytes with the given <paramref name="key"/>.
+    /// When the key is present, the raw bytes are successfully returned via <paramref name="data"/>
+    /// and removed from the <see cref="PersistentComponentState"/>.
+    /// </summary>
+    /// <param name="key">The key used to persist the data.</param>
+    /// <param name="data">The persisted raw bytes.</param>
+    /// <returns><c>true</c> if the state was found; <c>false</c> otherwise.</returns>
+    internal bool TryTakeBytes(string key, [MaybeNullWhen(false)] out byte[]? data)
+    {
+        return TryTake(key, out data);
+    }
+
     private bool TryTake(string key, out byte[]? value)
     {
         ArgumentNullException.ThrowIfNull(key);

+ 40 - 0
src/Components/Components/src/PersistentComponentStateSerializer.cs

@@ -0,0 +1,40 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Buffers;
+
+namespace Microsoft.AspNetCore.Components;
+
+/// <summary>
+/// Provides custom serialization logic for persistent component state values of type <typeparamref name="T"/>.
+/// </summary>
+/// <typeparam name="T">The type of the value to serialize.</typeparam>
+public abstract class PersistentComponentStateSerializer<T> : IPersistentComponentStateSerializer
+{
+    /// <summary>
+    /// Serializes the provided <paramref name="value"/> and writes it to the <paramref name="writer"/>.
+    /// </summary>
+    /// <param name="value">The value to serialize.</param>
+    /// <param name="writer">The buffer writer to write the serialized data to.</param>
+    public abstract void Persist(T value, IBufferWriter<byte> writer);
+
+    /// <summary>
+    /// Deserializes a value of type <typeparamref name="T"/> from the provided <paramref name="data"/>.
+    /// This method must be synchronous to avoid UI tearing during component state restoration.
+    /// </summary>
+    /// <param name="data">The serialized data to deserialize.</param>
+    /// <returns>The deserialized value.</returns>
+    public abstract T Restore(ReadOnlySequence<byte> data);
+
+    /// <summary>
+    /// Explicit interface implementation for non-generic serialization.
+    /// </summary>
+    void IPersistentComponentStateSerializer.Persist(Type type, object value, IBufferWriter<byte> writer)
+        => Persist((T)value, writer);
+
+    /// <summary>
+    /// Explicit interface implementation for non-generic deserialization.
+    /// </summary>
+    object IPersistentComponentStateSerializer.Restore(Type type, ReadOnlySequence<byte> data)
+        => Restore(data)!;
+}

+ 84 - 1
src/Components/Components/src/PersistentStateValueProvider.cs

@@ -15,10 +15,11 @@ using Microsoft.AspNetCore.Internal;
 
 namespace Microsoft.AspNetCore.Components.Infrastructure;
 
-internal sealed class PersistentStateValueProvider(PersistentComponentState state) : ICascadingValueSupplier
+internal sealed class PersistentStateValueProvider(PersistentComponentState state, IServiceProvider serviceProvider) : ICascadingValueSupplier
 {
     private static readonly ConcurrentDictionary<(string, string, string), byte[]> _keyCache = new();
     private static readonly ConcurrentDictionary<(Type, string), PropertyGetter> _propertyGetterCache = new();
+    private readonly ConcurrentDictionary<Type, IPersistentComponentStateSerializer?> _serializerCache = new();
 
     private readonly Dictionary<ComponentState, PersistingComponentStateSubscription> _subscriptions = [];
 
@@ -42,6 +43,20 @@ internal sealed class PersistentStateValueProvider(PersistentComponentState stat
         var componentState = (ComponentState)key!;
         var storageKey = ComputeKey(componentState, parameterInfo.PropertyName);
 
+        // Try to get a custom serializer for this type first
+        var customSerializer = _serializerCache.GetOrAdd(parameterInfo.PropertyType, SerializerFactory);
+        
+        if (customSerializer != null)
+        {
+            if (state.TryTakeBytes(storageKey, out var data))
+            {
+                var sequence = new ReadOnlySequence<byte>(data!);
+                return customSerializer.Restore(parameterInfo.PropertyType, sequence);
+            }
+            return null;
+        }
+
+        // Fallback to JSON serialization
         return state.TryTakeFromJson(storageKey, parameterInfo.PropertyType, out var value) ? value : null;
     }
 
@@ -52,6 +67,10 @@ internal sealed class PersistentStateValueProvider(PersistentComponentState stat
     {
         var propertyName = parameterInfo.PropertyName;
         var propertyType = parameterInfo.PropertyType;
+        
+        // Resolve serializer outside the lambda
+        var customSerializer = _serializerCache.GetOrAdd(propertyType, SerializerFactory);
+        
         _subscriptions[subscriber] = state.RegisterOnPersisting(() =>
             {
                 var storageKey = ComputeKey(subscriber, propertyName);
@@ -61,6 +80,16 @@ internal sealed class PersistentStateValueProvider(PersistentComponentState stat
                 {
                     return Task.CompletedTask;
                 }
+
+                if (customSerializer != null)
+                {
+                    using var writer = new PooledArrayBufferWriter<byte>();
+                    customSerializer.Persist(propertyType, property, writer);
+                    state.PersistAsBytes(storageKey, writer.WrittenMemory.ToArray());
+                    return Task.CompletedTask;
+                }
+
+                // Fallback to JSON serialization
                 state.PersistAsJson(storageKey, property, propertyType);
                 return Task.CompletedTask;
             }, subscriber.Renderer.GetComponentRenderMode(subscriber.Component));
@@ -71,6 +100,15 @@ internal sealed class PersistentStateValueProvider(PersistentComponentState stat
         return _propertyGetterCache.GetOrAdd((type, propertyName), PropertyGetterFactory);
     }
 
+    private IPersistentComponentStateSerializer? SerializerFactory(Type type)
+    {
+        var serializerType = typeof(PersistentComponentStateSerializer<>).MakeGenericType(type);
+        var serializer = serviceProvider.GetService(serializerType);
+        
+        // The generic class now inherits from the internal interface, so we can cast directly
+        return serializer as IPersistentComponentStateSerializer;
+    }
+
     [UnconditionalSuppressMessage(
     "Trimming",
     "IL2077:Target parameter argument does not satisfy 'DynamicallyAccessedMembersAttribute' in call to target method. The source field does not have matching annotations.",
@@ -281,4 +319,49 @@ internal sealed class PersistentStateValueProvider(PersistentComponentState stat
 
         return result;
     }
+
+    /// <summary>
+    /// Serializes <paramref name="instance"/> using the provided <paramref name="serializer"/> and persists it under the given <paramref name="key"/>.
+    /// </summary>
+    /// <typeparam name="TValue">The <paramref name="instance"/> type.</typeparam>
+    /// <param name="key">The key to use to persist the state.</param>
+    /// <param name="instance">The instance to persist.</param>
+    /// <param name="serializer">The custom serializer to use for serialization.</param>
+    internal void PersistAsync<TValue>(string key, TValue instance, PersistentComponentStateSerializer<TValue> serializer)
+    {
+        ArgumentNullException.ThrowIfNull(key);
+        ArgumentNullException.ThrowIfNull(serializer);
+
+        using var writer = new PooledArrayBufferWriter<byte>();
+        serializer.Persist(instance, writer);
+        state.PersistAsBytes(key, writer.WrittenMemory.ToArray());
+    }
+
+    /// <summary>
+    /// Tries to retrieve the persisted state with the given <paramref name="key"/> and deserializes it using the provided <paramref name="serializer"/> into an
+    /// instance of type <typeparamref name="TValue"/>.
+    /// When the key is present, the state is successfully returned via <paramref name="instance"/>
+    /// and removed from the <see cref="PersistentComponentState"/>.
+    /// </summary>
+    /// <param name="key">The key used to persist the instance.</param>
+    /// <param name="serializer">The custom serializer to use for deserialization.</param>
+    /// <param name="instance">The persisted instance.</param>
+    /// <returns><c>true</c> if the state was found; <c>false</c> otherwise.</returns>
+    internal bool TryTake<TValue>(string key, PersistentComponentStateSerializer<TValue> serializer, [MaybeNullWhen(false)] out TValue instance)
+    {
+        ArgumentNullException.ThrowIfNull(key);
+        ArgumentNullException.ThrowIfNull(serializer);
+
+        if (state.TryTakeBytes(key, out var data))
+        {
+            var sequence = new ReadOnlySequence<byte>(data!);
+            instance = serializer.Restore(sequence);
+            return true;
+        }
+        else
+        {
+            instance = default;
+            return false;
+        }
+    }
 }

+ 4 - 0
src/Components/Components/src/PublicAPI.Unshipped.txt

@@ -16,6 +16,10 @@ Microsoft.AspNetCore.Components.Infrastructure.RegisterPersistentComponentStateS
 Microsoft.AspNetCore.Components.PersistentStateAttribute
 Microsoft.AspNetCore.Components.PersistentStateAttribute.PersistentStateAttribute() -> void
 Microsoft.AspNetCore.Components.Infrastructure.PersistentStateProviderServiceCollectionExtensions
+Microsoft.AspNetCore.Components.PersistentComponentStateSerializer<T>
+Microsoft.AspNetCore.Components.PersistentComponentStateSerializer<T>.PersistentComponentStateSerializer() -> void
+abstract Microsoft.AspNetCore.Components.PersistentComponentStateSerializer<T>.Persist(T value, System.Buffers.IBufferWriter<byte>! writer) -> void
+abstract Microsoft.AspNetCore.Components.PersistentComponentStateSerializer<T>.Restore(System.Buffers.ReadOnlySequence<byte> data) -> T
 static Microsoft.AspNetCore.Components.Infrastructure.RegisterPersistentComponentStateServiceCollectionExtensions.AddPersistentServiceRegistration<TService>(Microsoft.Extensions.DependencyInjection.IServiceCollection! services, Microsoft.AspNetCore.Components.IComponentRenderMode! componentRenderMode) -> Microsoft.Extensions.DependencyInjection.IServiceCollection!
 static Microsoft.AspNetCore.Components.Infrastructure.ComponentsMetricsServiceCollectionExtensions.AddComponentsMetrics(Microsoft.Extensions.DependencyInjection.IServiceCollection! services) -> Microsoft.Extensions.DependencyInjection.IServiceCollection!
 static Microsoft.AspNetCore.Components.Infrastructure.ComponentsMetricsServiceCollectionExtensions.AddComponentsTracing(Microsoft.Extensions.DependencyInjection.IServiceCollection! services) -> Microsoft.Extensions.DependencyInjection.IServiceCollection!

+ 79 - 0
src/Components/Components/test/IPersistentComponentStateSerializerTests.cs

@@ -0,0 +1,79 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Buffers;
+using System.Text;
+using System.Text.Json;
+using Microsoft.AspNetCore.Components.Infrastructure;
+using Microsoft.Extensions.DependencyInjection;
+
+namespace Microsoft.AspNetCore.Components;
+
+public class IPersistentComponentStateSerializerTests
+{
+    [Fact]
+    public void PersistAsync_CanUseCustomSerializer()
+    {
+        // Arrange
+        var currentState = new Dictionary<string, byte[]>();
+        var state = new PersistentComponentState(currentState, []);
+        var serviceProvider = new ServiceCollection().BuildServiceProvider();
+        var stateValueProvider = new PersistentStateValueProvider(state, serviceProvider);
+        var customSerializer = new TestStringSerializer();
+        var testValue = "Hello, World!";
+
+        state.PersistingState = true;
+
+        // Act
+        stateValueProvider.PersistAsync("test-key", testValue, customSerializer);
+
+        // Assert
+        state.PersistingState = false;
+        
+        // Simulate the state transfer that happens between persist and restore phases
+        var newState = new PersistentComponentState(new Dictionary<string, byte[]>(), []);
+        newState.InitializeExistingState(currentState);
+        var newStateValueProvider = new PersistentStateValueProvider(newState, serviceProvider);
+        
+        Assert.True(newStateValueProvider.TryTake("test-key", customSerializer, out var retrievedValue));
+        Assert.Equal(testValue, retrievedValue);
+    }
+
+    [Fact]
+    public void TryTake_CanUseCustomSerializer()
+    {
+        // Arrange
+        var customData = "Custom Data";
+        var customBytes = Encoding.UTF8.GetBytes(customData);
+        var existingState = new Dictionary<string, byte[]> { { "test-key", customBytes } };
+        
+        var state = new PersistentComponentState(new Dictionary<string, byte[]>(), []);
+        state.InitializeExistingState(existingState);
+        
+        var serviceProvider = new ServiceCollection().BuildServiceProvider();
+        var stateValueProvider = new PersistentStateValueProvider(state, serviceProvider);
+        var customSerializer = new TestStringSerializer();
+
+        // Act
+        var success = stateValueProvider.TryTake("test-key", customSerializer, out var retrievedValue);
+
+        // Assert
+        Assert.True(success);
+        Assert.Equal(customData, retrievedValue);
+    }
+
+    private class TestStringSerializer : PersistentComponentStateSerializer<string>
+    {
+        public override void Persist(string value, IBufferWriter<byte> writer)
+        {
+            var bytes = Encoding.UTF8.GetBytes(value);
+            writer.Write(bytes);
+        }
+
+        public override string Restore(ReadOnlySequence<byte> data)
+        {
+            var bytes = data.ToArray();
+            return Encoding.UTF8.GetString(bytes);
+        }
+    }
+}

+ 15 - 15
src/Components/Components/test/PersistentStateValueProviderTests.cs

@@ -25,7 +25,7 @@ public class PersistentStateValueProviderTests
             new Dictionary<string, byte[]>(),
             []);
 
-        var provider = new PersistentStateValueProvider(state);
+        var provider = new PersistentStateValueProvider(state, new ServiceCollection().BuildServiceProvider());
         var renderer = new TestRenderer();
         var component = new TestComponent();
         // Update the method call to match the correct signature
@@ -53,7 +53,7 @@ public class PersistentStateValueProviderTests
         var state = new PersistentComponentState(
             new Dictionary<string, byte[]>(),
             []);
-        var provider = new PersistentStateValueProvider(state);
+        var provider = new PersistentStateValueProvider(state, new ServiceCollection().BuildServiceProvider());
         var renderer = new TestRenderer();
         var component = new TestComponent();
         var componentStates = CreateComponentState(renderer, [(component, null)], null);
@@ -75,7 +75,7 @@ public class PersistentStateValueProviderTests
         var state = new PersistentComponentState(
             new Dictionary<string, byte[]>(),
             []);
-        var provider = new PersistentStateValueProvider(state);
+        var provider = new PersistentStateValueProvider(state, new ServiceCollection().BuildServiceProvider());
         var renderer = new TestRenderer();
         var component = new TestComponent();
         var componentStates = CreateComponentState(renderer, [(component, null)], null);
@@ -108,7 +108,7 @@ public class PersistentStateValueProviderTests
         var componentState = componentStates.First();
 
         // Create the provider and subscribe the component
-        var provider = new PersistentStateValueProvider(persistenceManager.State);
+        var provider = new PersistentStateValueProvider(persistenceManager.State, new ServiceCollection().BuildServiceProvider());
         var cascadingParameterInfo = CreateCascadingParameterInfo(nameof(TestComponent.State), typeof(string));
         provider.Subscribe(componentState, cascadingParameterInfo);
 
@@ -147,7 +147,7 @@ public class PersistentStateValueProviderTests
         var componentState = componentStates.First();
 
         // Create the provider and subscribe the component
-        var provider = new PersistentStateValueProvider(persistenceManager.State);
+        var provider = new PersistentStateValueProvider(persistenceManager.State, new ServiceCollection().BuildServiceProvider());
         var cascadingParameterInfo = CreateCascadingParameterInfo(nameof(TestComponent.State), typeof(string));
         provider.Subscribe(componentState, cascadingParameterInfo);
 
@@ -187,7 +187,7 @@ public class PersistentStateValueProviderTests
         var componentState2 = componentStates.Last();
 
         // Create the provider and subscribe the component
-        var provider = new PersistentStateValueProvider(persistenceManager.State);
+        var provider = new PersistentStateValueProvider(persistenceManager.State, new ServiceCollection().BuildServiceProvider());
         var cascadingParameterInfo = CreateCascadingParameterInfo(nameof(TestComponent.State), typeof(string));
         provider.Subscribe(componentState1, cascadingParameterInfo);
         provider.Subscribe(componentState2, cascadingParameterInfo);
@@ -260,7 +260,7 @@ public class PersistentStateValueProviderTests
         var componentState2 = componentStates.Last();
 
         // Create the provider and subscribe the component
-        var provider = new PersistentStateValueProvider(persistenceManager.State);
+        var provider = new PersistentStateValueProvider(persistenceManager.State, new ServiceCollection().BuildServiceProvider());
         var cascadingParameterInfo = CreateCascadingParameterInfo(nameof(TestComponent.State), typeof(string));
         provider.Subscribe(componentState1, cascadingParameterInfo);
         provider.Subscribe(componentState2, cascadingParameterInfo);
@@ -305,7 +305,7 @@ public class PersistentStateValueProviderTests
         var componentState2 = componentStates.Last();
 
         // Create the provider and subscribe the components
-        var provider = new PersistentStateValueProvider(persistenceManager.State);
+        var provider = new PersistentStateValueProvider(persistenceManager.State, new ServiceCollection().BuildServiceProvider());
         var cascadingParameterInfo = CreateCascadingParameterInfo(nameof(TestComponent.State), typeof(string));
         provider.Subscribe(componentState1, cascadingParameterInfo);
         provider.Subscribe(componentState2, cascadingParameterInfo);
@@ -346,7 +346,7 @@ public class PersistentStateValueProviderTests
         var componentState2 = componentStates.Last();
 
         // Create the provider and subscribe the components
-        var provider = new PersistentStateValueProvider(persistenceManager.State);
+        var provider = new PersistentStateValueProvider(persistenceManager.State, new ServiceCollection().BuildServiceProvider());
         var cascadingParameterInfo = CreateCascadingParameterInfo(nameof(TestComponent.State), typeof(string));
         provider.Subscribe(componentState1, cascadingParameterInfo);
         provider.Subscribe(componentState2, cascadingParameterInfo);
@@ -379,7 +379,7 @@ public class PersistentStateValueProviderTests
         var componentState2 = componentStates.Last();
 
         // Create the provider and subscribe the components
-        var provider = new PersistentStateValueProvider(persistenceManager.State);
+        var provider = new PersistentStateValueProvider(persistenceManager.State, new ServiceCollection().BuildServiceProvider());
         var cascadingParameterInfo = CreateCascadingParameterInfo(nameof(TestComponent.State), typeof(string));
         provider.Subscribe(componentState1, cascadingParameterInfo);
         provider.Subscribe(componentState2, cascadingParameterInfo);
@@ -419,7 +419,7 @@ public class PersistentStateValueProviderTests
         var componentState2 = componentStates.Last();
 
         // Create the provider and subscribe the components
-        var provider = new PersistentStateValueProvider(persistenceManager.State);
+        var provider = new PersistentStateValueProvider(persistenceManager.State, new ServiceCollection().BuildServiceProvider());
         var cascadingParameterInfo = CreateCascadingParameterInfo(nameof(TestComponent.State), typeof(string));
         provider.Subscribe(componentState1, cascadingParameterInfo);
         provider.Subscribe(componentState2, cascadingParameterInfo);
@@ -448,7 +448,7 @@ public class PersistentStateValueProviderTests
         var componentState = componentStates.First();
 
         // Create the provider and subscribe the component
-        var provider = new PersistentStateValueProvider(persistenceManager.State);
+        var provider = new PersistentStateValueProvider(persistenceManager.State, new ServiceCollection().BuildServiceProvider());
         var cascadingParameterInfo = CreateCascadingParameterInfo(nameof(ValueTypeTestComponent.IntValue), typeof(int));
         provider.Subscribe(componentState, cascadingParameterInfo);
 
@@ -483,7 +483,7 @@ public class PersistentStateValueProviderTests
         var componentState = componentStates.First();
 
         // Create the provider and subscribe the component
-        var provider = new PersistentStateValueProvider(persistenceManager.State);
+        var provider = new PersistentStateValueProvider(persistenceManager.State, new ServiceCollection().BuildServiceProvider());
         var cascadingParameterInfo = CreateCascadingParameterInfo(nameof(ValueTypeTestComponent.NullableIntValue), typeof(int?));
         provider.Subscribe(componentState, cascadingParameterInfo);
 
@@ -518,7 +518,7 @@ public class PersistentStateValueProviderTests
         var componentState = componentStates.First();
 
         // Create the provider and subscribe the component
-        var provider = new PersistentStateValueProvider(persistenceManager.State);
+        var provider = new PersistentStateValueProvider(persistenceManager.State, new ServiceCollection().BuildServiceProvider());
         var cascadingParameterInfo = CreateCascadingParameterInfo(nameof(ValueTypeTestComponent.TupleValue), typeof((string, int)));
         provider.Subscribe(componentState, cascadingParameterInfo);
 
@@ -553,7 +553,7 @@ public class PersistentStateValueProviderTests
         var componentState = componentStates.First();
 
         // Create the provider and subscribe the component
-        var provider = new PersistentStateValueProvider(persistenceManager.State);
+        var provider = new PersistentStateValueProvider(persistenceManager.State, new ServiceCollection().BuildServiceProvider());
         var cascadingParameterInfo = CreateCascadingParameterInfo(nameof(ValueTypeTestComponent.NullableTupleValue), typeof((string, int)?));
         provider.Subscribe(componentState, cascadingParameterInfo);
 

+ 4 - 0
src/Components/test/E2ETest/ServerRenderingTests/InteractivityTest.cs

@@ -1059,6 +1059,7 @@ public class InteractivityTest : ServerTestBase<BasicTestAppServerSiteFixture<Ra
         Navigate($"{ServerPathBase}/persist-state?server=true&declarative=true");
 
         Browser.Equal("restored", () => Browser.FindElement(By.Id("server")).Text);
+        Browser.Equal("42", () => Browser.FindElement(By.Id("custom-server")).Text);
         Browser.Equal("Server", () => Browser.FindElement(By.Id("render-mode-server")).Text);
     }
 
@@ -1077,6 +1078,7 @@ public class InteractivityTest : ServerTestBase<BasicTestAppServerSiteFixture<Ra
         Navigate($"{ServerPathBase}/persist-state?wasm=true&declarative=true");
 
         Browser.Equal("restored", () => Browser.FindElement(By.Id("wasm")).Text);
+        Browser.Equal("42", () => Browser.FindElement(By.Id("custom-wasm")).Text);
         Browser.Equal("WebAssembly", () => Browser.FindElement(By.Id("render-mode-wasm")).Text);
     }
 
@@ -1095,6 +1097,7 @@ public class InteractivityTest : ServerTestBase<BasicTestAppServerSiteFixture<Ra
         Navigate($"{ServerPathBase}/persist-state?auto=true&declarative=true");
 
         Browser.Equal("restored", () => Browser.FindElement(By.Id("auto")).Text);
+        Browser.Equal("42", () => Browser.FindElement(By.Id("custom-auto")).Text);
         Browser.Equal("WebAssembly", () => Browser.FindElement(By.Id("render-mode-auto")).Text);
     }
 
@@ -1156,6 +1159,7 @@ public class InteractivityTest : ServerTestBase<BasicTestAppServerSiteFixture<Ra
         Navigate($"{ServerPathBase}/persist-state?auto=true&declarative=true");
 
         Browser.Equal("restored", () => Browser.FindElement(By.Id("auto")).Text);
+        Browser.Equal("42", () => Browser.FindElement(By.Id("custom-auto")).Text);
         Browser.Equal("Server", () => Browser.FindElement(By.Id("render-mode-auto")).Text);
     }
 

+ 5 - 0
src/Components/test/testassets/Components.TestServer/RazorComponentEndpointsStartup.cs

@@ -8,10 +8,12 @@ using System.Web;
 using Components.TestServer.RazorComponents;
 using Components.TestServer.RazorComponents.Pages.Forms;
 using Components.TestServer.Services;
+using Microsoft.AspNetCore.Components;
 using Microsoft.AspNetCore.Components.Server.Circuits;
 using Microsoft.AspNetCore.Components.Web;
 using Microsoft.AspNetCore.Components.WebAssembly.Server;
 using Microsoft.AspNetCore.Mvc;
+using TestContentPackage;
 using TestContentPackage.Services;
 
 namespace TestServer;
@@ -64,6 +66,9 @@ public class RazorComponentEndpointsStartup<TRootComponent>
         services.AddScoped<InteractiveServerService>();
         services.AddScoped<InteractiveAutoService>();
 
+        // Register custom serializer for E2E testing of persistent component state serialization extensibility
+        services.AddSingleton<PersistentComponentStateSerializer<int>, CustomIntSerializer>();
+
         services.AddHttpContextAccessor();
         services.AddSingleton<AsyncOperationService>();
         services.AddCascadingAuthenticationState();

+ 5 - 0
src/Components/test/testassets/Components.WasmMinimal/Program.cs

@@ -4,8 +4,10 @@
 using System.Runtime.InteropServices.JavaScript;
 using System.Security.Claims;
 using Components.TestServer.Services;
+using Microsoft.AspNetCore.Components;
 using Microsoft.AspNetCore.Components.Web;
 using Microsoft.AspNetCore.Components.WebAssembly.Hosting;
+using TestContentPackage;
 using TestContentPackage.Services;
 
 var builder = WebAssemblyHostBuilder.CreateDefault(args);
@@ -14,6 +16,9 @@ builder.Services.AddSingleton<InteractiveWebAssemblyService>();
 builder.Services.AddSingleton<InteractiveAutoService>();
 builder.Services.AddSingleton<InteractiveServerService>();
 
+// Register custom serializer for persistent component state
+builder.Services.AddSingleton<PersistentComponentStateSerializer<int>, CustomIntSerializer>();
+
 builder.Services.AddCascadingAuthenticationState();
 
 builder.Services.AddAuthenticationStateDeserialization(options =>

+ 36 - 0
src/Components/test/testassets/TestContentPackage/CustomIntSerializer.cs

@@ -0,0 +1,36 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Buffers;
+using System.Text;
+using Microsoft.AspNetCore.Components;
+
+namespace TestContentPackage;
+
+/// <summary>
+/// A custom serializer for int values that uses a custom format to test serialization extensibility.
+/// This serializer prefixes integer values with "CUSTOM:" to clearly distinguish them from JSON serialization.
+/// </summary>
+public class CustomIntSerializer : PersistentComponentStateSerializer<int>
+{
+    public override void Persist(int value, IBufferWriter<byte> writer)
+    {
+        var customFormat = $"CUSTOM:{value}";
+        var bytes = Encoding.UTF8.GetBytes(customFormat);
+        writer.Write(bytes);
+    }
+
+    public override int Restore(ReadOnlySequence<byte> data)
+    {
+        var bytes = data.ToArray();
+        var text = Encoding.UTF8.GetString(bytes);
+        
+        if (text.StartsWith("CUSTOM:", StringComparison.Ordinal) && int.TryParse(text.Substring(7), out var value))
+        {
+            return value;
+        }
+        
+        // Fallback to direct parsing if format is unexpected
+        return int.TryParse(text, out var fallbackValue) ? fallbackValue : 0;
+    }
+}

+ 8 - 0
src/Components/test/testassets/TestContentPackage/DeclarativePersistStateComponent.razor

@@ -1,4 +1,5 @@
 <p>Application state is <span id="@KeyName">@Value</span></p>
+<p>Custom value is <span id="custom-@KeyName">@CustomValue</span></p>
 <p>Render mode: <span id="render-mode-@KeyName">@_renderMode</span></p>
 
 @code {
@@ -11,11 +12,18 @@
     [PersistentState]
     public string Value { get; set; }
 
+    [PersistentState]
+    public int CustomValue { get; set; }
+
     private string _renderMode = "SSR";
 
     protected override void OnInitialized()
     {
         Value ??= !RendererInfo.IsInteractive ? InitialValue : "not restored";
+        if (CustomValue == 0)
+        {
+            CustomValue = !RendererInfo.IsInteractive ? 42 : 0;
+        }
         _renderMode = OperatingSystem.IsBrowser() ? "WebAssembly" : "Server";
     }
 }

+ 1 - 1
src/Shared/PooledArrayBufferWriter.cs

@@ -93,7 +93,7 @@ internal sealed class PooledArrayBufferWriter<T> : IBufferWriter<T>, IDisposable
 
         ClearHelper();
         ArrayPool<T>.Shared.Return(_rentedBuffer);
-        _rentedBuffer = null;
+        _rentedBuffer = null!;
     }
 
     private void CheckIfDisposed()