ソースを参照

[automated] Merge branch 'release/8.0' => 'main' (#50792)

* Revert "Remove hardcoded System.Security.Cryptography.Xml version (#48029)" (#50723)

This reverts commit 42d14c4bab2afb8cddc1f9683aa8ff2e20a46edb.

* [Blazor] Prerendered state (#50742)

[Blazor] Adds support for persting prerendered state on Blazor Web applications.
* Persists state both for server and webassembly as necessary.
* Initializes the state when a given interactive runtime is initialized and renders the first set of components.
  * On WebAssembly, this is the first time the app starts.
  * On Server this happens every time a circuit starts.
* The state is available during the first render, until the components reach quiescence.

The approach we follow is different for server and webassembly:
* On Server, we support initializing the circuit with an empty set of descriptors and in that case, we delay initialization until the first `UpdateRootComponents` call is issued.
  * This is because it's hard to deal with the security constraints imposed by starting a new circuit multiple times, and its easier to handle them within UpdateRootComponents. We might switch this approach in the future to go through `StartCircuit` too.
* On WebAssembly, we query for the initial set of webassembly components when we are starting the runtime in a Blazor Web Scenario.
  * We do this because Blazor WebAssembly offers a programatic API to render root components at a given location defined by their selectors, so we need to make sure that those components can receive state at the same time the initial set of WebAssembly components added to the page.

There are a set of tests validating different behaviors with regards to enhanced navigation and streaming rendering, as well as making sure that auto mode can access the state on Server and WebAssembly, and that Server gets new state every time a circuit is opened.

* Make IEmailSender more customizable (#50301)

* Make IEmailSender more customizable

* Remove unnecessary metadata

* Add TUser parameter

* React to API review feedback

* Fix IdentitySample.DefaultUI

* Update branding to RTM (#50799)

---------

Co-authored-by: Igor Velikorossov <[email protected]>
Co-authored-by: Javier Calvarro Nelson <[email protected]>
Co-authored-by: Stephen Halter <[email protected]>
Co-authored-by: William Godbe <[email protected]>
dotnet-maestro-bot 2 年 前
コミット
2ce54c68b2
86 ファイル変更2748 行追加744 行削除
  1. 0 2
      eng/SourceBuildPrebuiltBaseline.xml
  2. 2 0
      eng/tools/RepoTasks/RepoTasks.csproj
  3. 7 0
      src/Components/Components/src/IPersistentComponentStateStore.cs
  4. 78 20
      src/Components/Components/src/Infrastructure/ComponentStatePersistenceManager.cs
  5. 13 0
      src/Components/Components/src/PersistComponentStateRegistration.cs
  6. 17 5
      src/Components/Components/src/PersistentComponentState.cs
  7. 5 5
      src/Components/Components/src/PersistingComponentStateSubscription.cs
  8. 4 1
      src/Components/Components/src/PublicAPI.Unshipped.txt
  9. 14 1
      src/Components/Components/src/RenderTree/Renderer.cs
  10. 21 13
      src/Components/Components/test/Lifetime/ComponentApplicationStateTest.cs
  11. 71 34
      src/Components/Components/test/Lifetime/ComponentStatePersistenceManagerTest.cs
  12. 4 0
      src/Components/Endpoints/src/RazorComponentEndpointInvoker.cs
  13. 24 3
      src/Components/Endpoints/src/Rendering/EndpointHtmlRenderer.Prerendering.cs
  14. 173 23
      src/Components/Endpoints/src/Rendering/EndpointHtmlRenderer.PrerenderingState.cs
  15. 10 9
      src/Components/Endpoints/src/Rendering/SSRRenderModeBoundary.cs
  16. 360 1
      src/Components/Endpoints/test/EndpointHtmlRendererTest.cs
  17. 13 3
      src/Components/Server/src/Circuits/CircuitFactory.cs
  18. 146 4
      src/Components/Server/src/Circuits/CircuitHost.cs
  19. 1 1
      src/Components/Server/src/Circuits/IServerComponentDeserializer.cs
  20. 6 90
      src/Components/Server/src/Circuits/RemoteRenderer.cs
  21. 96 0
      src/Components/Server/src/Circuits/ServerComponentDeserializer.cs
  22. 27 0
      src/Components/Server/src/ComponentHub.cs
  23. 295 1
      src/Components/Server/test/Circuits/CircuitHostTest.cs
  24. 6 0
      src/Components/Server/test/Circuits/ComponentHubTest.cs
  25. 0 295
      src/Components/Server/test/Circuits/RemoteRendererTest.cs
  26. 105 0
      src/Components/Server/test/Circuits/ServerComponentDeserializerTest.cs
  27. 2 1
      src/Components/Shared/src/DefaultAntiforgeryStateProvider.cs
  28. 0 0
      src/Components/Web.JS/dist/Release/blazor.server.js
  29. 0 0
      src/Components/Web.JS/dist/Release/blazor.web.js
  30. 8 2
      src/Components/Web.JS/src/Boot.Server.Common.ts
  31. 48 3
      src/Components/Web.JS/src/Boot.WebAssembly.Common.ts
  32. 3 0
      src/Components/Web.JS/src/GlobalExports.ts
  33. 13 1
      src/Components/Web.JS/src/Platform/Circuits/CircuitManager.ts
  34. 13 4
      src/Components/Web.JS/src/Services/ComponentDescriptorDiscovery.ts
  35. 22 4
      src/Components/Web.JS/src/Services/WebRootComponentManager.ts
  36. 0 1
      src/Components/Web/src/PublicAPI.Unshipped.txt
  37. 0 13
      src/Components/Web/src/WebRenderer.cs
  38. 28 5
      src/Components/WebAssembly/WebAssembly/src/Hosting/WebAssemblyHost.cs
  39. 22 85
      src/Components/WebAssembly/WebAssembly/src/Rendering/WebAssemblyRenderer.cs
  40. 108 0
      src/Components/WebAssembly/WebAssembly/src/Services/DefaultWebAssemblyJSRuntime.cs
  41. 28 0
      src/Components/WebAssembly/WebAssembly/src/Services/InternalJSImportMethods.cs
  42. 79 0
      src/Components/test/E2ETest/ServerRenderingTests/InteractivityTest.cs
  43. 2 0
      src/Components/test/E2ETest/Tests/SaveStateTest.cs
  44. 237 0
      src/Components/test/E2ETest/Tests/StatePersistenceTest.cs
  45. 2 1
      src/Components/test/testassets/BasicTestApp/PreserveStateService.cs
  46. 24 0
      src/Components/test/testassets/Components.TestServer/RazorComponents/Pages/PersistServerState.razor
  47. 33 0
      src/Components/test/testassets/Components.TestServer/RazorComponents/Pages/PersistStateComponents.razor
  48. 36 0
      src/Components/test/testassets/Components.TestServer/RazorComponents/Pages/PersistentState/EndStreamingPage.razor
  49. 65 0
      src/Components/test/testassets/Components.TestServer/RazorComponents/Pages/PersistentState/PageWithComponents.razor
  50. 14 0
      src/Components/test/testassets/Components.TestServer/RazorComponents/Pages/PersistentState/PageWithoutComponents.razor
  51. 25 0
      src/Components/test/testassets/Components.WasmMinimal/Pages/PersistWebAssemblyState.razor
  52. 4 0
      src/Components/test/testassets/Components.WasmMinimal/Program.cs
  53. 40 0
      src/Components/test/testassets/TestContentPackage/PersistStateComponent.razor
  54. 49 0
      src/Components/test/testassets/TestContentPackage/PersistentComponents/NonStreamingComponentWithPersistentState.razor
  55. 77 0
      src/Components/test/testassets/TestContentPackage/PersistentComponents/StreamingComponentWithPersistentState.razor
  56. 0 0
      src/Components/test/testassets/TestContentPackage/Services/AsyncOperationService.cs
  57. 2 4
      src/Identity/Core/src/IEmailSender.cs
  58. 41 0
      src/Identity/Core/src/IEmailSenderOfT.cs
  59. 3 7
      src/Identity/Core/src/IdentityApiEndpointRouteBuilderExtensions.cs
  60. 1 0
      src/Identity/Core/src/IdentityBuilderExtensions.cs
  61. 1 1
      src/Identity/Core/src/Microsoft.AspNetCore.Identity.csproj
  62. 0 2
      src/Identity/Core/src/NoOpEmailSender.cs
  63. 9 0
      src/Identity/Core/src/PublicAPI.Unshipped.txt
  64. 0 5
      src/Identity/Extensions.Core/src/PublicAPI.Unshipped.txt
  65. 3 5
      src/Identity/UI/src/Areas/Identity/Pages/V4/Account/ExternalLogin.cshtml.cs
  66. 3 7
      src/Identity/UI/src/Areas/Identity/Pages/V4/Account/ForgotPassword.cshtml.cs
  67. 4 11
      src/Identity/UI/src/Areas/Identity/Pages/V4/Account/Manage/Email.cshtml.cs
  68. 3 5
      src/Identity/UI/src/Areas/Identity/Pages/V4/Account/Register.cshtml.cs
  69. 3 4
      src/Identity/UI/src/Areas/Identity/Pages/V4/Account/RegisterConfirmation.cshtml.cs
  70. 3 7
      src/Identity/UI/src/Areas/Identity/Pages/V4/Account/ResendEmailConfirmation.cshtml.cs
  71. 3 5
      src/Identity/UI/src/Areas/Identity/Pages/V5/Account/ExternalLogin.cshtml.cs
  72. 3 7
      src/Identity/UI/src/Areas/Identity/Pages/V5/Account/ForgotPassword.cshtml.cs
  73. 4 11
      src/Identity/UI/src/Areas/Identity/Pages/V5/Account/Manage/Email.cshtml.cs
  74. 3 5
      src/Identity/UI/src/Areas/Identity/Pages/V5/Account/Register.cshtml.cs
  75. 3 4
      src/Identity/UI/src/Areas/Identity/Pages/V5/Account/RegisterConfirmation.cshtml.cs
  76. 3 7
      src/Identity/UI/src/Areas/Identity/Pages/V5/Account/ResendEmailConfirmation.cshtml.cs
  77. 1 0
      src/Identity/UI/src/IdentityBuilderUIExtensions.cs
  78. 4 0
      src/Identity/UI/src/Microsoft.AspNetCore.Identity.UI.csproj
  79. 2 2
      src/Identity/UI/src/PublicAPI.Unshipped.txt
  80. 3 4
      src/Identity/samples/IdentitySample.DefaultUI/Areas/Identity/Pages/Account/Register.cshtml.cs
  81. 40 0
      src/Identity/test/Identity.FunctionalTests/MapIdentityApiTests.cs
  82. 79 8
      src/Mvc/Mvc.TagHelpers/test/PersistComponentStateTagHelperTest.cs
  83. 18 1
      src/Shared/Components/PrerenderComponentApplicationStore.cs
  84. 5 0
      src/Shared/Components/ProtectedPrerenderComponentApplicationStore.cs
  85. 20 0
      src/Shared/DefaultMessageEmailSender.cs
  86. 1 1
      src/Shared/E2ETesting/BrowserFixture.cs

+ 0 - 2
eng/SourceBuildPrebuiltBaseline.xml

@@ -31,8 +31,6 @@
     <UsagePattern IdentityGlob="System.Composition/*7.0.0*" />
     <UsagePattern IdentityGlob="System.Threading.Tasks.Extensions/*4.5.3*" />
 
-    <!-- Added to unblock dependency flow, needs review. -->
-    <UsagePattern IdentityGlob="System.Security.Cryptography.Xml/*6.0.0*" />
 
     <!-- These are coming in via runtime but the source-build infra isn't able to automatically pick up the right intermediate. -->
     <UsagePattern IdentityGlob="Microsoft.NET.ILLink.Tasks/*8.0.*" />

+ 2 - 0
eng/tools/RepoTasks/RepoTasks.csproj

@@ -26,6 +26,8 @@
     <PackageReference Include="Microsoft.Build.Framework" Version="$(MicrosoftBuildFrameworkVersion)" />
     <PackageReference Include="Microsoft.Build.Tasks.Core" Version="$(MicrosoftBuildTasksCoreVersion)" />
     <PackageReference Include="Microsoft.Build.Utilities.Core" Version="$(MicrosoftBuildUtilitiesCoreVersion)" />
+    <!-- Manually updated version from 6.0.0 to address CVE-2021-43877 -->
+    <PackageReference Include="System.Security.Cryptography.Xml" Version="$(RepoTasksSystemSecurityCryptographyXmlVersion)" />
   </ItemGroup>
 
   <ItemGroup Condition="'$(TargetFramework)' == 'net472'">

+ 7 - 0
src/Components/Components/src/IPersistentComponentStateStore.cs

@@ -20,4 +20,11 @@ public interface IPersistentComponentStateStore
     /// <param name="state">The serialized state to persist.</param>
     /// <returns>A <see cref="Task" /> that completes when the state is persisted to disk.</returns>
     Task PersistStateAsync(IReadOnlyDictionary<string, byte[]> state);
+
+    /// <summary>
+    /// Returns a value that indicates whether the store supports the given <see cref="IComponentRenderMode"/>.
+    /// </summary>
+    /// <param name="renderMode">The <see cref="IComponentRenderMode"/> in question.</param>
+    /// <returns><c>true</c> if the render mode is supported by the store, otherwise <c>false</c>.</returns>
+    bool SupportsRenderMode(IComponentRenderMode renderMode) => true;
 }

+ 78 - 20
src/Components/Components/src/Infrastructure/ComponentStatePersistenceManager.cs

@@ -11,17 +11,18 @@ namespace Microsoft.AspNetCore.Components.Infrastructure;
 /// </summary>
 public class ComponentStatePersistenceManager
 {
+    private readonly List<PersistComponentStateRegistration> _registeredCallbacks = new();
+    private readonly ILogger<ComponentStatePersistenceManager> _logger;
+
     private bool _stateIsPersisted;
-    private readonly List<Func<Task>> _pauseCallbacks = new();
     private readonly Dictionary<string, byte[]> _currentState = new(StringComparer.Ordinal);
-    private readonly ILogger<ComponentStatePersistenceManager> _logger;
 
     /// <summary>
     /// Initializes a new instance of <see cref="ComponentStatePersistenceManager"/>.
     /// </summary>
     public ComponentStatePersistenceManager(ILogger<ComponentStatePersistenceManager> logger)
     {
-        State = new PersistentComponentState(_currentState, _pauseCallbacks);
+        State = new PersistentComponentState(_currentState, _registeredCallbacks);
         _logger = logger;
     }
 
@@ -48,43 +49,100 @@ public class ComponentStatePersistenceManager
     /// <param name="renderer">The <see cref="Renderer"/> that components are being rendered.</param>
     /// <returns>A <see cref="Task"/> that will complete when the state has been restored.</returns>
     public Task PersistStateAsync(IPersistentComponentStateStore store, Renderer renderer)
-        => PersistStateAsync(store, renderer.Dispatcher);
-
-    /// <summary>
-    /// Persists the component application state into the given <see cref="IPersistentComponentStateStore"/>.
-    /// </summary>
-    /// <param name="store">The <see cref="IPersistentComponentStateStore"/> to restore the application state from.</param>
-    /// <param name="dispatcher">The <see cref="Dispatcher"/> corresponding to the components' renderer.</param>
-    /// <returns>A <see cref="Task"/> that will complete when the state has been restored.</returns>
-    public Task PersistStateAsync(IPersistentComponentStateStore store, Dispatcher dispatcher)
     {
         if (_stateIsPersisted)
         {
             throw new InvalidOperationException("State already persisted.");
         }
 
-        _stateIsPersisted = true;
-
-        return dispatcher.InvokeAsync(PauseAndPersistState);
+        return renderer.Dispatcher.InvokeAsync(PauseAndPersistState);
 
         async Task PauseAndPersistState()
         {
             State.PersistingState = true;
-            await PauseAsync();
+
+            if (store is IEnumerable<IPersistentComponentStateStore> compositeStore)
+            {
+                // We only need to do inference when there is more than one store. This is determined by
+                // the set of rendered components.
+                InferRenderModes(renderer);
+
+                // Iterate over each store and give it a chance to run against the existing declared
+                // render modes. After we've run through a store, we clear the current state so that
+                // the next store can start with a clean slate.
+                foreach (var store in compositeStore)
+                {
+                    await PersistState(store);
+                    _currentState.Clear();
+                }
+            }
+            else
+            {
+                await PersistState(store);
+            }
+
             State.PersistingState = false;
+            _stateIsPersisted = true;
+        }
 
+        async Task PersistState(IPersistentComponentStateStore store)
+        {
+            await PauseAsync(store);
             await store.PersistStateAsync(_currentState);
         }
     }
 
-    internal Task PauseAsync()
+    private void InferRenderModes(Renderer renderer)
+    {
+        for (var i = 0; i < _registeredCallbacks.Count; i++)
+        {
+            var registration = _registeredCallbacks[i];
+            if (registration.RenderMode != null)
+            {
+                // Explicitly set render mode, so nothing to do.
+                continue;
+            }
+
+            if (registration.Callback.Target is IComponent component)
+            {
+                var componentRenderMode = renderer.GetComponentRenderMode(component);
+                if (componentRenderMode != null)
+                {
+                    _registeredCallbacks[i] = new PersistComponentStateRegistration(registration.Callback, componentRenderMode);
+                }
+                else
+                {
+                    // If we can't find a render mode, it's an SSR only component and we don't need to
+                    // persist its state at all.
+                    _registeredCallbacks[i] = default;
+                }
+                continue;
+            }
+
+            throw new InvalidOperationException(
+                $"The registered callback {registration.Callback.Method.Name} must be associated with a component or define" +
+                $" an explicit render mode type during registration.");
+        }
+    }
+
+    internal Task PauseAsync(IPersistentComponentStateStore store)
     {
         List<Task>? pendingCallbackTasks = null;
 
-        for (var i = 0; i < _pauseCallbacks.Count; i++)
+        for (var i = 0; i < _registeredCallbacks.Count; i++)
         {
-            var callback = _pauseCallbacks[i];
-            var result = ExecuteCallback(callback, _logger);
+            var registration = _registeredCallbacks[i];
+
+            if (!store.SupportsRenderMode(registration.RenderMode!))
+            {
+                // The callback does not have an associated render mode and we are in a multi-store scenario.
+                // Otherwise, in a single store scenario, we just run the callback.
+                // If the registration callback is null, it's because it was associated with a component and we couldn't infer
+                // its render mode, which means is an SSR only component and we don't need to persist its state at all.
+                continue;
+            }
+
+            var result = ExecuteCallback(registration.Callback, _logger);
             if (!result.IsCompletedSuccessfully)
             {
                 pendingCallbackTasks ??= new();

+ 13 - 0
src/Components/Components/src/PersistComponentStateRegistration.cs

@@ -0,0 +1,13 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace Microsoft.AspNetCore.Components;
+
+internal readonly struct PersistComponentStateRegistration(
+    Func<Task> callback,
+    IComponentRenderMode? renderMode)
+{
+    public Func<Task> Callback { get; } = callback;
+
+    public IComponentRenderMode? RenderMode { get; } = renderMode;
+}

+ 17 - 5
src/Components/Components/src/PersistentComponentState.cs

@@ -15,11 +15,11 @@ public class PersistentComponentState
     private IDictionary<string, byte[]>? _existingState;
     private readonly IDictionary<string, byte[]> _currentState;
 
-    private readonly List<Func<Task>> _registeredCallbacks;
+    private readonly List<PersistComponentStateRegistration> _registeredCallbacks;
 
     internal PersistentComponentState(
-        IDictionary<string, byte[]> currentState,
-        List<Func<Task>> pauseCallbacks)
+        IDictionary<string , byte[]> currentState,
+        List<PersistComponentStateRegistration> pauseCallbacks)
     {
         _currentState = currentState;
         _registeredCallbacks = pauseCallbacks;
@@ -43,12 +43,24 @@ public class PersistentComponentState
     /// <param name="callback">The callback to invoke when the application is being paused.</param>
     /// <returns>A subscription that can be used to unregister the callback when disposed.</returns>
     public PersistingComponentStateSubscription RegisterOnPersisting(Func<Task> callback)
+        => RegisterOnPersisting(callback, null);
+
+    /// <summary>
+    /// Register a callback to persist the component state when the application is about to be paused.
+    /// Registered callbacks can use this opportunity to persist their state so that it can be retrieved when the application resumes.
+    /// </summary>
+    /// <param name="callback">The callback to invoke when the application is being paused.</param>
+    /// <param name="renderMode"></param>
+    /// <returns>A subscription that can be used to unregister the callback when disposed.</returns>
+    public PersistingComponentStateSubscription RegisterOnPersisting(Func<Task> callback, IComponentRenderMode? renderMode)
     {
         ArgumentNullException.ThrowIfNull(callback);
 
-        _registeredCallbacks.Add(callback);
+        var persistenceCallback = new PersistComponentStateRegistration(callback, renderMode);
+
+        _registeredCallbacks.Add(persistenceCallback);
 
-        return new PersistingComponentStateSubscription(_registeredCallbacks, callback);
+        return new PersistingComponentStateSubscription(_registeredCallbacks, persistenceCallback);
     }
 
     /// <summary>

+ 5 - 5
src/Components/Components/src/PersistingComponentStateSubscription.cs

@@ -11,10 +11,10 @@ namespace Microsoft.AspNetCore.Components;
 /// </summary>
 public readonly struct PersistingComponentStateSubscription : IDisposable
 {
-    private readonly List<Func<Task>>? _callbacks;
-    private readonly Func<Task>? _callback;
+    private readonly List<PersistComponentStateRegistration>? _callbacks;
+    private readonly PersistComponentStateRegistration? _callback;
 
-    internal PersistingComponentStateSubscription(List<Func<Task>> callbacks, Func<Task> callback)
+    internal PersistingComponentStateSubscription(List<PersistComponentStateRegistration> callbacks, PersistComponentStateRegistration callback)
     {
         _callbacks = callbacks;
         _callback = callback;
@@ -23,9 +23,9 @@ public readonly struct PersistingComponentStateSubscription : IDisposable
     /// <inheritdoc />
     public void Dispose()
     {
-        if (_callback != null)
+        if (_callback.HasValue)
         {
-            _callbacks?.Remove(_callback);
+            _callbacks?.Remove(_callback.Value);
         }
     }
 }

+ 4 - 1
src/Components/Components/src/PublicAPI.Unshipped.txt

@@ -16,11 +16,12 @@ Microsoft.AspNetCore.Components.CascadingValueSource<TValue>.NotifyChangedAsync(
 Microsoft.AspNetCore.Components.CascadingValueSource<TValue>.NotifyChangedAsync(TValue newValue) -> System.Threading.Tasks.Task!
 Microsoft.AspNetCore.Components.ComponentBase.DispatchExceptionAsync(System.Exception! exception) -> System.Threading.Tasks.Task!
 Microsoft.AspNetCore.Components.IComponentRenderMode
-Microsoft.AspNetCore.Components.Infrastructure.ComponentStatePersistenceManager.PersistStateAsync(Microsoft.AspNetCore.Components.IPersistentComponentStateStore! store, Microsoft.AspNetCore.Components.Dispatcher! dispatcher) -> System.Threading.Tasks.Task!
 Microsoft.AspNetCore.Components.InjectAttribute.Key.get -> object?
 Microsoft.AspNetCore.Components.InjectAttribute.Key.init -> void
+Microsoft.AspNetCore.Components.IPersistentComponentStateStore.SupportsRenderMode(Microsoft.AspNetCore.Components.IComponentRenderMode! renderMode) -> bool
 Microsoft.AspNetCore.Components.ParameterView.ToDictionary() -> System.Collections.Generic.IReadOnlyDictionary<string!, object?>!
 *REMOVED*Microsoft.AspNetCore.Components.ParameterView.ToDictionary() -> System.Collections.Generic.IReadOnlyDictionary<string!, object!>!
+Microsoft.AspNetCore.Components.PersistentComponentState.RegisterOnPersisting(System.Func<System.Threading.Tasks.Task!>! callback, Microsoft.AspNetCore.Components.IComponentRenderMode? renderMode) -> Microsoft.AspNetCore.Components.PersistingComponentStateSubscription
 Microsoft.AspNetCore.Components.RenderHandle.DispatchExceptionAsync(System.Exception! exception) -> System.Threading.Tasks.Task!
 *REMOVED*Microsoft.AspNetCore.Components.NavigationManager.ToAbsoluteUri(string! relativeUri) -> System.Uri!
 Microsoft.AspNetCore.Components.NavigationManager.ToAbsoluteUri(string? relativeUri) -> System.Uri!
@@ -44,6 +45,7 @@ Microsoft.AspNetCore.Components.RenderTree.NamedEventChangeType
 Microsoft.AspNetCore.Components.RenderTree.NamedEventChangeType.Added = 0 -> Microsoft.AspNetCore.Components.RenderTree.NamedEventChangeType
 Microsoft.AspNetCore.Components.RenderTree.NamedEventChangeType.Removed = 1 -> Microsoft.AspNetCore.Components.RenderTree.NamedEventChangeType
 Microsoft.AspNetCore.Components.RenderTree.RenderBatch.NamedEventChanges.get -> Microsoft.AspNetCore.Components.RenderTree.ArrayRange<Microsoft.AspNetCore.Components.RenderTree.NamedEventChange>?
+Microsoft.AspNetCore.Components.RenderTree.Renderer.GetComponentState(Microsoft.AspNetCore.Components.IComponent! component) -> Microsoft.AspNetCore.Components.Rendering.ComponentState!
 Microsoft.AspNetCore.Components.RenderTree.RenderTreeFrame.ComponentFrameFlags.get -> Microsoft.AspNetCore.Components.RenderTree.ComponentFrameFlags
 Microsoft.AspNetCore.Components.RenderTree.RenderTreeFrameType.ComponentRenderMode = 9 -> Microsoft.AspNetCore.Components.RenderTree.RenderTreeFrameType
 Microsoft.AspNetCore.Components.RenderTree.RenderTreeFrameType.NamedEvent = 10 -> Microsoft.AspNetCore.Components.RenderTree.RenderTreeFrameType
@@ -101,6 +103,7 @@ virtual Microsoft.AspNetCore.Components.Rendering.ComponentState.DisposeAsync()
 virtual Microsoft.AspNetCore.Components.RenderTree.Renderer.AddPendingTask(Microsoft.AspNetCore.Components.Rendering.ComponentState? componentState, System.Threading.Tasks.Task! task) -> void
 virtual Microsoft.AspNetCore.Components.RenderTree.Renderer.CreateComponentState(int componentId, Microsoft.AspNetCore.Components.IComponent! component, Microsoft.AspNetCore.Components.Rendering.ComponentState? parentComponentState) -> Microsoft.AspNetCore.Components.Rendering.ComponentState!
 virtual Microsoft.AspNetCore.Components.RenderTree.Renderer.DispatchEventAsync(ulong eventHandlerId, Microsoft.AspNetCore.Components.RenderTree.EventFieldInfo? fieldInfo, System.EventArgs! eventArgs, bool waitForQuiescence) -> System.Threading.Tasks.Task!
+virtual Microsoft.AspNetCore.Components.RenderTree.Renderer.GetComponentRenderMode(Microsoft.AspNetCore.Components.IComponent! component) -> Microsoft.AspNetCore.Components.IComponentRenderMode?
 virtual Microsoft.AspNetCore.Components.RenderTree.Renderer.ResolveComponentForRenderMode(System.Type! componentType, int? parentComponentId, Microsoft.AspNetCore.Components.IComponentActivator! componentActivator, Microsoft.AspNetCore.Components.IComponentRenderMode! renderMode) -> Microsoft.AspNetCore.Components.IComponent!
 ~Microsoft.AspNetCore.Components.RenderTree.RenderTreeFrame.ComponentRenderMode.get -> Microsoft.AspNetCore.Components.IComponentRenderMode
 ~Microsoft.AspNetCore.Components.RenderTree.RenderTreeFrame.NamedEventAssignedName.get -> string

+ 14 - 1
src/Components/Components/src/RenderTree/Renderer.cs

@@ -133,7 +133,20 @@ public abstract partial class Renderer : IDisposable, IAsyncDisposable
     protected ComponentState GetComponentState(int componentId)
         => GetRequiredComponentState(componentId);
 
-    internal ComponentState GetComponentState(IComponent component)
+    /// <summary>
+    /// Gets the <see cref="IComponentRenderMode"/> for a given component if available.
+    /// </summary>
+    /// <param name="component">The component type</param>
+    /// <returns></returns>
+    protected internal virtual IComponentRenderMode? GetComponentRenderMode(IComponent component)
+        => null;
+
+    /// <summary>
+    /// Resolves the component state for a given <see cref="IComponent"/> instance.
+    /// </summary>
+    /// <param name="component">The <see cref="IComponent"/> instance</param>
+    /// <returns></returns>
+    protected internal ComponentState GetComponentState(IComponent component)
         => _componentStateByComponent.GetValueOrDefault(component);
 
     private async void RenderRootComponentsOnHotReload()

+ 21 - 13
src/Components/Components/test/Lifetime/ComponentApplicationStateTest.cs

@@ -11,7 +11,7 @@ public class ComponentApplicationStateTest
     public void InitializeExistingState_SetupsState()
     {
         // Arrange
-        var applicationState = new PersistentComponentState(new Dictionary<string, byte[]>(), new List<Func<Task>>());
+        var applicationState = new PersistentComponentState(new Dictionary<string, byte[]>(), new List<PersistComponentStateRegistration>());
         var existingState = new Dictionary<string, byte[]>
         {
             ["MyState"] = JsonSerializer.SerializeToUtf8Bytes(new byte[] { 1, 2, 3, 4 })
@@ -29,7 +29,7 @@ public class ComponentApplicationStateTest
     public void InitializeExistingState_ThrowsIfAlreadyInitialized()
     {
         // Arrange
-        var applicationState = new PersistentComponentState(new Dictionary<string, byte[]>(), new List<Func<Task>>());
+        var applicationState = new PersistentComponentState(new Dictionary<string, byte[]>(), new List<PersistComponentStateRegistration>());
         var existingState = new Dictionary<string, byte[]>
         {
             ["MyState"] = new byte[] { 1, 2, 3, 4 }
@@ -45,7 +45,7 @@ public class ComponentApplicationStateTest
     public void TryRetrieveState_ReturnsStateWhenItExists()
     {
         // Arrange
-        var applicationState = new PersistentComponentState(new Dictionary<string, byte[]>(), new List<Func<Task>>());
+        var applicationState = new PersistentComponentState(new Dictionary<string, byte[]>(), new List<PersistComponentStateRegistration>());
         var existingState = new Dictionary<string, byte[]>
         {
             ["MyState"] = JsonSerializer.SerializeToUtf8Bytes(new byte[] { 1, 2, 3, 4 })
@@ -65,8 +65,10 @@ public class ComponentApplicationStateTest
     {
         // Arrange
         var currentState = new Dictionary<string, byte[]>();
-        var applicationState = new PersistentComponentState(currentState, new List<Func<Task>>());
-        applicationState.PersistingState = true;
+        var applicationState = new PersistentComponentState(currentState, new List<PersistComponentStateRegistration>())
+        {
+            PersistingState = true
+        };
         var myState = new byte[] { 1, 2, 3, 4 };
 
         // Act
@@ -82,8 +84,10 @@ public class ComponentApplicationStateTest
     {
         // Arrange
         var currentState = new Dictionary<string, byte[]>();
-        var applicationState = new PersistentComponentState(currentState, new List<Func<Task>>());
-        applicationState.PersistingState = true;
+        var applicationState = new PersistentComponentState(currentState, new List<PersistComponentStateRegistration>())
+        {
+            PersistingState = true
+        };
         var myState = new byte[] { 1, 2, 3, 4 };
 
         applicationState.PersistAsJson("MyState", myState);
@@ -97,8 +101,10 @@ public class ComponentApplicationStateTest
     {
         // Arrange
         var currentState = new Dictionary<string, byte[]>();
-        var applicationState = new PersistentComponentState(currentState, new List<Func<Task>>());
-        applicationState.PersistingState = true;
+        var applicationState = new PersistentComponentState(currentState, new List<PersistComponentStateRegistration>())
+        {
+            PersistingState = true
+        };
         var myState = new byte[] { 1, 2, 3, 4 };
 
         // Act
@@ -114,8 +120,10 @@ public class ComponentApplicationStateTest
     {
         // Arrange
         var currentState = new Dictionary<string, byte[]>();
-        var applicationState = new PersistentComponentState(currentState, new List<Func<Task>>());
-        applicationState.PersistingState = true;
+        var applicationState = new PersistentComponentState(currentState, new List<PersistComponentStateRegistration>())
+        {
+            PersistingState = true
+        };
 
         // Act
         applicationState.PersistAsJson<byte[]>("MyState", null);
@@ -132,7 +140,7 @@ public class ComponentApplicationStateTest
         var myState = new byte[] { 1, 2, 3, 4 };
         var serialized = JsonSerializer.SerializeToUtf8Bytes(myState);
         var existingState = new Dictionary<string, byte[]>() { ["MyState"] = serialized };
-        var applicationState = new PersistentComponentState(new Dictionary<string, byte[]>(), new List<Func<Task>>());
+        var applicationState = new PersistentComponentState(new Dictionary<string, byte[]>(), new List<PersistComponentStateRegistration>());
 
         applicationState.InitializeExistingState(existingState);
 
@@ -150,7 +158,7 @@ public class ComponentApplicationStateTest
         // Arrange
         var serialized = JsonSerializer.SerializeToUtf8Bytes<byte[]>(null);
         var existingState = new Dictionary<string, byte[]>() { ["MyState"] = serialized };
-        var applicationState = new PersistentComponentState(new Dictionary<string, byte[]>(), new List<Func<Task>>());
+        var applicationState = new PersistentComponentState(new Dictionary<string, byte[]>(), new List<PersistComponentStateRegistration>());
 
         applicationState.InitializeExistingState(existingState);
 

+ 71 - 34
src/Components/Components/test/Lifetime/ComponentApplicationLifetimeTest.cs → src/Components/Components/test/Lifetime/ComponentStatePersistenceManagerTest.cs

@@ -2,6 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System.Buffers;
+using System.Collections;
 using System.Text.Json;
 using Microsoft.AspNetCore.Components.Infrastructure;
 using Microsoft.AspNetCore.Components.RenderTree;
@@ -12,7 +13,7 @@ using Microsoft.Extensions.Logging.Testing;
 
 namespace Microsoft.AspNetCore.Components;
 
-public class ComponentApplicationLifetimeTest
+public class ComponentStatePersistenceManagerTest
 {
     [Fact]
     public async Task RestoreStateAsync_InitializesStateWithDataFromTheProvidedStore()
@@ -41,7 +42,7 @@ public class ComponentApplicationLifetimeTest
         // Arrange
         var state = new Dictionary<string, byte[]>
         {
-            ["MyState"] = new byte[] { 0, 1, 2, 3, 4 }
+            ["MyState"] = [0, 1, 2, 3, 4]
         };
         var store = new TestStore(state);
         var lifetime = new ComponentStatePersistenceManager(NullLogger<ComponentStatePersistenceManager>.Instance);
@@ -52,6 +53,28 @@ public class ComponentApplicationLifetimeTest
         await Assert.ThrowsAsync<InvalidOperationException>(() => lifetime.RestoreStateAsync(store));
     }
 
+    [Fact]
+    public async Task PersistStateAsync_ThrowsWhenCallbackRenerModeCannotBeInferred()
+    {
+        // Arrange
+        var state = new Dictionary<string, byte[]>();
+        var store = new CompositeTestStore(state);
+        var lifetime = new ComponentStatePersistenceManager(NullLogger<ComponentStatePersistenceManager>.Instance);
+
+        var renderer = new TestRenderer();
+        var data = new byte[] { 1, 2, 3, 4 };
+
+        lifetime.State.RegisterOnPersisting(() =>
+        {
+            lifetime.State.PersistAsJson("MyState", new byte[] { 1, 2, 3, 4 });
+            return Task.CompletedTask;
+        });
+
+        // Act
+        // Assert
+        await Assert.ThrowsAsync<InvalidOperationException>(() => lifetime.PersistStateAsync(store, renderer));
+    }
+
     [Fact]
     public async Task PersistStateAsync_SavesPersistedStateToTheStore()
     {
@@ -67,7 +90,7 @@ public class ComponentApplicationLifetimeTest
         {
             lifetime.State.PersistAsJson("MyState", new byte[] { 1, 2, 3, 4 });
             return Task.CompletedTask;
-        });
+        }, new TestRenderMode());
 
         // Act
         await lifetime.PersistStateAsync(store, renderer);
@@ -88,7 +111,7 @@ public class ComponentApplicationLifetimeTest
         var data = new byte[] { 1, 2, 3, 4 };
         var invoked = false;
 
-        lifetime.State.RegisterOnPersisting(() => { invoked = true; return default; });
+        lifetime.State.RegisterOnPersisting(() => { invoked = true; return default; }, new TestRenderMode());
 
         // Act
         await lifetime.PersistStateAsync(store, renderer);
@@ -111,8 +134,8 @@ public class ComponentApplicationLifetimeTest
         var tcs = new TaskCompletionSource();
         var tcs2 = new TaskCompletionSource();
 
-        lifetime.State.RegisterOnPersisting(async () => { sequence.Add(1); await tcs.Task; sequence.Add(3); });
-        lifetime.State.RegisterOnPersisting(async () => { sequence.Add(2); await tcs2.Task; sequence.Add(4); });
+        lifetime.State.RegisterOnPersisting(async () => { sequence.Add(1); await tcs.Task; sequence.Add(3); }, new TestRenderMode());
+        lifetime.State.RegisterOnPersisting(async () => { sequence.Add(2); await tcs2.Task; sequence.Add(4); }, new TestRenderMode());
 
         // Act
         var persistTask = lifetime.PersistStateAsync(store, renderer);
@@ -170,8 +193,8 @@ public class ComponentApplicationLifetimeTest
         var data = new byte[] { 1, 2, 3, 4 };
         var invoked = false;
 
-        lifetime.State.RegisterOnPersisting(() => throw new InvalidOperationException());
-        lifetime.State.RegisterOnPersisting(() => { invoked = true; return Task.CompletedTask; });
+        lifetime.State.RegisterOnPersisting(() => throw new InvalidOperationException(), new TestRenderMode());
+        lifetime.State.RegisterOnPersisting(() => { invoked = true; return Task.CompletedTask; }, new TestRenderMode());
 
         // Act
         await lifetime.PersistStateAsync(store, renderer);
@@ -196,8 +219,8 @@ public class ComponentApplicationLifetimeTest
         var invoked = false;
         var tcs = new TaskCompletionSource();
 
-        lifetime.State.RegisterOnPersisting(async () => { await tcs.Task; throw new InvalidOperationException(); });
-        lifetime.State.RegisterOnPersisting(() => { invoked = true; return Task.CompletedTask; });
+        lifetime.State.RegisterOnPersisting(async () => { await tcs.Task; throw new InvalidOperationException(); }, new TestRenderMode());
+        lifetime.State.RegisterOnPersisting(() => { invoked = true; return Task.CompletedTask; }, new TestRenderMode());
 
         // Act
         var persistTask = lifetime.PersistStateAsync(store, renderer);
@@ -211,30 +234,6 @@ public class ComponentApplicationLifetimeTest
         Assert.Equal(LogLevel.Error, log.LogLevel);
     }
 
-    [Fact]
-    public async Task PersistStateAsync_ThrowsWhenDeveloperTriesToPersistStateMultipleTimes()
-    {
-        // Arrange
-        var state = new Dictionary<string, byte[]>();
-        var store = new TestStore(state);
-        var lifetime = new ComponentStatePersistenceManager(NullLogger<ComponentStatePersistenceManager>.Instance);
-
-        var renderer = new TestRenderer();
-        var data = new byte[] { 1, 2, 3, 4 };
-
-        lifetime.State.RegisterOnPersisting(() =>
-        {
-            lifetime.State.PersistAsJson<byte[]>("MyState", new byte[] { 1, 2, 3, 4 });
-            return Task.CompletedTask;
-        });
-
-        // Act
-        await lifetime.PersistStateAsync(store, renderer);
-
-        // Assert
-        await Assert.ThrowsAsync<InvalidOperationException>(() => lifetime.PersistStateAsync(store, renderer));
-    }
-
     private class TestRenderer : Renderer
     {
         public TestRenderer() : base(new ServiceCollection().BuildServiceProvider(), NullLoggerFactory.Instance)
@@ -277,4 +276,42 @@ public class ComponentApplicationLifetimeTest
             return Task.CompletedTask;
         }
     }
+
+    private class CompositeTestStore : IPersistentComponentStateStore,  IEnumerable<IPersistentComponentStateStore>
+    {
+        public CompositeTestStore(IDictionary<string, byte[]> initialState)
+        {
+            State = initialState;
+        }
+
+        public IDictionary<string, byte[]> State { get; set; }
+
+        public IEnumerator<IPersistentComponentStateStore> GetEnumerator()
+        {
+            yield return new TestStore(State);
+            yield return new TestStore(State);
+        }
+
+        public Task<IDictionary<string, byte[]>> GetPersistedStateAsync()
+        {
+            return Task.FromResult(State);
+        }
+
+        public Task PersistStateAsync(IReadOnlyDictionary<string, byte[]> state)
+        {
+            // We copy the data here because it's no longer available after this call completes.
+            State = state.ToDictionary(k => k.Key, v => v.Value);
+            return Task.CompletedTask;
+        }
+
+        IEnumerator IEnumerable.GetEnumerator()
+        {
+            return GetEnumerator();
+        }
+    }
+
+    private class TestRenderMode : IComponentRenderMode
+    {
+
+    }
 }

+ 4 - 0
src/Components/Endpoints/src/RazorComponentEndpointInvoker.cs

@@ -121,6 +121,10 @@ internal partial class RazorComponentEndpointInvoker : IRazorComponentEndpointIn
             await _renderer.SendStreamingUpdatesAsync(context, quiesceTask, bufferWriter);
         }
 
+        // Emit comment containing state.
+        var componentStateHtmlContent = await _renderer.PrerenderPersistedStateAsync(context);
+        componentStateHtmlContent.WriteTo(bufferWriter, HtmlEncoder.Default);
+
         // Invoke FlushAsync to ensure any buffered content is asynchronously written to the underlying
         // response asynchronously. In the absence of this line, the buffer gets synchronously written to the
         // response as part of the Dispose which has a perf impact.

+ 24 - 3
src/Components/Endpoints/src/Rendering/EndpointHtmlRenderer.Prerendering.cs

@@ -3,6 +3,7 @@
 
 using System.Diagnostics.CodeAnalysis;
 using System.Text.Encodings.Web;
+using Microsoft.AspNetCore.Components.Rendering;
 using Microsoft.AspNetCore.Components.Web.HtmlRendering;
 using Microsoft.AspNetCore.Html;
 using Microsoft.AspNetCore.Http;
@@ -33,19 +34,39 @@ internal partial class EndpointHtmlRenderer
         }
     }
 
+    protected override IComponentRenderMode? GetComponentRenderMode(IComponent component)
+    {
+        var componentState = GetComponentState(component);
+        var ssrRenderBoundary = GetClosestRenderModeBoundary(componentState);
+
+        if (ssrRenderBoundary is null)
+        {
+            return null;
+        }
+
+        return ssrRenderBoundary.RenderMode;
+    }
+
     private SSRRenderModeBoundary? GetClosestRenderModeBoundary(int componentId)
     {
         var componentState = GetComponentState(componentId);
+        return GetClosestRenderModeBoundary(componentState);
+    }
+
+    private static SSRRenderModeBoundary? GetClosestRenderModeBoundary(ComponentState componentState)
+    {
+        var currentComponentState = componentState;
+
         do
         {
-            if (componentState.Component is SSRRenderModeBoundary boundary)
+            if (currentComponentState.Component is SSRRenderModeBoundary boundary)
             {
                 return boundary;
             }
 
-            componentState = componentState.ParentComponentState;
+            currentComponentState = currentComponentState.ParentComponentState;
         }
-        while (componentState is not null);
+        while (currentComponentState is not null);
 
         return null;
     }

+ 173 - 23
src/Components/Endpoints/src/Rendering/EndpointHtmlRenderer.PrerenderingState.cs

@@ -1,6 +1,7 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
+using System.Collections;
 using System.Text.Encodings.Web;
 using Microsoft.AspNetCore.Components.Infrastructure;
 using Microsoft.AspNetCore.Components.Web;
@@ -15,6 +16,98 @@ internal partial class EndpointHtmlRenderer
 {
     private static readonly object InvokedRenderModesKey = new object();
 
+    public async ValueTask<IHtmlContent> PrerenderPersistedStateAsync(HttpContext httpContext)
+    {
+        SetHttpContext(httpContext);
+
+        var manager = _httpContext.RequestServices.GetRequiredService<ComponentStatePersistenceManager>();
+
+        var renderModesMetadata = httpContext.GetEndpoint()?.Metadata.GetMetadata<ConfiguredRenderModesMetadata>();
+
+        IPersistentComponentStateStore? store = null;
+
+        // There is configured render modes metadata, use this to determine where to persist state if possible
+        if (renderModesMetadata != null)
+        {
+            // No render modes are configured, do not persist state
+            if (renderModesMetadata.ConfiguredRenderModes.Length == 0)
+            {
+                return ComponentStateHtmlContent.Empty;
+            }
+
+            // Single render mode, no need to perform inference. Any component that tried to render an
+            // incompatible render mode would have failed at this point.
+            if (renderModesMetadata.ConfiguredRenderModes.Length == 1)
+            {
+                store = renderModesMetadata.ConfiguredRenderModes[0] switch
+                {
+                    InteractiveServerRenderMode => new ProtectedPrerenderComponentApplicationStore(_httpContext.RequestServices.GetRequiredService<IDataProtectionProvider>()),
+                    InteractiveWebAssemblyRenderMode => new PrerenderComponentApplicationStore(),
+                    _ => throw new InvalidOperationException("Invalid configured render mode."),
+                };
+            }
+        }
+
+        if (store != null)
+        {
+            await manager.PersistStateAsync(store, this);
+            return store switch
+            {
+                ProtectedPrerenderComponentApplicationStore protectedStore => new ComponentStateHtmlContent(protectedStore, null),
+                PrerenderComponentApplicationStore prerenderStore => new ComponentStateHtmlContent(null, prerenderStore),
+                _ => throw new InvalidOperationException("Invalid store."),
+            };
+        }
+        else
+        {
+            // We were not able to resolve a store from the configured render modes metadata, we need to capture
+            // all possible destinations for the state and persist it in all of them.
+            var serverStore = new ProtectedPrerenderComponentApplicationStore(_httpContext.RequestServices.GetRequiredService<IDataProtectionProvider>());
+            var webAssemblyStore = new PrerenderComponentApplicationStore();
+
+            // The persistence state manager checks if the store implements
+            // IEnumerable<IPersistentComponentStateStore> and if so, it invokes PersistStateAsync on each store
+            // for each of the render mode callbacks defined.
+            // We pass in a composite store with fake stores for each render mode that only take care of
+            // creating a copy of the state for each render mode.
+            // Then, we copy the state from the auto store to the server and webassembly stores and persist
+            // the real state for server and webassembly render modes.
+            // This makes sure that:
+            // 1. The persistence state manager is agnostic to the render modes.
+            // 2. The callbacks are run only once, even if the state ends up persisted in multiple locations.
+            var server = new CopyOnlyStore<InteractiveServerRenderMode>();
+            var auto = new CopyOnlyStore<InteractiveAutoRenderMode>();
+            var webAssembly = new CopyOnlyStore<InteractiveWebAssemblyRenderMode>();
+            store = new CompositeStore(server, auto, webAssembly);
+
+            await manager.PersistStateAsync(store, this);
+
+            foreach (var kvp in auto.Saved)
+            {
+                server.Saved.Add(kvp.Key, kvp.Value);
+                webAssembly.Saved.Add(kvp.Key, kvp.Value);
+            }
+
+            // Persist state only if there is state to persist
+            var saveServerTask = server.Saved.Count > 0
+                ? serverStore.PersistStateAsync(server.Saved)
+                : Task.CompletedTask;
+
+            var saveWebAssemblyTask = webAssembly.Saved.Count > 0
+                ? webAssemblyStore.PersistStateAsync(webAssembly.Saved)
+                : Task.CompletedTask;
+
+            await Task.WhenAll(
+                saveServerTask,
+                saveWebAssemblyTask);
+
+            // Do not return any HTML content if there is no state to persist for a given mode.
+            return new ComponentStateHtmlContent(
+                server.Saved.Count > 0 ? serverStore : null,
+                webAssembly.Saved.Count > 0 ? webAssemblyStore : null);
+        }
+    }
+
     public async ValueTask<IHtmlContent> PrerenderPersistedStateAsync(HttpContext httpContext, PersistedStateSerializationMode serializationMode)
     {
         SetHttpContext(httpContext);
@@ -40,21 +133,25 @@ internal partial class EndpointHtmlRenderer
             }
         }
 
+        var manager = _httpContext.RequestServices.GetRequiredService<ComponentStatePersistenceManager>();
+
         // Now given the mode, we obtain a particular store for that mode
-        var store = serializationMode switch
+        // and persist the state and return the HTML content
+        switch (serializationMode)
         {
-            PersistedStateSerializationMode.Server =>
-                new ProtectedPrerenderComponentApplicationStore(_httpContext.RequestServices.GetRequiredService<IDataProtectionProvider>()),
-            PersistedStateSerializationMode.WebAssembly =>
-                new PrerenderComponentApplicationStore(),
-            _ =>
-                throw new InvalidOperationException("Invalid persistence mode.")
-        };
-
-        // Finally, persist the state and return the HTML content
-        var manager = _httpContext.RequestServices.GetRequiredService<ComponentStatePersistenceManager>();
-        await manager.PersistStateAsync(store, Dispatcher);
-        return new ComponentStateHtmlContent(store);
+            case PersistedStateSerializationMode.Server:
+                var protectedStore = new ProtectedPrerenderComponentApplicationStore(_httpContext.RequestServices.GetRequiredService<IDataProtectionProvider>());
+                await manager.PersistStateAsync(protectedStore, this);
+                return new ComponentStateHtmlContent(protectedStore, null);
+
+            case PersistedStateSerializationMode.WebAssembly:
+                var store = new PrerenderComponentApplicationStore();
+                await manager.PersistStateAsync(store, this);
+                return new ComponentStateHtmlContent(null, store);
+
+            default:
+                throw new InvalidOperationException("Invalid persistence mode.");
+        }
     }
 
     // Internal for test only
@@ -101,27 +198,80 @@ internal partial class EndpointHtmlRenderer
             : InvokedRenderModes.Mode.None;
     }
 
-    private sealed class ComponentStateHtmlContent : IHtmlContent
+    internal sealed class ComponentStateHtmlContent : IHtmlContent
     {
-        private PrerenderComponentApplicationStore? _store;
+        public static ComponentStateHtmlContent Empty { get; } = new(null, null);
+
+        internal PrerenderComponentApplicationStore? ServerStore { get; }
 
-        public static ComponentStateHtmlContent Empty { get; }
-            = new ComponentStateHtmlContent(null);
+        internal PrerenderComponentApplicationStore? WebAssemblyStore { get; }
 
-        public ComponentStateHtmlContent(PrerenderComponentApplicationStore? store)
+        public ComponentStateHtmlContent(PrerenderComponentApplicationStore? serverStore, PrerenderComponentApplicationStore? webAssemblyStore)
         {
-            _store = store;
+            WebAssemblyStore = webAssemblyStore;
+            ServerStore = serverStore;
         }
 
         public void WriteTo(TextWriter writer, HtmlEncoder encoder)
         {
-            if (_store != null)
+            if (ServerStore is not null && ServerStore.PersistedState is not null)
             {
-                writer.Write("<!--Blazor-Component-State:");
-                writer.Write(_store.PersistedState);
+                writer.Write("<!--Blazor-Server-Component-State:");
+                writer.Write(ServerStore.PersistedState);
                 writer.Write("-->");
-                _store = null;
             }
+
+            if (WebAssemblyStore is not null && WebAssemblyStore.PersistedState is not null)
+            {
+                writer.Write("<!--Blazor-WebAssembly-Component-State:");
+                writer.Write(WebAssemblyStore.PersistedState);
+                writer.Write("-->");
+            }
+        }
+    }
+
+    internal class CompositeStore : IPersistentComponentStateStore, IEnumerable<IPersistentComponentStateStore>
+    {
+        public CompositeStore(
+            CopyOnlyStore<InteractiveServerRenderMode> server,
+            CopyOnlyStore<InteractiveAutoRenderMode> auto,
+            CopyOnlyStore<InteractiveWebAssemblyRenderMode> webassembly)
+        {
+            Server = server;
+            Auto = auto;
+            Webassembly = webassembly;
+        }
+
+        public CopyOnlyStore<InteractiveServerRenderMode> Server { get; }
+        public CopyOnlyStore<InteractiveAutoRenderMode> Auto { get; }
+        public CopyOnlyStore<InteractiveWebAssemblyRenderMode> Webassembly { get; }
+
+        public IEnumerator<IPersistentComponentStateStore> GetEnumerator()
+        {
+            yield return Server;
+            yield return Auto;
+            yield return Webassembly;
+        }
+
+        public Task<IDictionary<string, byte[]>> GetPersistedStateAsync() => throw new NotImplementedException();
+
+        public Task PersistStateAsync(IReadOnlyDictionary<string, byte[]> state) => Task.CompletedTask;
+
+        IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
+    }
+
+    internal class CopyOnlyStore<T> : IPersistentComponentStateStore where T : IComponentRenderMode
+    {
+        public Dictionary<string, byte[]> Saved { get; private set; } = new();
+
+        public Task<IDictionary<string, byte[]>> GetPersistedStateAsync() => throw new NotImplementedException();
+
+        public Task PersistStateAsync(IReadOnlyDictionary<string, byte[]> state)
+        {
+            Saved = new Dictionary<string, byte[]>(state);
+            return Task.CompletedTask;
         }
+
+        public bool SupportsRenderMode(IComponentRenderMode renderMode) => renderMode is T;
     }
 }

+ 10 - 9
src/Components/Endpoints/src/Rendering/SSRRenderModeBoundary.cs

@@ -26,12 +26,13 @@ internal class SSRRenderModeBoundary : IComponent
 
     [DynamicallyAccessedMembers(Component)]
     private readonly Type _componentType;
-    private readonly IComponentRenderMode _renderMode;
     private readonly bool _prerender;
     private RenderHandle _renderHandle;
     private IReadOnlyDictionary<string, object?>? _latestParameters;
     private string? _markerKey;
 
+    public IComponentRenderMode RenderMode { get; }
+
     public SSRRenderModeBoundary(
         HttpContext httpContext,
         [DynamicallyAccessedMembers(Component)] Type componentType,
@@ -40,7 +41,7 @@ internal class SSRRenderModeBoundary : IComponent
         AssertRenderModeIsConfigured(httpContext, componentType, renderMode);
 
         _componentType = componentType;
-        _renderMode = renderMode;
+        RenderMode = renderMode;
         _prerender = renderMode switch
         {
             InteractiveServerRenderMode mode => mode.Prerender,
@@ -76,7 +77,7 @@ internal class SSRRenderModeBoundary : IComponent
         }
     }
 
-    private static void AssertRenderModeIsConfigured<TRequiredMode>(Type componentType, IComponentRenderMode specifiedMode, IComponentRenderMode[] configuredModes, string expectedCall) where TRequiredMode: IComponentRenderMode
+    private static void AssertRenderModeIsConfigured<TRequiredMode>(Type componentType, IComponentRenderMode specifiedMode, IComponentRenderMode[] configuredModes, string expectedCall) where TRequiredMode : IComponentRenderMode
     {
         foreach (var configuredMode in configuredModes)
         {
@@ -126,7 +127,7 @@ internal class SSRRenderModeBoundary : IComponent
                 var valueType = value.GetType();
                 if (valueType.IsGenericType && valueType.GetGenericTypeDefinition() == typeof(RenderFragment<>))
                 {
-                    throw new InvalidOperationException($"Cannot pass RenderFragment<T> parameter '{name}' to component '{_componentType.Name}' with rendermode '{_renderMode.GetType().Name}'. Templated content can't be passed across a rendermode boundary, because it is arbitrary code and cannot be serialized.");
+                    throw new InvalidOperationException($"Cannot pass RenderFragment<T> parameter '{name}' to component '{_componentType.Name}' with rendermode '{RenderMode.GetType().Name}'. Templated content can't be passed across a rendermode boundary, because it is arbitrary code and cannot be serialized.");
                 }
                 else
                 {
@@ -135,7 +136,7 @@ internal class SSRRenderModeBoundary : IComponent
                     // somehow without actually emitting its result directly, wait for quiescence, and then prerender
                     // the output into a separate buffer so we can serialize it in a special way.
                     // A prototype implementation is at https://github.com/dotnet/aspnetcore/commit/ed330ff5b143974d9060828a760ad486b1d386ac
-                    throw new InvalidOperationException($"Cannot pass the parameter '{name}' to component '{_componentType.Name}' with rendermode '{_renderMode.GetType().Name}'. This is because the parameter is of the delegate type '{value.GetType()}', which is arbitrary code and cannot be serialized.");
+                    throw new InvalidOperationException($"Cannot pass the parameter '{name}' to component '{_componentType.Name}' with rendermode '{RenderMode.GetType().Name}'. This is because the parameter is of the delegate type '{value.GetType()}', which is arbitrary code and cannot be serialized.");
                 }
             }
         }
@@ -163,15 +164,15 @@ internal class SSRRenderModeBoundary : IComponent
             ? ParameterView.Empty
             : ParameterView.FromDictionary((IDictionary<string, object?>)_latestParameters);
 
-        var marker = _renderMode switch
+        var marker = RenderMode switch
         {
             InteractiveServerRenderMode server => ComponentMarker.Create(ComponentMarker.ServerMarkerType, server.Prerender, _markerKey),
             InteractiveWebAssemblyRenderMode webAssembly => ComponentMarker.Create(ComponentMarker.WebAssemblyMarkerType, webAssembly.Prerender, _markerKey),
             InteractiveAutoRenderMode auto => ComponentMarker.Create(ComponentMarker.AutoMarkerType, auto.Prerender, _markerKey),
-            _ => throw new UnreachableException($"Unknown render mode {_renderMode.GetType().FullName}"),
+            _ => throw new UnreachableException($"Unknown render mode {RenderMode.GetType().FullName}"),
         };
 
-        if (_renderMode is InteractiveServerRenderMode or InteractiveAutoRenderMode)
+        if (RenderMode is InteractiveServerRenderMode or InteractiveAutoRenderMode)
         {
             // Lazy because we don't actually want to require a whole chain of services including Data Protection
             // to be required unless you actually use Server render mode.
@@ -181,7 +182,7 @@ internal class SSRRenderModeBoundary : IComponent
             serverComponentSerializer.SerializeInvocation(ref marker, invocationId, _componentType, parameters);
         }
 
-        if (_renderMode is InteractiveWebAssemblyRenderMode or InteractiveAutoRenderMode)
+        if (RenderMode is InteractiveWebAssemblyRenderMode or InteractiveAutoRenderMode)
         {
             WebAssemblyComponentSerializer.SerializeInvocation(ref marker, _componentType, parameters);
         }

+ 360 - 1
src/Components/Endpoints/test/EndpointHtmlRendererTest.cs

@@ -10,6 +10,7 @@ using Microsoft.AspNetCore.Components.Endpoints.Tests.TestComponents;
 using Microsoft.AspNetCore.Components.Forms;
 using Microsoft.AspNetCore.Components.Forms.Mapping;
 using Microsoft.AspNetCore.Components.Infrastructure;
+using Microsoft.AspNetCore.Components.Reflection;
 using Microsoft.AspNetCore.Components.Rendering;
 using Microsoft.AspNetCore.Components.Test.Helpers;
 using Microsoft.AspNetCore.Components.Web;
@@ -1177,6 +1178,326 @@ public class EndpointHtmlRendererTest
         Assert.Equal("<h1>This is InteractiveWithInteractiveChild</h1>\n\n<p>Hello from InteractiveGreetingServer!</p>", prerenderedContent.Replace("\r\n", "\n"));
     }
 
+    [Fact]
+    public async Task PrerenderedState_EmptyWhenNoDeclaredRenderModes()
+    {
+        var declaredRenderModesMetadata = new ConfiguredRenderModesMetadata([]);
+        var endpoint = new Endpoint((context) => Task.CompletedTask, new EndpointMetadataCollection(declaredRenderModesMetadata),
+            "TestEndpoint");
+
+        var httpContext = GetHttpContext();
+        httpContext.SetEndpoint(endpoint);
+        var content = await renderer.PrerenderPersistedStateAsync(httpContext);
+
+        Assert.Equal(EndpointHtmlRenderer.ComponentStateHtmlContent.Empty, content);
+    }
+
+    public static TheoryData<IComponentRenderMode> SingleComponentRenderModeData => new TheoryData<IComponentRenderMode>
+    {
+        RenderMode.InteractiveServer,
+        RenderMode.InteractiveWebAssembly
+    };
+
+    [Theory]
+    [MemberData(nameof(SingleComponentRenderModeData))]
+    public async Task PrerenderedState_SelectsSingleStoreCorrectly(IComponentRenderMode renderMode)
+    {
+        var declaredRenderModesMetadata = new ConfiguredRenderModesMetadata([renderMode]);
+        var endpoint = new Endpoint((context) => Task.CompletedTask, new EndpointMetadataCollection(declaredRenderModesMetadata),
+            "TestEndpoint");
+
+        var httpContext = GetHttpContext();
+        httpContext.SetEndpoint(endpoint);
+        var content = await renderer.PrerenderPersistedStateAsync(httpContext);
+
+        Assert.NotNull(content);
+        var stateContent = Assert.IsType<EndpointHtmlRenderer.ComponentStateHtmlContent>(content);
+        switch (renderMode)
+        {
+            case InteractiveServerRenderMode:
+                Assert.NotNull(stateContent.ServerStore);
+                Assert.Null(stateContent.ServerStore.PersistedState);
+                Assert.Null(stateContent.WebAssemblyStore);
+                break;
+            case InteractiveWebAssemblyRenderMode:
+                Assert.NotNull(stateContent.WebAssemblyStore);
+                Assert.Null(stateContent.WebAssemblyStore.PersistedState);
+                Assert.Null(stateContent.ServerStore);
+                break;
+            default:
+                throw new InvalidOperationException($"Unexpected render mode: {renderMode}");
+        }
+    }
+
+    [Fact]
+    public async Task PrerenderedState_MultipleStoresCorrectly()
+    {
+        var declaredRenderModesMetadata = new ConfiguredRenderModesMetadata([RenderMode.InteractiveServer, RenderMode.InteractiveWebAssembly]);
+        var endpoint = new Endpoint((context) => Task.CompletedTask, new EndpointMetadataCollection(declaredRenderModesMetadata),
+            "TestEndpoint");
+
+        var httpContext = GetHttpContext();
+        httpContext.SetEndpoint(endpoint);
+        var content = await renderer.PrerenderPersistedStateAsync(httpContext);
+
+        Assert.NotNull(content);
+        var stateContent = Assert.IsType<EndpointHtmlRenderer.ComponentStateHtmlContent>(content);
+        Assert.Null(stateContent.ServerStore);
+        Assert.Null(stateContent.WebAssemblyStore);
+    }
+
+    [Theory]
+    [InlineData("server")]
+    [InlineData("wasm")]
+    [InlineData("auto")]
+    public async Task PrerenderedState_PersistToStores_OnlyWhenContentIsAvailable(string renderMode)
+    {
+        IComponentRenderMode persistenceMode = renderMode switch
+        {
+            "server" => RenderMode.InteractiveServer,
+            "wasm" => RenderMode.InteractiveWebAssembly,
+            "auto" => RenderMode.InteractiveAuto,
+            _ => throw new InvalidOperationException($"Unexpected render mode: {renderMode}"),
+        };
+
+        var declaredRenderModesMetadata = new ConfiguredRenderModesMetadata([RenderMode.InteractiveServer, RenderMode.InteractiveWebAssembly]);
+        var endpoint = new Endpoint((context) => Task.CompletedTask, new EndpointMetadataCollection(declaredRenderModesMetadata),
+            "TestEndpoint");
+
+        var httpContext = GetHttpContext();
+        httpContext.SetEndpoint(endpoint);
+        var state = httpContext.RequestServices.GetRequiredService<PersistentComponentState>();
+
+        state.RegisterOnPersisting(() =>
+        {
+            state.PersistAsJson(renderMode, "persisted");
+            return Task.CompletedTask;
+        }, persistenceMode);
+
+        var content = await renderer.PrerenderPersistedStateAsync(httpContext);
+
+        Assert.NotNull(content);
+        var stateContent = Assert.IsType<EndpointHtmlRenderer.ComponentStateHtmlContent>(content);
+        switch (persistenceMode)
+        {
+            case InteractiveServerRenderMode:
+                Assert.NotNull(stateContent.ServerStore);
+                Assert.NotNull(stateContent.ServerStore.PersistedState);
+                Assert.Null(stateContent.WebAssemblyStore);
+                break;
+            case InteractiveWebAssemblyRenderMode:
+                Assert.NotNull(stateContent.WebAssemblyStore);
+                Assert.NotNull(stateContent.WebAssemblyStore.PersistedState);
+                Assert.Null(stateContent.ServerStore);
+                break;
+            case InteractiveAutoRenderMode:
+                Assert.NotNull(stateContent.ServerStore);
+                Assert.NotNull(stateContent.ServerStore.PersistedState);
+                Assert.NotNull(stateContent.WebAssemblyStore);
+                Assert.NotNull(stateContent.WebAssemblyStore.PersistedState);
+                break;
+            default:
+                break;
+        }
+    }
+
+    [Theory]
+    [InlineData("server")]
+    [InlineData("wasm")]
+    public async Task PrerenderedState_PersistToStores_DoesNotNeedToInferRenderMode_ForSingleRenderMode(string declaredRenderMode)
+    {
+        IComponentRenderMode configuredMode = declaredRenderMode switch
+        {
+            "server" => RenderMode.InteractiveServer,
+            "wasm" => RenderMode.InteractiveWebAssembly,
+            "auto" => RenderMode.InteractiveAuto,
+            _ => throw new InvalidOperationException($"Unexpected render mode: {declaredRenderMode}"),
+        };
+
+        var declaredRenderModesMetadata = new ConfiguredRenderModesMetadata([configuredMode]);
+        var endpoint = new Endpoint((context) => Task.CompletedTask, new EndpointMetadataCollection(declaredRenderModesMetadata),
+            "TestEndpoint");
+
+        var httpContext = GetHttpContext();
+        httpContext.SetEndpoint(endpoint);
+        var state = httpContext.RequestServices.GetRequiredService<PersistentComponentState>();
+
+        state.RegisterOnPersisting(() =>
+        {
+            state.PersistAsJson("key", "persisted");
+            return Task.CompletedTask;
+        });
+
+        var content = await renderer.PrerenderPersistedStateAsync(httpContext);
+
+        Assert.NotNull(content);
+        var stateContent = Assert.IsType<EndpointHtmlRenderer.ComponentStateHtmlContent>(content);
+        switch (configuredMode)
+        {
+            case InteractiveServerRenderMode:
+                Assert.NotNull(stateContent.ServerStore);
+                Assert.NotNull(stateContent.ServerStore.PersistedState);
+                Assert.Null(stateContent.WebAssemblyStore);
+                break;
+            case InteractiveWebAssemblyRenderMode:
+                Assert.NotNull(stateContent.WebAssemblyStore);
+                Assert.NotNull(stateContent.WebAssemblyStore.PersistedState);
+                Assert.Null(stateContent.ServerStore);
+                break;
+            default:
+                break;
+        }
+    }
+
+    [Fact]
+    public async Task PrerenderedState_Throws_WhenItCanInfer_CallbackRenderMode_ForMultipleRenderModes()
+    {
+        var declaredRenderModesMetadata = new ConfiguredRenderModesMetadata([RenderMode.InteractiveServer, RenderMode.InteractiveWebAssembly]);
+        var endpoint = new Endpoint((context) => Task.CompletedTask, new EndpointMetadataCollection(declaredRenderModesMetadata),
+            "TestEndpoint");
+
+        var httpContext = GetHttpContext();
+        httpContext.SetEndpoint(endpoint);
+        var state = httpContext.RequestServices.GetRequiredService<PersistentComponentState>();
+
+        state.RegisterOnPersisting(() =>
+        {
+            state.PersistAsJson("key", "persisted");
+            return Task.CompletedTask;
+        });
+
+        await Assert.ThrowsAsync<InvalidOperationException>(async () => await renderer.PrerenderPersistedStateAsync(httpContext));
+    }
+
+    [Theory]
+    [InlineData("server")]
+    [InlineData("auto")]
+    [InlineData("wasm")]
+    public async Task PrerenderedState_InfersCallbackRenderMode_ForMultipleRenderModes(string renderMode)
+    {
+        IComponentRenderMode persistenceMode = renderMode switch
+        {
+            "server" => RenderMode.InteractiveServer,
+            "wasm" => RenderMode.InteractiveWebAssembly,
+            "auto" => RenderMode.InteractiveAuto,
+            _ => throw new InvalidOperationException($"Unexpected render mode: {renderMode}"),
+        };
+        var declaredRenderModesMetadata = new ConfiguredRenderModesMetadata([RenderMode.InteractiveServer, RenderMode.InteractiveWebAssembly]);
+        var endpoint = new Endpoint((context) => Task.CompletedTask, new EndpointMetadataCollection(declaredRenderModesMetadata),
+            "TestEndpoint");
+
+        var httpContext = GetHttpContext();
+        httpContext.SetEndpoint(endpoint);
+        var state = httpContext.RequestServices.GetRequiredService<PersistentComponentState>();
+
+        var ssrBoundary = new SSRRenderModeBoundary(httpContext, typeof(PersistenceComponent), persistenceMode);
+        var id = renderer.AssignRootComponentId(ssrBoundary);
+
+        await renderer.Dispatcher.InvokeAsync(() => renderer.RenderRootComponentAsync(id, ParameterView.Empty));
+
+        var content = await renderer.PrerenderPersistedStateAsync(httpContext);
+        Assert.NotNull(content);
+        var stateContent = Assert.IsType<EndpointHtmlRenderer.ComponentStateHtmlContent>(content);
+        switch (persistenceMode)
+        {
+            case InteractiveServerRenderMode:
+                Assert.NotNull(stateContent.ServerStore);
+                Assert.NotNull(stateContent.ServerStore.PersistedState);
+                Assert.Null(stateContent.WebAssemblyStore);
+                break;
+            case InteractiveWebAssemblyRenderMode:
+                Assert.NotNull(stateContent.WebAssemblyStore);
+                Assert.NotNull(stateContent.WebAssemblyStore.PersistedState);
+                Assert.Null(stateContent.ServerStore);
+                break;
+            case InteractiveAutoRenderMode:
+                Assert.NotNull(stateContent.ServerStore);
+                Assert.NotNull(stateContent.ServerStore.PersistedState);
+                Assert.NotNull(stateContent.WebAssemblyStore);
+                Assert.NotNull(stateContent.WebAssemblyStore.PersistedState);
+                break;
+            default:
+                break;
+        }
+    }
+
+    [Theory]
+    [InlineData("server", "server", true)]
+    [InlineData("auto", "server", true)]
+    [InlineData("auto", "wasm", true)]
+    [InlineData("wasm", "wasm", true)]
+    // Note that when an incompatible explicit render mode is specified we don't serialize the data.
+    [InlineData("server", "wasm", false)]
+    [InlineData("wasm", "server", false)]
+    public async Task PrerenderedState_ExplicitRenderModes_AreRespected(string renderMode, string declared, bool persisted)
+    {
+        IComponentRenderMode persistenceMode = renderMode switch
+        {
+            "server" => RenderMode.InteractiveServer,
+            "wasm" => RenderMode.InteractiveWebAssembly,
+            "auto" => RenderMode.InteractiveAuto,
+            _ => throw new InvalidOperationException($"Unexpected render mode: {renderMode}"),
+        };
+
+        IComponentRenderMode configuredMode = declared switch
+        {
+            "server" => RenderMode.InteractiveServer,
+            "wasm" => RenderMode.InteractiveWebAssembly,
+            "auto" => RenderMode.InteractiveAuto,
+            _ => throw new InvalidOperationException($"Unexpected render mode: {declared}"),
+        };
+
+        var declaredRenderModesMetadata = new ConfiguredRenderModesMetadata([configuredMode]);
+        var endpoint = new Endpoint((context) => Task.CompletedTask, new EndpointMetadataCollection(declaredRenderModesMetadata),
+            "TestEndpoint");
+
+        var httpContext = GetHttpContext();
+        httpContext.SetEndpoint(endpoint);
+        var state = httpContext.RequestServices.GetRequiredService<PersistentComponentState>();
+
+        var ssrBoundary = new SSRRenderModeBoundary(httpContext, typeof(PersistenceComponent), configuredMode);
+        var id = renderer.AssignRootComponentId(ssrBoundary);
+        await renderer.Dispatcher.InvokeAsync(() => renderer.RenderRootComponentAsync(
+            id,
+            ParameterView.FromDictionary(new Dictionary<string, object>
+            {
+                ["Mode"] = renderMode,
+            })));
+
+        var content = await renderer.PrerenderPersistedStateAsync(httpContext);
+        Assert.NotNull(content);
+        var stateContent = Assert.IsType<EndpointHtmlRenderer.ComponentStateHtmlContent>(content);
+        switch (configuredMode)
+        {
+            case InteractiveServerRenderMode:
+                if (persisted)
+                {
+                    Assert.NotNull(stateContent.ServerStore);
+                    Assert.NotNull(stateContent.ServerStore.PersistedState);
+                }
+                else
+                {
+                    Assert.Null(stateContent.ServerStore.PersistedState);
+                }
+                Assert.Null(stateContent.WebAssemblyStore);
+                break;
+            case InteractiveWebAssemblyRenderMode:
+                if (persisted)
+                {
+                    Assert.NotNull(stateContent.WebAssemblyStore);
+                    Assert.NotNull(stateContent.WebAssemblyStore.PersistedState);
+                }
+                else
+                {
+                    Assert.Null(stateContent.WebAssemblyStore.PersistedState);
+                }
+                Assert.Null(stateContent.ServerStore);
+                break;
+            default:
+                break;
+        }
+    }
+
     private class NamedEventHandlerComponent : ComponentBase
     {
         [Parameter]
@@ -1230,7 +1551,7 @@ public class EndpointHtmlRendererTest
             builder.OpenElement(0, "form");
             builder.AddAttribute(1, "onsubmit", !hasRendered
                 ? () => { Message = "Received call to original handler"; }
-                : () => { Message = "Received call to updated handler"; });
+            : () => { Message = "Received call to updated handler"; });
             builder.AddNamedEvent("onsubmit", "default");
             builder.CloseElement();
         }
@@ -1266,6 +1587,44 @@ public class EndpointHtmlRendererTest
             => _renderFragment(builder);
     }
 
+    class PersistenceComponent : IComponent
+    {
+        [Inject] public PersistentComponentState State { get; set; }
+
+        [Parameter] public string Mode { get; set; }
+
+        private Task PersistState()
+        {
+            State.PersistAsJson("key", "value");
+            return Task.CompletedTask;
+        }
+
+        public void Attach(RenderHandle renderHandle)
+        {
+        }
+
+        public Task SetParametersAsync(ParameterView parameters)
+        {
+            ComponentProperties.SetProperties(parameters, this);
+            switch (Mode)
+            {
+                case "server":
+                    State.RegisterOnPersisting(PersistState, RenderMode.InteractiveServer);
+                    break;
+                case "wasm":
+                    State.RegisterOnPersisting(PersistState, RenderMode.InteractiveWebAssembly);
+                    break;
+                case "auto":
+                    State.RegisterOnPersisting(PersistState, RenderMode.InteractiveAuto);
+                    break;
+                default:
+                    State.RegisterOnPersisting(PersistState);
+                    break;
+            }
+            return Task.CompletedTask;
+        }
+    }
+
     private static string HtmlContentToString(IHtmlAsyncContent result)
     {
         var writer = new StringWriter();

+ 13 - 3
src/Components/Server/src/Circuits/CircuitFactory.cs

@@ -61,8 +61,14 @@ internal sealed partial class CircuitFactory : ICircuitFactory
             navigationManager.Initialize(baseUri, uri);
         }
 
-        var appLifetime = scope.ServiceProvider.GetRequiredService<ComponentStatePersistenceManager>();
-        await appLifetime.RestoreStateAsync(store);
+        if (components.Count > 0)
+        {
+            // Skip initializing the state if there are no components.
+            // This is the case on Blazor Web scenarios, which will initialize the state
+            // when the first set of components is provided via an UpdateRootComponents call.
+            var appLifetime = scope.ServiceProvider.GetRequiredService<ComponentStatePersistenceManager>();
+            await appLifetime.RestoreStateAsync(store);
+        }
 
         var serverComponentDeserializer = scope.ServiceProvider.GetRequiredService<IServerComponentDeserializer>();
         var jsComponentInterop = new CircuitJSComponentInterop(_options);
@@ -76,7 +82,11 @@ internal sealed partial class CircuitFactory : ICircuitFactory
             jsRuntime,
             jsComponentInterop);
 
-        var circuitHandlers = scope.ServiceProvider.GetServices<CircuitHandler>()
+        // In Blazor Server we have already restored the app state, so we can get the handlers from DI.
+        // In Blazor Web the state is provided in the first call to UpdateRootComponents, so we need to
+        // delay creating the handlers until then. Otherwise, a handler would be able to access the state
+        // in the constructor for Blazor Server, but not in Blazor Web.
+        var circuitHandlers = components.Count == 0 ? [] : scope.ServiceProvider.GetServices<CircuitHandler>()
             .OrderBy(h => h.Order)
             .ToArray();
 

+ 146 - 4
src/Components/Server/src/Circuits/CircuitHost.cs

@@ -2,8 +2,10 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System.Globalization;
+using System.Linq;
 using System.Security.Claims;
 using Microsoft.AspNetCore.Components.Authorization;
+using Microsoft.AspNetCore.Components.Infrastructure;
 using Microsoft.AspNetCore.SignalR;
 using Microsoft.Extensions.DependencyInjection;
 using Microsoft.Extensions.Logging;
@@ -18,11 +20,12 @@ internal partial class CircuitHost : IAsyncDisposable
 {
     private readonly AsyncServiceScope _scope;
     private readonly CircuitOptions _options;
-    private readonly CircuitHandler[] _circuitHandlers;
     private readonly RemoteNavigationManager _navigationManager;
     private readonly ILogger _logger;
     private readonly Func<Func<Task>, Task> _dispatchInboundActivity;
+    private CircuitHandler[] _circuitHandlers;
     private bool _initialized;
+    private bool _isFirstUpdate = true;
     private bool _disposed;
 
     // This event is fired when there's an unrecoverable exception coming from the circuit, and
@@ -111,8 +114,15 @@ internal partial class CircuitHost : IAsyncDisposable
             {
                 _initialized = true; // We're ready to accept incoming JSInterop calls from here on
 
-                await OnCircuitOpenedAsync(cancellationToken);
-                await OnConnectionUpAsync(cancellationToken);
+                // We only run the handlers in case we are in a Blazor Server scenario, which renders
+                // the components inmediately during start.
+                // On Blazor Web scenarios we delay running these handlers until the first UpdateRootComponents call
+                // We do this so that the handlers can have access to the restored application state.
+                if (Descriptors.Count > 0)
+                {
+                    await OnCircuitOpenedAsync(cancellationToken);
+                    await OnConnectionUpAsync(cancellationToken);
+                }
 
                 // Here, we add each root component but don't await the returned tasks so that the
                 // components can be processed in parallel.
@@ -130,7 +140,20 @@ internal partial class CircuitHost : IAsyncDisposable
                 // At this point all components have successfully produced an initial render and we can clear the contents of the component
                 // application state store. This ensures the memory that was not used during the initial render of these components gets
                 // reclaimed since no-one else is holding on to it any longer.
-                store.ExistingState.Clear();
+                // This is also important because otherwise components will keep reusing the existing state after
+                // the initial render instead of initializing their state from the original sources like the Db or a
+                // web service, preventing UI updates.
+                if (Descriptors.Count > 0)
+                {
+                    store.ExistingState.Clear();
+                }
+
+                // This variable is used to track that this is the first time we are updating components.
+                // In Blazor Web scenarios the app will send an initial empty list of descriptors,
+                // so we want to make sure that we allow setting up the state in that case.
+                // In Blazor Server the initial set of descriptors is provided via the call to Start, so
+                // we want to make sure we don't take any state afterwards.
+                _isFirstUpdate = Descriptors.Count == 0;
 
                 Log.InitializationSucceeded(_logger);
             }
@@ -702,6 +725,113 @@ internal partial class CircuitHost : IAsyncDisposable
         }
     }
 
+    internal Task UpdateRootComponents(
+        (RootComponentOperation, ComponentDescriptor?)[] operations,
+        ProtectedPrerenderComponentApplicationStore store,
+        IServerComponentDeserializer serverComponentDeserializer,
+        CancellationToken cancellation)
+    {
+        Log.UpdateRootComponentsStarted(_logger);
+
+        return Renderer.Dispatcher.InvokeAsync(async () =>
+        {
+            var shouldClearStore = false;
+            Task[]? pendingTasks = null;
+            try
+            {
+                if (Descriptors.Count > 0)
+                {
+                    // Block updating components if they were provided during StartCircuit. This keeps
+                    // the footprint for Blazor Server closer to what it was before.
+                    throw new InvalidOperationException("UpdateRootComponents is not supported when components have" +
+                        " been provided during circuit start up.");
+                }
+                if (_isFirstUpdate)
+                {
+                    _isFirstUpdate = false;
+                    if (store != null)
+                    {
+                        shouldClearStore = true;
+                        // We only do this if we have no root components. Otherwise, the state would have been
+                        // provided during the start up process
+                        var appLifetime = _scope.ServiceProvider.GetRequiredService<ComponentStatePersistenceManager>();
+                        await appLifetime.RestoreStateAsync(store);
+                    }
+
+                    // Retrieve the circuit handlers at this point.
+                    _circuitHandlers = [.. _scope.ServiceProvider.GetServices<CircuitHandler>().OrderBy(h => h.Order)];
+                    await OnCircuitOpenedAsync(cancellation);
+                    await OnConnectionUpAsync(cancellation);
+
+                    for (var i = 0; i < operations.Length; i++)
+                    {
+                        var operation = operations[i];
+                        if (operation.Item1.Type != RootComponentOperationType.Add)
+                        {
+                            throw new InvalidOperationException($"The first set of update operations must always be of type {nameof(RootComponentOperationType.Add)}");
+                        }
+                    }
+
+                    pendingTasks = new Task[operations.Length];
+                }
+
+                for (var i = 0; i < operations.Length;i++)
+                {
+                    var (operation, descriptor) = operations[i];
+                    switch (operation.Type)
+                    {
+                        case RootComponentOperationType.Add:
+                            var task = Renderer.AddComponentAsync(descriptor.ComponentType, descriptor.Parameters, operation.SelectorId.Value.ToString(CultureInfo.InvariantCulture));
+                            if (pendingTasks != null)
+                            {
+                                pendingTasks[i] = task;
+                            }
+                            break;
+                        case RootComponentOperationType.Update:
+                            var componentType = Renderer.GetExistingComponentType(operation.ComponentId.Value);
+                            if (descriptor.ComponentType != componentType)
+                            {
+                                Log.InvalidComponentTypeForUpdate(_logger, message: "Component type mismatch.");
+                                throw new InvalidOperationException($"Incorrect type for descriptor '{descriptor.ComponentType.FullName}'");
+                            }
+
+                            // We don't need to await component updates as any unhandled exception will be reported and terminate the circuit.
+                            _ = Renderer.UpdateRootComponentAsync(operation.ComponentId.Value, descriptor.Parameters);
+
+                            break;
+                        case RootComponentOperationType.Remove:
+                            Renderer.RemoveExistingRootComponent(operation.ComponentId.Value);
+                            break;
+                    }
+                }
+
+                if (pendingTasks != null)
+                {
+                    await Task.WhenAll(pendingTasks);
+                }
+
+                Log.UpdateRootComponentsSucceeded(_logger);
+            }
+            catch (Exception ex)
+            {
+                // Report errors asynchronously. UpdateRootComponents is designed not to throw.
+                Log.UpdateRootComponentsFailed(_logger, ex);
+                UnhandledException?.Invoke(this, new UnhandledExceptionEventArgs(ex, isTerminating: false));
+                await TryNotifyClientErrorAsync(Client, GetClientErrorMessage(ex), ex);
+            }
+            finally
+            {
+                if (shouldClearStore)
+                {
+                    // At this point all components have successfully produced an initial render and we can clear the contents of the component
+                    // application state store. This ensures the memory that was not used during the initial render of these components gets
+                    // reclaimed since no-one else is holding on to it any longer.
+                    store.ExistingState.Clear();
+                }
+            }
+        });
+    }
+
     private static partial class Log
     {
         // 100s used for lifecycle stuff
@@ -740,6 +870,15 @@ internal partial class CircuitHost : IAsyncDisposable
         [LoggerMessage(110, LogLevel.Error, "Unhandled error invoking circuit handler type {handlerType}.{handlerMethod}: {Message}", EventName = "CircuitHandlerFailed")]
         private static partial void CircuitHandlerFailed(ILogger logger, Type handlerType, string handlerMethod, string message, Exception exception);
 
+        [LoggerMessage(111, LogLevel.Debug, "Update root components started.", EventName = nameof(UpdateRootComponentsStarted))]
+        public static partial void UpdateRootComponentsStarted(ILogger logger);
+
+        [LoggerMessage(112, LogLevel.Debug, "Update root components succeeded.", EventName = nameof(UpdateRootComponentsSucceeded))]
+        public static partial void UpdateRootComponentsSucceeded(ILogger logger);
+
+        [LoggerMessage(113, LogLevel.Debug, "Update root components failed.", EventName = nameof(UpdateRootComponentsFailed))]
+        public static partial void UpdateRootComponentsFailed(ILogger logger, Exception exception);
+
         public static void CircuitHandlerFailed(ILogger logger, CircuitHandler handler, string handlerMethod, Exception exception)
         {
             CircuitHandlerFailed(
@@ -765,6 +904,9 @@ internal partial class CircuitHost : IAsyncDisposable
         [LoggerMessage(115, LogLevel.Debug, "An exception occurred on the circuit host '{CircuitId}' while the client is disconnected.", EventName = "UnhandledExceptionClientDisconnected")]
         public static partial void UnhandledExceptionClientDisconnected(ILogger logger, CircuitId circuitId, Exception exception);
 
+        [LoggerMessage(116, LogLevel.Debug, "The root component operation of type 'Update' was invalid: {Message}", EventName = nameof(InvalidComponentTypeForUpdate))]
+        public static partial void InvalidComponentTypeForUpdate(ILogger logger, string message);
+
         [LoggerMessage(200, LogLevel.Debug, "Failed to parse the event data when trying to dispatch an event.", EventName = "DispatchEventFailedToParseEventData")]
         public static partial void DispatchEventFailedToParseEventData(ILogger logger, Exception ex);
 

+ 1 - 1
src/Components/Server/src/Circuits/IServerComponentDeserializer.cs

@@ -10,6 +10,6 @@ internal interface IServerComponentDeserializer
     bool TryDeserializeComponentDescriptorCollection(
         string serializedComponentRecords,
         out List<ComponentDescriptor> descriptors);
-
     bool TryDeserializeSingleComponentDescriptor(ComponentMarker record, [NotNullWhen(true)] out ComponentDescriptor? result);
+    bool TryDeserializeRootComponentOperations(string serializedComponentOperations, out (RootComponentOperation, ComponentDescriptor?)[] operationsWithDescriptors);
 }

+ 6 - 90
src/Components/Server/src/Circuits/RemoteRenderer.cs

@@ -3,9 +3,7 @@
 
 using System.Collections.Concurrent;
 using System.Diagnostics.CodeAnalysis;
-using System.Globalization;
 using System.Linq;
-using System.Text.Json;
 using Microsoft.AspNetCore.Components.RenderTree;
 using Microsoft.AspNetCore.Components.Web;
 using Microsoft.AspNetCore.SignalR;
@@ -72,93 +70,14 @@ internal partial class RemoteRenderer : WebRenderer
         _ = CaptureAsyncExceptions(attachComponentTask);
     }
 
-    protected override void UpdateRootComponents(string operationsJson)
-    {
-        var operations = JsonSerializer.Deserialize<IEnumerable<RootComponentOperation>>(
-            operationsJson,
-            ServerComponentSerializationSettings.JsonSerializationOptions);
-
-        foreach (var operation in operations)
-        {
-            switch (operation.Type)
-            {
-                case RootComponentOperationType.Add:
-                    AddRootComponent(operation);
-                    break;
-                case RootComponentOperationType.Update:
-                    UpdateRootComponent(operation);
-                    break;
-                case RootComponentOperationType.Remove:
-                    RemoveRootComponent(operation);
-                    break;
-            }
-        }
-
-        return;
-
-        void AddRootComponent(RootComponentOperation operation)
-        {
-            if (operation.SelectorId is not { } selectorId)
-            {
-                Log.InvalidRootComponentOperation(_logger, operation.Type, message: "Missing selector ID.");
-                return;
-            }
+    internal Task UpdateRootComponentAsync(int componentId, ParameterView initialParameters) =>
+        RenderRootComponentAsync(componentId, initialParameters);
 
-            if (operation.Marker is not { } marker)
-            {
-                Log.InvalidRootComponentOperation(_logger, operation.Type, message: "Missing marker.");
-                return;
-            }
-
-            if (!_serverComponentDeserializer.TryDeserializeSingleComponentDescriptor(marker, out var descriptor))
-            {
-                throw new InvalidOperationException("Failed to deserialize a component descriptor when adding a new root component.");
-            }
+    internal void RemoveExistingRootComponent(int componentId) =>
+        RemoveRootComponent(componentId);
 
-            _ = AddComponentAsync(descriptor.ComponentType, descriptor.Parameters, selectorId.ToString(CultureInfo.InvariantCulture));
-        }
-
-        void UpdateRootComponent(RootComponentOperation operation)
-        {
-            if (operation.ComponentId is not { } componentId)
-            {
-                Log.InvalidRootComponentOperation(_logger, operation.Type, message: "Missing component ID.");
-                return;
-            }
-
-            if (operation.Marker is not { } marker)
-            {
-                Log.InvalidRootComponentOperation(_logger, operation.Type, message: "Missing marker.");
-                return;
-            }
-
-            var componentState = GetComponentState(componentId);
-
-            if (!_serverComponentDeserializer.TryDeserializeSingleComponentDescriptor(marker, out var descriptor))
-            {
-                throw new InvalidOperationException("Failed to deserialize a component descriptor when updating an existing root component.");
-            }
-
-            if (descriptor.ComponentType != componentState.Component.GetType())
-            {
-                Log.InvalidRootComponentOperation(_logger, operation.Type, message: "Component type mismatch.");
-                return;
-            }
-
-            _ = RenderRootComponentAsync(componentId, descriptor.Parameters);
-        }
-
-        void RemoveRootComponent(RootComponentOperation operation)
-        {
-            if (operation.ComponentId is not { } componentId)
-            {
-                Log.InvalidRootComponentOperation(_logger, operation.Type, message: "Missing component ID.");
-                return;
-            }
-
-            this.RemoveRootComponent(componentId);
-        }
-    }
+    internal Type GetExistingComponentType(int componentId) =>
+        GetComponentState(componentId).Component.GetType();
 
     protected override void ProcessPendingRender()
     {
@@ -483,9 +402,6 @@ internal partial class RemoteRenderer : WebRenderer
 
         [LoggerMessage(107, LogLevel.Debug, "The queue of unacknowledged render batches is full.", EventName = "FullUnacknowledgedRenderBatchesQueue")]
         public static partial void FullUnacknowledgedRenderBatchesQueue(ILogger logger);
-
-        [LoggerMessage(108, LogLevel.Debug, "The root component operation of type '{OperationType}' was invalid: {Message}", EventName = "InvalidRootComponentOperation")]
-        public static partial void InvalidRootComponentOperation(ILogger logger, RootComponentOperationType operationType, string message);
     }
 }
 

+ 96 - 0
src/Components/Server/src/Circuits/ServerComponentDeserializer.cs

@@ -1,6 +1,7 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
+using System.Buffers;
 using System.Diagnostics.CodeAnalysis;
 using System.Text;
 using System.Text.Json;
@@ -270,6 +271,95 @@ internal sealed partial class ServerComponentDeserializer : IServerComponentDese
         return (componentDescriptor, serverComponent);
     }
 
+    public bool TryDeserializeRootComponentOperations(string serializedComponentOperations, out (RootComponentOperation, ComponentDescriptor?)[] operations)
+    {
+        int[]? seenComponentIdsStorage = null;
+        try
+        {
+            var result = JsonSerializer.Deserialize<RootComponentOperation[]>(
+                serializedComponentOperations,
+                ServerComponentSerializationSettings.JsonSerializationOptions);
+
+            operations = new (RootComponentOperation, ComponentDescriptor?)[result.Length];
+
+            Span<int> seenComponentIds = result.Length <= 128
+                ? stackalloc int[result.Length]
+                : (seenComponentIdsStorage = ArrayPool<int>.Shared.Rent(result.Length)).AsSpan(0, result.Length);
+            var currentComponentIdIndex = 0;
+            for (var i = 0; i < result.Length; i++)
+            {
+                var operation = result[i];
+                if (operation.Type == RootComponentOperationType.Remove ||
+                    operation.Type == RootComponentOperationType.Update)
+                {
+                    if (operation.ComponentId == null)
+                    {
+                        Log.InvalidRootComponentOperation(_logger, operation.Type, message: "Missing component ID.");
+                        operations = null;
+                        return false;
+                    }
+
+                    if (seenComponentIds[0..currentComponentIdIndex]
+                        .Contains(operation.ComponentId.Value))
+                    {
+                        Log.InvalidRootComponentOperation(_logger, operation.Type, message: "Duplicate component ID.");
+                        operations = null;
+                        return false;
+                    }
+
+                    seenComponentIds[currentComponentIdIndex++] = operation.ComponentId.Value;
+                }
+
+                if (operation.Type == RootComponentOperationType.Remove)
+                {
+                    operations[i] = (operation, null);
+                    continue;
+                }
+
+                if (operation.Type == RootComponentOperationType.Add)
+                {
+                    if (operation.SelectorId is not { } selectorId)
+                    {
+                        Log.InvalidRootComponentOperation(_logger, operation.Type, message: "Missing selector ID.");
+                        operations = null;
+                        return false;
+                    }
+                }
+
+                if (operation.Marker == null)
+                {
+                    Log.InvalidRootComponentOperation(_logger, operation.Type, message: "Missing marker.");
+                    operations = null;
+                    return false;
+                }
+
+                if (!TryDeserializeSingleComponentDescriptor(operation.Marker.Value, out var descriptor))
+                {
+                    operations = null;
+                    return false;
+                }
+
+                operations[i] = (operation, descriptor);
+            }
+
+            return true;
+
+        }
+        catch (Exception ex)
+        {
+            Log.FailedToProcessRootComponentOperations(_logger, ex);
+            operations = null;
+            return false;
+        }
+        finally
+        {
+            if (seenComponentIdsStorage != null)
+            {
+                ArrayPool<int>.Shared.Return(seenComponentIdsStorage);
+            }
+        }
+    }
+
     private static partial class Log
     {
         [LoggerMessage(1, LogLevel.Debug, "Failed to deserialize the component descriptor.", EventName = "FailedToDeserializeDescriptor")]
@@ -301,5 +391,11 @@ internal sealed partial class ServerComponentDeserializer : IServerComponentDese
 
         [LoggerMessage(10, LogLevel.Debug, "The descriptor with sequence '{sequence}' was already used for the current invocationId '{invocationId}'.", EventName = "ReusedDescriptorSequence")]
         public static partial void ReusedDescriptorSequence(ILogger<ServerComponentDeserializer> logger, int sequence, string invocationId);
+
+        [LoggerMessage(11, LogLevel.Debug, "The root component operation of type '{OperationType}' was invalid: {Message}", EventName = "InvalidRootComponentOperation")]
+        public static partial void InvalidRootComponentOperation(ILogger logger, RootComponentOperationType operationType, string message);
+
+        [LoggerMessage(12, LogLevel.Debug, "Failed to parse root component operations", EventName = nameof(FailedToProcessRootComponentOperations))]
+        public static partial void FailedToProcessRootComponentOperations(ILogger logger, Exception exception);
     }
 }

+ 27 - 0
src/Components/Server/src/ComponentHub.cs

@@ -160,6 +160,33 @@ internal sealed partial class ComponentHub : Hub
         }
     }
 
+    public async Task UpdateRootComponents(string serializedComponentOperations, string applicationState)
+    {
+        var circuitHost = await GetActiveCircuitAsync();
+        if (circuitHost == null)
+        {
+            return;
+        }
+
+        if (!_serverComponentSerializer.TryDeserializeRootComponentOperations(
+            serializedComponentOperations,
+            out var operations))
+        {
+            // There was an error, so kill the circuit.
+            await _circuitRegistry.TerminateAsync(circuitHost.CircuitId);
+            await NotifyClientError(Clients.Caller, "The list of component operations is not valid.");
+            Context.Abort();
+
+            return;
+        }
+
+        var store = !string.IsNullOrEmpty(applicationState) ?
+            new ProtectedPrerenderComponentApplicationStore(applicationState, _dataProtectionProvider) :
+            new ProtectedPrerenderComponentApplicationStore(_dataProtectionProvider);
+
+        _ = circuitHost.UpdateRootComponents(operations, store, _serverComponentSerializer, Context.ConnectionAborted);
+    }
+
     public async ValueTask<bool> ConnectCircuit(string circuitIdSecret)
     {
         // TryParseCircuitId will not throw.

+ 295 - 1
src/Components/Server/test/Circuits/CircuitHostTest.cs

@@ -3,6 +3,9 @@
 
 using System.Diagnostics.CodeAnalysis;
 using System.Reflection;
+using System.Text.Json;
+using Microsoft.AspNetCore.Components.Endpoints;
+using Microsoft.AspNetCore.Components.Rendering;
 using Microsoft.AspNetCore.DataProtection;
 using Microsoft.AspNetCore.SignalR;
 using Microsoft.Extensions.DependencyInjection;
@@ -15,6 +18,9 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits;
 
 public class CircuitHostTest
 {
+    private readonly IDataProtectionProvider _ephemeralDataProtectionProvider = new EphemeralDataProtectionProvider();
+    private readonly ServerComponentInvocationSequence _invocationSequence = new();
+
     [Fact]
     public async Task DisposeAsync_DisposesResources()
     {
@@ -252,7 +258,7 @@ public class CircuitHostTest
             .Returns(tcs.Task)
             .Verifiable();
 
-        var circuitHost = TestCircuitHost.Create(handlers: new[] { handler.Object });
+        var circuitHost = TestCircuitHost.Create(handlers: new[] { handler.Object }, descriptors: [new ComponentDescriptor() ]);
         circuitHost.UnhandledException += (sender, errorInfo) =>
         {
             Assert.Same(circuitHost, sender);
@@ -405,6 +411,194 @@ public class CircuitHostTest
         Assert.True(wasHandlerFuncInvoked);
     }
 
+    [Fact]
+    public async Task UpdateRootComponents_CanAddNewRootComponent()
+    {
+        // Arrange
+        var circuitHost = TestCircuitHost.Create(
+            remoteRenderer: GetRemoteRenderer(),
+            serviceScope: new ServiceCollection().BuildServiceProvider().CreateAsyncScope());
+        var expectedMessage = "Hello, world!";
+        Dictionary<string, object> parameters = new()
+        {
+            [nameof(DynamicallyAddedComponent.Message)] = expectedMessage,
+        };
+        var operation = new RootComponentOperation
+        {
+            Type = RootComponentOperationType.Add,
+            SelectorId = 1,
+            Marker = CreateMarker(typeof(DynamicallyAddedComponent), parameters),
+        };
+        var descriptor = new ComponentDescriptor()
+        {
+            ComponentType = typeof(DynamicallyAddedComponent),
+            Parameters = ParameterView.FromDictionary(parameters),
+            Sequence = 0,
+        };
+
+        // Act
+        await circuitHost.UpdateRootComponents(
+            [(operation, descriptor)], null, CreateDeserializer(), CancellationToken.None);
+
+        // Assert
+        var componentState = ((TestRemoteRenderer)circuitHost.Renderer).GetTestComponentState(0);
+        var component = Assert.IsType<DynamicallyAddedComponent>(componentState.Component);
+        Assert.Equal(expectedMessage, component.Message);
+    }
+
+    [Fact]
+    public async Task UpdateRootComponents_CanUpdateExistingRootComponent()
+    {
+        // Arrange
+        var circuitHost = TestCircuitHost.Create(
+            remoteRenderer: GetRemoteRenderer(),
+            serviceScope: new ServiceCollection().BuildServiceProvider().CreateAsyncScope());
+        var expectedMessage = "Updated message";
+
+        Dictionary<string, object> parameters = new()
+        {
+            [nameof(DynamicallyAddedComponent.Message)] = expectedMessage,
+        };
+        await AddComponent<DynamicallyAddedComponent>(circuitHost, parameters);
+
+        var operation = new RootComponentOperation
+        {
+            Type = RootComponentOperationType.Update,
+            ComponentId = 0,
+            Marker = CreateMarker(typeof(DynamicallyAddedComponent), new()
+            {
+                [nameof(DynamicallyAddedComponent.Message)] = expectedMessage,
+            }),
+        };
+        var descriptor = new ComponentDescriptor()
+        {
+            ComponentType = typeof(DynamicallyAddedComponent),
+            Parameters = ParameterView.FromDictionary(new Dictionary<string, object>()),
+            Sequence = 0,
+        };
+
+        // Act
+        await circuitHost.UpdateRootComponents([(operation, descriptor)], null, CreateDeserializer(), CancellationToken.None);
+
+        // Assert
+        var componentState = ((TestRemoteRenderer)circuitHost.Renderer).GetTestComponentState(0);
+        var component = Assert.IsType<DynamicallyAddedComponent>(componentState.Component);
+        Assert.Equal(expectedMessage, component.Message);
+    }
+
+    [Fact]
+    public async Task UpdateRootComponents_DoesNotUpdateExistingRootComponent_WhenDescriptorComponentTypeDoesNotMatchRootComponentType()
+    {
+        // Arrange
+        var circuitHost = TestCircuitHost.Create(
+            remoteRenderer: GetRemoteRenderer(),
+            serviceScope: new ServiceCollection().BuildServiceProvider().CreateAsyncScope());
+
+        // Arrange
+        var expectedMessage = "Existing message";
+        await AddComponent<DynamicallyAddedComponent>(circuitHost, new Dictionary<string, object>()
+        {
+            [nameof(DynamicallyAddedComponent.Message)] = expectedMessage,
+        });
+
+        await AddComponent<TestComponent>(circuitHost, []);
+
+        Dictionary<string, object> parameters = new()
+        {
+            [nameof(DynamicallyAddedComponent.Message)] = "Updated message",
+        };
+        var operation = new RootComponentOperation
+        {
+            Type = RootComponentOperationType.Update,
+            ComponentId = 0,
+            Marker = CreateMarker(typeof(TestComponent) /* Note the incorrect component type */, parameters),
+        };
+        var descriptor = new ComponentDescriptor()
+        {
+            ComponentType = typeof(TestComponent),
+            Parameters = ParameterView.FromDictionary(parameters),
+            Sequence = 0,
+        };
+        var operationsJson = JsonSerializer.Serialize(
+            new[] { operation },
+            ServerComponentSerializationSettings.JsonSerializationOptions);
+
+        // Act
+        var evt = Assert.Raises<UnhandledExceptionEventArgs>(
+            handler => circuitHost.UnhandledException += new UnhandledExceptionEventHandler(handler),
+            handler => circuitHost.UnhandledException -= new UnhandledExceptionEventHandler(handler),
+            () => circuitHost.UpdateRootComponents(
+                [(operation, descriptor)], null, CreateDeserializer(), CancellationToken.None));
+
+        // Assert
+        var componentState = ((TestRemoteRenderer)circuitHost.Renderer).GetTestComponentState(0);
+        var component = Assert.IsType<DynamicallyAddedComponent>(componentState.Component);
+        Assert.Equal(expectedMessage, component.Message);
+
+        Assert.NotNull(evt);
+        var exception = Assert.IsType<InvalidOperationException>(evt.Arguments.ExceptionObject);
+    }
+
+    [Fact]
+    public async Task UpdateRootComponents_CanRemoveExistingRootComponent()
+    {
+        // Arrange
+        var circuitHost = TestCircuitHost.Create(
+            remoteRenderer: GetRemoteRenderer(),
+            serviceScope: new ServiceCollection().BuildServiceProvider().CreateAsyncScope());
+        var expectedMessage = "Updated message";
+
+        Dictionary<string, object> parameters = new()
+        {
+            [nameof(DynamicallyAddedComponent.Message)] = expectedMessage,
+        };
+        await AddComponent<DynamicallyAddedComponent>(circuitHost, parameters);
+
+        var operation = new RootComponentOperation
+        {
+            Type = RootComponentOperationType.Remove,
+            ComponentId = 0,
+        };
+
+        // Act
+        await circuitHost.UpdateRootComponents([(operation, null)], null, CreateDeserializer(), CancellationToken.None);
+
+        // Assert
+        Assert.Throws<ArgumentException>(() =>
+            ((TestRemoteRenderer)circuitHost.Renderer).GetTestComponentState(0));
+    }
+
+    private async Task AddComponent<TComponent>(CircuitHost circuitHost, Dictionary<string, object> parameters)
+    where TComponent : IComponent
+    {
+        var addOperation = new RootComponentOperation
+        {
+            Type = RootComponentOperationType.Add,
+            SelectorId = 1,
+            Marker = CreateMarker(typeof(TComponent), parameters),
+        };
+        var addDescriptor = new ComponentDescriptor()
+        {
+            ComponentType = typeof(TComponent),
+            Parameters = ParameterView.FromDictionary(parameters),
+            Sequence = 0,
+        };
+
+        // Add component
+        await circuitHost.UpdateRootComponents(
+            [(addOperation, addDescriptor)], null, CreateDeserializer(), CancellationToken.None);
+    }
+
+    private ProtectedPrerenderComponentApplicationStore CreateStore()
+    {
+        return new ProtectedPrerenderComponentApplicationStore(_ephemeralDataProtectionProvider);
+    }
+
+    private ServerComponentDeserializer CreateDeserializer()
+    {
+        return new ServerComponentDeserializer(_ephemeralDataProtectionProvider, NullLogger<ServerComponentDeserializer>.Instance, new RootComponentTypeCache(), new ComponentParameterDeserializer(NullLogger<ComponentParameterDeserializer>.Instance, new ComponentParametersTypeCache()));
+    }
+
     private static TestRemoteRenderer GetRemoteRenderer()
     {
         var serviceCollection = new ServiceCollection();
@@ -434,6 +628,18 @@ public class CircuitHostTest
             .Verifiable();
     }
 
+    private ComponentMarker CreateMarker(Type type, Dictionary<string, object> parameters = null)
+    {
+        var serializer = new ServerComponentSerializer(_ephemeralDataProtectionProvider);
+        var marker = ComponentMarker.Create(ComponentMarker.ServerMarkerType, false, null);
+        serializer.SerializeInvocation(
+            ref marker,
+            _invocationSequence,
+            type,
+            parameters is null ? ParameterView.Empty : ParameterView.FromDictionary(parameters));
+        return marker;
+    }
+
     private class TestRemoteRenderer : RemoteRenderer
     {
         public TestRemoteRenderer(IServiceProvider serviceProvider, IClientProxy client)
@@ -449,6 +655,9 @@ public class CircuitHostTest
         {
         }
 
+        public ComponentState GetTestComponentState(int id)
+            => base.GetComponentState(id);
+
         protected override void Dispose(bool disposing)
         {
             base.Dispose(disposing);
@@ -580,10 +789,95 @@ public class CircuitHostTest
             return true;
         }
 
+        public bool TryDeserializeRootComponentOperations(string serializedComponentOperations, out (RootComponentOperation, ComponentDescriptor)[] operationsWithDescriptors)
+        {
+            operationsWithDescriptors= default;
+            return true;
+        }
+
         public bool TryDeserializeSingleComponentDescriptor(ComponentMarker record, [NotNullWhen(true)] out ComponentDescriptor result)
         {
             result = default;
             return true;
         }
     }
+
+    private class DynamicallyAddedComponent : IComponent, IDisposable
+    {
+        private readonly TaskCompletionSource _disposeTcs = new();
+        private RenderHandle _renderHandle;
+
+        [Parameter]
+        public string Message { get; set; } = "Default message";
+
+        private void Render(RenderTreeBuilder builder)
+        {
+            builder.AddContent(0, Message);
+        }
+
+        public void Attach(RenderHandle renderHandle)
+        {
+            _renderHandle = renderHandle;
+        }
+
+        public Task SetParametersAsync(ParameterView parameters)
+        {
+            if (parameters.TryGetValue<string>(nameof(Message), out var message))
+            {
+                Message = message;
+            }
+
+            TriggerRender();
+            return Task.CompletedTask;
+        }
+
+        public void TriggerRender()
+        {
+            var task = _renderHandle.Dispatcher.InvokeAsync(() => _renderHandle.Render(Render));
+            Assert.True(task.IsCompletedSuccessfully);
+        }
+
+        public Task WaitForDisposeAsync()
+            => _disposeTcs.Task;
+
+        public void Dispose()
+        {
+            _disposeTcs.SetResult();
+        }
+    }
+
+    private class TestComponent() : IComponent, IHandleAfterRender
+    {
+        private RenderHandle _renderHandle;
+        private readonly RenderFragment _renderFragment = (builder) =>
+        {
+            builder.OpenElement(0, "my element");
+            builder.AddContent(1, "some text");
+            builder.CloseElement();
+        };
+
+        public TestComponent(RenderFragment renderFragment) : this() => _renderFragment = renderFragment;
+
+        public Action OnAfterRenderComplete { get; set; }
+
+        public void Attach(RenderHandle renderHandle) => _renderHandle = renderHandle;
+
+        public Task OnAfterRenderAsync()
+        {
+            OnAfterRenderComplete?.Invoke();
+            return Task.CompletedTask;
+        }
+
+        public Task SetParametersAsync(ParameterView parameters)
+        {
+            TriggerRender();
+            return Task.CompletedTask;
+        }
+
+        public void TriggerRender()
+        {
+            var task = _renderHandle.Dispatcher.InvokeAsync(() => _renderHandle.Render(_renderFragment));
+            Assert.True(task.IsCompletedSuccessfully);
+        }
+    }
 }

+ 6 - 0
src/Components/Server/test/Circuits/ComponentHubTest.cs

@@ -170,6 +170,12 @@ public class ComponentHubTest
             return true;
         }
 
+        public bool TryDeserializeRootComponentOperations(string serializedComponentOperations, out (RootComponentOperation, ComponentDescriptor)[] operationsWithDescriptors)
+        {
+            operationsWithDescriptors = default;
+            return true;
+        }
+
         public bool TryDeserializeSingleComponentDescriptor(ComponentMarker record, [NotNullWhen(true)] out ComponentDescriptor result)
         {
             result = default;

+ 0 - 295
src/Components/Server/test/Circuits/RemoteRendererTest.cs

@@ -25,7 +25,6 @@ public class RemoteRendererTest
     private static readonly TimeSpan Timeout = Debugger.IsAttached ? System.Threading.Timeout.InfiniteTimeSpan : TimeSpan.FromSeconds(10);
 
     private readonly IDataProtectionProvider _ephemeralDataProtectionProvider = new EphemeralDataProtectionProvider();
-    private readonly ServerComponentInvocationSequence _invocationSequence = new();
 
     [Fact]
     public void WritesAreBufferedWhenTheClientIsOffline()
@@ -426,239 +425,6 @@ public class RemoteRendererTest
             exception.Message);
     }
 
-    [Fact]
-    public async Task UpdateRootComponents_CanAddNewRootComponent()
-    {
-        // Arrange
-        var serviceProvider = CreateServiceProvider();
-        var renderer = GetRemoteRenderer(serviceProvider);
-        var expectedMessage = "Hello, world!";
-        var operation = new RootComponentOperation
-        {
-            Type = RootComponentOperationType.Add,
-            SelectorId = 1,
-            Marker = CreateMarker(typeof(DynamicallyAddedComponent), new()
-            {
-                [nameof(DynamicallyAddedComponent.Message)] = expectedMessage,
-            }),
-        };
-        var operationsJson = JsonSerializer.Serialize(
-            new[] { operation },
-            ServerComponentSerializationSettings.JsonSerializationOptions);
-
-        // Act
-        await renderer.Dispatcher.InvokeAsync(() => renderer.UpdateRootComponents(operationsJson));
-        var componentState = renderer.GetComponentState(0);
-
-        // Assert
-        var component = Assert.IsType<DynamicallyAddedComponent>(componentState.Component);
-        Assert.Equal(expectedMessage, component.Message);
-    }
-
-    [Fact]
-    public async Task UpdateRootComponents_DoesNotAddNewRootComponent_WhenSelectorIdIsMissing()
-    {
-        // Arrange
-        var serviceProvider = CreateServiceProvider();
-        var renderer = GetRemoteRenderer(serviceProvider);
-        var operation = new RootComponentOperation
-        {
-            Type = RootComponentOperationType.Add,
-            Marker = CreateMarker(typeof(DynamicallyAddedComponent)),
-        };
-        var operationsJson = JsonSerializer.Serialize(
-            new[] { operation },
-            ServerComponentSerializationSettings.JsonSerializationOptions);
-
-        // Act
-        await renderer.Dispatcher.InvokeAsync(() => renderer.UpdateRootComponents(operationsJson));
-        renderer.UpdateRootComponents(operationsJson);
-
-        // Assert
-        var ex = Assert.Throws<ArgumentException>(() => renderer.GetComponentState(0));
-        Assert.StartsWith("The renderer does not have a component with ID", ex.Message);
-    }
-
-    [Fact]
-    public async Task UpdateRootComponents_Throws_WhenAddingComponentFromInvalidDescriptor()
-    {
-        // Arrange
-        var serviceProvider = CreateServiceProvider();
-        var renderer = GetRemoteRenderer(serviceProvider);
-        var operation = new RootComponentOperation
-        {
-            Type = RootComponentOperationType.Add,
-            SelectorId = 1,
-            Marker = new ComponentMarker()
-            {
-                Descriptor = "some random text",
-            },
-        };
-        var operationsJson = JsonSerializer.Serialize(
-            new[] { operation },
-            ServerComponentSerializationSettings.JsonSerializationOptions);
-
-        // Act
-        var task = renderer.Dispatcher.InvokeAsync(() => renderer.UpdateRootComponents(operationsJson));
-
-        // Assert
-        var ex = await Assert.ThrowsAsync<InvalidOperationException>(async () => await task);
-        Assert.StartsWith("Failed to deserialize a component descriptor when adding", ex.Message);
-    }
-
-    [Fact]
-    public async Task UpdateRootComponents_CanUpdateExistingRootComponent()
-    {
-        // Arrange
-        var serviceProvider = CreateServiceProvider();
-        var renderer = GetRemoteRenderer(serviceProvider);
-        var component = new DynamicallyAddedComponent()
-        {
-            Message = "Existing message",
-        };
-        var expectedMessage = "Updated message";
-        var componentId = renderer.AssignRootComponentId(component);
-        var operation = new RootComponentOperation
-        {
-            Type = RootComponentOperationType.Update,
-            ComponentId = componentId,
-            Marker = CreateMarker(typeof(DynamicallyAddedComponent), new()
-            {
-                [nameof(DynamicallyAddedComponent.Message)] = expectedMessage,
-            }),
-        };
-        var operationsJson = JsonSerializer.Serialize(
-            new[] { operation },
-            ServerComponentSerializationSettings.JsonSerializationOptions);
-
-        // Act
-        await renderer.Dispatcher.InvokeAsync(() => renderer.UpdateRootComponents(operationsJson));
-
-        // Assert
-        Assert.Equal(expectedMessage, component.Message);
-    }
-
-    [Fact]
-    public async Task UpdateRootComponents_DoesNotUpdateExistingRootComponent_WhenComponentIdIsMissing()
-    {
-        // Arrange
-        var serviceProvider = CreateServiceProvider();
-        var renderer = GetRemoteRenderer(serviceProvider);
-        var expectedMessage = "Existing message";
-        var component = new DynamicallyAddedComponent()
-        {
-            Message = expectedMessage,
-        };
-        var componentId = renderer.AssignRootComponentId(component);
-        var operation = new RootComponentOperation
-        {
-            Type = RootComponentOperationType.Update,
-            Marker = CreateMarker(typeof(DynamicallyAddedComponent), new()
-            {
-                [nameof(DynamicallyAddedComponent.Message)] = "Some other message",
-            }),
-        };
-        var operationsJson = JsonSerializer.Serialize(
-            new[] { operation },
-            ServerComponentSerializationSettings.JsonSerializationOptions);
-
-        // Act
-        await renderer.Dispatcher.InvokeAsync(() => renderer.UpdateRootComponents(operationsJson));
-
-        // Assert
-        Assert.Equal(expectedMessage, component.Message);
-    }
-
-    [Fact]
-    public async Task UpdateRootComponents_DoesNotUpdateExistingRootComponent_WhenDescriptorComponentTypeDoesNotMatchRootComponentType()
-    {
-        // Arrange
-        var serviceProvider = CreateServiceProvider();
-        var renderer = GetRemoteRenderer(serviceProvider);
-        var expectedMessage = "Existing message";
-        var component1 = new DynamicallyAddedComponent()
-        {
-            Message = expectedMessage,
-        };
-        var component2 = new TestComponent();
-        var component1Id = renderer.AssignRootComponentId(component1);
-        var component2Id = renderer.AssignRootComponentId(component2);
-        var operation = new RootComponentOperation
-        {
-            Type = RootComponentOperationType.Update,
-            ComponentId = component1Id,
-            Marker = CreateMarker(typeof(TestComponent) /* Note the incorrect component type */, new()
-            {
-                [nameof(DynamicallyAddedComponent.Message)] = "Updated message",
-            }),
-        };
-        var operationsJson = JsonSerializer.Serialize(
-            new[] { operation },
-            ServerComponentSerializationSettings.JsonSerializationOptions);
-
-        // Act
-        await renderer.Dispatcher.InvokeAsync(() => renderer.UpdateRootComponents(operationsJson));
-
-        // Assert
-        Assert.Equal(expectedMessage, component1.Message);
-    }
-
-    [Fact]
-    public async Task UpdateRootComponents_Throws_WhenUpdatingComponentFromInvalidDescriptor()
-    {
-        // Arrange
-        var serviceProvider = CreateServiceProvider();
-        var renderer = GetRemoteRenderer(serviceProvider);
-        var component = new DynamicallyAddedComponent()
-        {
-            Message = "Existing message",
-        };
-        var componentId = renderer.AssignRootComponentId(component);
-        var operation = new RootComponentOperation
-        {
-            Type = RootComponentOperationType.Update,
-            ComponentId = componentId,
-            Marker = new()
-            {
-                Descriptor = "some random text",
-            },
-        };
-        var operationsJson = JsonSerializer.Serialize(
-            new[] { operation },
-            ServerComponentSerializationSettings.JsonSerializationOptions);
-
-        // Act
-        var task = renderer.Dispatcher.InvokeAsync(() => renderer.UpdateRootComponents(operationsJson));
-
-        // Assert
-        var ex = await Assert.ThrowsAsync<InvalidOperationException>(async () => await task);
-        Assert.StartsWith("Failed to deserialize a component descriptor when updating", ex.Message);
-    }
-
-    [Fact]
-    public async Task UpdateRootComponents_CanRemoveExistingRootComponent()
-    {
-        // Arrange
-        var serviceProvider = CreateServiceProvider();
-        var renderer = GetRemoteRenderer(serviceProvider);
-        var component = new DynamicallyAddedComponent();
-        var componentId = renderer.AssignRootComponentId(component);
-        var operation = new RootComponentOperation
-        {
-            Type = RootComponentOperationType.Remove,
-            ComponentId = componentId,
-        };
-        var operationsJson = JsonSerializer.Serialize(
-            new[] { operation },
-            ServerComponentSerializationSettings.JsonSerializationOptions);
-
-        // Act
-        await renderer.Dispatcher.InvokeAsync(() => renderer.UpdateRootComponents(operationsJson));
-
-        // Assert
-        await component.WaitForDisposeAsync().WaitAsync(Timeout); // Will timeout and throw if not disposed
-    }
-
     private IServiceProvider CreateServiceProvider()
     {
         var serviceCollection = new ServiceCollection();
@@ -685,18 +451,6 @@ public class RemoteRendererTest
             NullLogger.Instance);
     }
 
-    private ComponentMarker CreateMarker(Type type, Dictionary<string, object> parameters = null)
-    {
-        var serializer = new ServerComponentSerializer(_ephemeralDataProtectionProvider);
-        var marker = ComponentMarker.Create(ComponentMarker.ServerMarkerType, false, null);
-        serializer.SerializeInvocation(
-            ref marker,
-            _invocationSequence,
-            type,
-            parameters is null ? ParameterView.Empty : ParameterView.FromDictionary(parameters));
-        return marker;
-    }
-
     private class TestRemoteRenderer : RemoteRenderer
     {
         public TestRemoteRenderer(IServiceProvider serviceProvider, ILoggerFactory loggerFactory, CircuitOptions options, CircuitClientProxy client, IServerComponentDeserializer serverComponentDeserializer, ILogger logger)
@@ -715,11 +469,6 @@ public class RemoteRendererTest
         {
         }
 
-        public new void UpdateRootComponents(string operationsJson)
-        {
-            base.UpdateRootComponents(operationsJson);
-        }
-
         public new ComponentState GetComponentState(int componentId)
         {
             return base.GetComponentState(componentId);
@@ -811,48 +560,4 @@ public class RemoteRendererTest
             Component.TriggerRender();
         }
     }
-
-    private class DynamicallyAddedComponent : IComponent, IDisposable
-    {
-        private readonly TaskCompletionSource _disposeTcs = new();
-        private RenderHandle _renderHandle;
-
-        [Parameter]
-        public string Message { get; set; } = "Default message";
-
-        private void Render(RenderTreeBuilder builder)
-        {
-            builder.AddContent(0, Message);
-        }
-
-        public void Attach(RenderHandle renderHandle)
-        {
-            _renderHandle = renderHandle;
-        }
-
-        public Task SetParametersAsync(ParameterView parameters)
-        {
-            if (parameters.TryGetValue<string>(nameof(Message), out var message))
-            {
-                Message = message;
-            }
-
-            TriggerRender();
-            return Task.CompletedTask;
-        }
-
-        public void TriggerRender()
-        {
-            var task = _renderHandle.Dispatcher.InvokeAsync(() => _renderHandle.Render(Render));
-            Assert.True(task.IsCompletedSuccessfully);
-        }
-
-        public Task WaitForDisposeAsync()
-            => _disposeTcs.Task;
-
-        public void Dispose()
-        {
-            _disposeTcs.SetResult();
-        }
-    }
 }

+ 105 - 0
src/Components/Server/test/Circuits/ServerComponentDeserializerTest.cs

@@ -394,6 +394,93 @@ public class ServerComponentDeserializerTest
         Assert.False(serverComponentDeserializer.TryDeserializeSingleComponentDescriptor(firstInvocationMarkers[0], out _));
     }
 
+    [Fact]
+    public void UpdateRootComponents_TryDeserializeRootComponentOperationsReturnsFalse_WhenAddOperationIsMissingSelectorId()
+    {
+        // Arrange
+        var operation = new RootComponentOperation
+        {
+            Type = RootComponentOperationType.Add,
+            SelectorId = 1,
+            Marker = new ComponentMarker()
+            {
+                Descriptor = "some random text",
+            },
+        };
+        var operationsJson = JsonSerializer.Serialize(
+            new[] { operation },
+            ServerComponentSerializationSettings.JsonSerializationOptions);
+        var deserializer = CreateServerComponentDeserializer();
+
+        // Act
+        var result = deserializer.TryDeserializeRootComponentOperations(operationsJson, out var parsed);
+
+        // Assert
+        Assert.False(result);
+        Assert.Null(parsed);
+    }
+
+    [Fact]
+    public void UpdateRootComponents_TryDeserializeRootComponentOperationsReturnsFalse_WhenComponentIdIsMissing()
+    {
+        // Arrange
+        var operation = new RootComponentOperation
+        {
+            Type = RootComponentOperationType.Update,
+            Marker = CreateMarker(typeof(DynamicallyAddedComponent), new()
+            {
+                ["Message"] = "Some other message",
+            }),
+        };
+        var operationsJson = JsonSerializer.Serialize(
+            new[] { operation },
+            ServerComponentSerializationSettings.JsonSerializationOptions);
+
+        var deserializer = CreateServerComponentDeserializer();
+
+        // Act
+        var result = deserializer.TryDeserializeRootComponentOperations(operationsJson, out var parsed);
+
+        // Assert
+        Assert.False(result);
+        Assert.Null(parsed);
+    }
+
+    [Fact]
+    public void UpdateRootComponents_TryDeserializeRootComponentOperationsReturnsFalse_WhenComponentIdIsRepeated()
+    {
+        // Arrange
+        var operation = new RootComponentOperation
+        {
+            Type = RootComponentOperationType.Update,
+            ComponentId = 1,
+            Marker = CreateMarker(typeof(DynamicallyAddedComponent), new()
+            {
+                ["Message"] = "Some other message",
+            }),
+        };
+
+        var other = new RootComponentOperation
+        {
+            Type = RootComponentOperationType.Remove,
+            ComponentId = 1,
+            Marker = CreateMarker(typeof(DynamicallyAddedComponent)),
+        };
+
+        var operationsJson = JsonSerializer.Serialize(
+            new[] { operation, other },
+            ServerComponentSerializationSettings.JsonSerializationOptions);
+
+        var deserializer = CreateServerComponentDeserializer();
+
+        // Act
+        var result = deserializer.TryDeserializeRootComponentOperations(operationsJson, out var parsed);
+
+        // Assert
+        Assert.False(result);
+        Assert.Null(parsed);
+    }
+
     private string SerializeComponent(string assembly, string type) =>
         JsonSerializer.Serialize(
             new ServerComponent(0, assembly, type, Array.Empty<ComponentParameter>(), Array.Empty<object>(), Guid.NewGuid()),
@@ -411,6 +498,18 @@ public class ServerComponentDeserializerTest
     private string SerializeMarkers(ComponentMarker[] markers) =>
         JsonSerializer.Serialize(markers, ServerComponentSerializationSettings.JsonSerializationOptions);
 
+    private ComponentMarker CreateMarker(Type type, Dictionary<string, object> parameters = null)
+    {
+        var serializer = new ServerComponentSerializer(_ephemeralDataProtectionProvider);
+        var marker = ComponentMarker.Create(ComponentMarker.ServerMarkerType, false, null);
+        serializer.SerializeInvocation(
+            ref marker,
+            _invocationSequence,
+            type,
+            parameters is null ? ParameterView.Empty : ParameterView.FromDictionary(parameters));
+        return marker;
+    }
+
     private ComponentMarker[] CreateMarkers(params Type[] types)
     {
         var serializer = new ServerComponentSerializer(_ephemeralDataProtectionProvider);
@@ -466,4 +565,10 @@ public class ServerComponentDeserializerTest
 
         public Task SetParametersAsync(ParameterView parameters) => throw new NotImplementedException();
     }
+
+    private class DynamicallyAddedComponent : IComponent
+    {
+        public void Attach(RenderHandle renderHandle) => throw new NotImplementedException();
+        public Task SetParametersAsync(ParameterView parameters) => throw new NotImplementedException();
+    }
 }

+ 2 - 1
src/Components/Shared/src/DefaultAntiforgeryStateProvider.cs

@@ -2,6 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System.Diagnostics.CodeAnalysis;
+using Microsoft.AspNetCore.Components.Web;
 
 namespace Microsoft.AspNetCore.Components.Forms;
 
@@ -25,7 +26,7 @@ internal class DefaultAntiforgeryStateProvider : AntiforgeryStateProvider, IDisp
         {
             state.PersistAsJson(PersistenceKey, _currentToken);
             return Task.CompletedTask;
-        });
+        }, new InteractiveAutoRenderMode());
 
         state.TryTakeFromJson(PersistenceKey, out _currentToken);
     }

ファイルの差分が大きいため隠しています
+ 0 - 0
src/Components/Web.JS/dist/Release/blazor.server.js


ファイルの差分が大きいため隠しています
+ 0 - 0
src/Components/Web.JS/dist/Release/blazor.web.js


+ 8 - 2
src/Components/Web.JS/src/Boot.Server.Common.ts

@@ -7,7 +7,7 @@ import { LogLevel } from './Platform/Logging/Logger';
 import { CircuitManager } from './Platform/Circuits/CircuitManager';
 import { resolveOptions, CircuitStartOptions } from './Platform/Circuits/CircuitStartOptions';
 import { DefaultReconnectionHandler } from './Platform/Circuits/DefaultReconnectionHandler';
-import { discoverPersistedState, ServerComponentDescriptor } from './Services/ComponentDescriptorDiscovery';
+import { discoverServerPersistedState, ServerComponentDescriptor } from './Services/ComponentDescriptorDiscovery';
 import { fetchAndInvokeInitializers } from './JSInitializers/JSInitializers.Server';
 import { RootComponentManager } from './Services/RootComponentManager';
 
@@ -31,7 +31,7 @@ export async function startServer(components: RootComponentManager<ServerCompone
   }
 
   started = true;
-  appState = discoverPersistedState(document) || '';
+  appState = discoverServerPersistedState(document) || '';
   logger = new ConsoleLogger(options.logLevel);
   circuit = new CircuitManager(components, appState, options, logger);
 
@@ -65,6 +65,7 @@ export async function startServer(components: RootComponentManager<ServerCompone
   Blazor._internal.sendJSDataStream = (data: ArrayBufferView | Blob, streamId: number, chunkSize: number) => circuit.sendJsDataStream(data, streamId, chunkSize);
 
   const jsInitializer = await fetchAndInvokeInitializers(options);
+
   const circuitStarted = await circuit.start();
   if (!circuitStarted) {
     logger.log(LogLevel.Error, 'Failed to start the circuit.');
@@ -97,6 +98,7 @@ export function startCircuit(): Promise<boolean> {
 
   if (circuit.isDisposedOrDisposing()) {
     // If the current circuit is no longer available, create a new one.
+    appState = discoverServerPersistedState(document) || '';
     circuit = new CircuitManager(circuit.getRootComponentManager(), appState, options, logger);
   }
 
@@ -116,3 +118,7 @@ export async function disposeCircuit() {
 export function isCircuitAvailable(): boolean {
   return !circuit.isDisposedOrDisposing();
 }
+
+export function updateServerRootComponents(operations: string): Promise<void>|undefined {
+  return circuit.updateRootComponents(operations);
+}

+ 48 - 3
src/Components/Web.JS/src/Boot.WebAssembly.Common.ts

@@ -11,7 +11,7 @@ import { SharedMemoryRenderBatch } from './Rendering/RenderBatch/SharedMemoryRen
 import { PlatformApi, Pointer } from './Platform/Platform';
 import { WebAssemblyStartOptions } from './Platform/WebAssemblyStartOptions';
 import { addDispatchEventMiddleware } from './Rendering/WebRendererInteropMethods';
-import { WebAssemblyComponentDescriptor, discoverPersistedState } from './Services/ComponentDescriptorDiscovery';
+import { WebAssemblyComponentDescriptor, discoverWebAssemblyPersistedState } from './Services/ComponentDescriptorDiscovery';
 import { receiveDotNetDataStream } from './StreamingInterop';
 import { WebAssemblyComponentAttacher } from './Platform/WebAssemblyComponentAttacher';
 import { MonoConfig } from 'dotnet';
@@ -21,12 +21,32 @@ let options: Partial<WebAssemblyStartOptions> | undefined;
 let platformLoadPromise: Promise<void> | undefined;
 let loadedWebAssemblyPlatform = false;
 let started = false;
+let firstUpdate = true;
+let waitForRootComponents = false;
 
 let resolveBootConfigPromise: (value: MonoConfig) => void;
 const bootConfigPromise = new Promise<MonoConfig>(resolve => {
   resolveBootConfigPromise = resolve;
 });
 
+let resolveInitialUpdatePromise: (value: string) => void;
+const initialUpdatePromise = new Promise<string>(resolve => {
+  resolveInitialUpdatePromise = resolve;
+});
+
+export function resolveInitialUpdate(value: string): void {
+  resolveInitialUpdatePromise(value);
+  firstUpdate = false;
+}
+
+export function isFirstUpdate() {
+  return firstUpdate;
+}
+
+export function setWaitForRootComponents(): void {
+  waitForRootComponents = true;
+}
+
 export function setWebAssemblyOptions(webAssemblyOptions?: Partial<WebAssemblyStartOptions>) {
   if (options) {
     throw new Error('WebAssembly options have already been configured.');
@@ -121,7 +141,12 @@ export async function startWebAssembly(components: RootComponentManager<WebAssem
     getParameterValues: (id) => componentAttacher.getParameterValues(id) || '',
   };
 
-  Blazor._internal.getPersistedState = () => discoverPersistedState(document) || '';
+  Blazor._internal.getPersistedState = () => discoverWebAssemblyPersistedState(document) || '';
+
+  Blazor._internal.getInitialComponentsUpdate = () => initialUpdatePromise;
+
+  Blazor._internal.updateRootComponents = (operations: string) =>
+    Blazor._internal.dotNetExports?.UpdateRootComponentsCore(operations);
 
   Blazor._internal.attachRootComponentToElement = (selector, componentId, rendererId: any) => {
     const element = componentAttacher.resolveRegisteredElement(selector, componentId);
@@ -157,7 +182,15 @@ export function waitForBootConfigLoaded(): Promise<MonoConfig> {
 
 export function loadWebAssemblyPlatformIfNotStarted(): Promise<void> {
   platformLoadPromise ??= (async () => {
-    await monoPlatform.load(options ?? {}, resolveBootConfigPromise);
+    const finalOptions = options ?? {};
+    const existingConfig = options?.configureRuntime;
+    finalOptions.configureRuntime = (config) => {
+      existingConfig?.(config);
+      if (waitForRootComponents) {
+        config.withEnvironmentVariable('__BLAZOR_WEBASSEMBLY_WAIT_FOR_ROOT_COMPONENTS', 'true');
+      }
+    };
+    await monoPlatform.load(finalOptions, resolveBootConfigPromise);
     loadedWebAssemblyPlatform = true;
   })();
   return platformLoadPromise;
@@ -210,6 +243,18 @@ function invokeJSFromDotNet(callInfo: Pointer, arg0: any, arg1: any, arg2: any):
   }
 }
 
+export function updateWebAssemblyRootComponents(operations: string): void {
+  if (!started) {
+    throw new Error('Blazor WebAssembly has not started.');
+  }
+
+  if (!Blazor._internal.updateRootComponents) {
+    throw new Error('Blazor WebAssembly has not initialized.');
+  }
+
+  Blazor._internal.updateRootComponents(operations);
+}
+
 function invokeJSJson(identifier: string, targetInstanceId: number, resultType: number, argsJson: string, asyncHandle: number): string | null {
   if (asyncHandle !== 0) {
     dispatcher.beginInvokeJSFromDotNet(asyncHandle, identifier, argsJson, resultType, targetInstanceId);

+ 3 - 0
src/Components/Web.JS/src/GlobalExports.ts

@@ -57,6 +57,8 @@ export interface IBlazor {
     endInvokeDotNetFromJS?: (callId: string, success: boolean, resultJsonOrErrorMessage: string) => void;
     receiveByteArray?: (id: number, data: Uint8Array) => void;
     getPersistedState?: () => string;
+    getInitialComponentsUpdate?: () => Promise<string>;
+    updateRootComponents?: (operations: string) => void;
     attachRootComponentToElement?: (arg0: any, arg1: any, arg2: any, arg3: any) => void;
     registeredComponents?: {
       getRegisteredComponentsCount: () => number;
@@ -84,6 +86,7 @@ export interface IBlazor {
       EndInvokeJS: (argsJson: string) => void;
       BeginInvokeDotNet: (callId: string | null, assemblyNameOrDotNetObjectId: string, methodIdentifier: string, argsJson: string) => void;
       ReceiveByteArrayFromJS: (id: number, data: Uint8Array) => void;
+      UpdateRootComponentsCore: (operationsJson: string) => void;
     }
 
     // APIs invoked by hot reload

+ 13 - 1
src/Components/Web.JS/src/Platform/Circuits/CircuitManager.ts

@@ -22,7 +22,7 @@ import { sendJSDataStream } from './CircuitStreamingInterop';
 export class CircuitManager implements DotNet.DotNetCallDispatcher {
   private readonly _componentManager: RootComponentManager<ServerComponentDescriptor>;
 
-  private readonly _applicationState: string;
+  private _applicationState: string;
 
   private readonly _options: CircuitStartOptions;
 
@@ -38,6 +38,8 @@ export class CircuitManager implements DotNet.DotNetCallDispatcher {
 
   private _startPromise?: Promise<boolean>;
 
+  private _firstUpdate = true;
+
   private _renderingFailed = false;
 
   private _disposePromise?: Promise<void>;
@@ -71,6 +73,16 @@ export class CircuitManager implements DotNet.DotNetCallDispatcher {
     return this._startPromise;
   }
 
+  public updateRootComponents(operations: string): Promise<void> | undefined {
+    if (this._firstUpdate) {
+      // Only send the application state on the first update.
+      this._firstUpdate = false;
+      return this._connection?.send('UpdateRootComponents', operations, this._applicationState);
+    } else {
+      return this._connection?.send('UpdateRootComponents', operations, '');
+    }
+  }
+
   private async startCore(): Promise<boolean> {
     this._connection = await this.startConnection();
 

+ 13 - 4
src/Components/Web.JS/src/Services/ComponentDescriptorDiscovery.ts

@@ -12,12 +12,21 @@ export function discoverComponents(root: Node, type: 'webassembly' | 'server' |
   }
 }
 
-const blazorStateCommentRegularExpression = /^\s*Blazor-Component-State:(?<state>[a-zA-Z0-9+/=]+)$/;
+const blazorServerStateCommentRegularExpression = /^\s*Blazor-Server-Component-State:(?<state>[a-zA-Z0-9+/=]+)$/;
+const blazorWebAssemblyStateCommentRegularExpression = /^\s*Blazor-WebAssembly-Component-State:(?<state>[a-zA-Z0-9+/=]+)$/;
 
-export function discoverPersistedState(node: Node): string | null | undefined {
+export function discoverServerPersistedState(node: Node): string | null | undefined {
+  return discoverPersistedState(node, blazorServerStateCommentRegularExpression);
+}
+
+export function discoverWebAssemblyPersistedState(node: Node): string | null | undefined {
+  return discoverPersistedState(node, blazorWebAssemblyStateCommentRegularExpression);
+}
+
+function discoverPersistedState(node: Node, comment: RegExp): string | null | undefined {
   if (node.nodeType === Node.COMMENT_NODE) {
     const content = node.textContent || '';
-    const parsedState = blazorStateCommentRegularExpression.exec(content);
+    const parsedState = comment.exec(content);
     const value = parsedState && parsedState.groups && parsedState.groups['state'];
     if (value){
       node.parentNode?.removeChild(node);
@@ -32,7 +41,7 @@ export function discoverPersistedState(node: Node): string | null | undefined {
   const nodes = node.childNodes;
   for (let index = 0; index < nodes.length; index++) {
     const candidate = nodes[index];
-    const result = discoverPersistedState(candidate);
+    const result = discoverPersistedState(candidate, comment);
     if (result){
       return result;
     }

+ 22 - 4
src/Components/Web.JS/src/Services/WebRootComponentManager.ts

@@ -2,11 +2,11 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 import { ComponentDescriptor, ComponentMarker, descriptorToMarker } from './ComponentDescriptorDiscovery';
-import { isRendererAttached, registerRendererAttachedListener, updateRootComponents } from '../Rendering/WebRendererInteropMethods';
+import { isRendererAttached, registerRendererAttachedListener } from '../Rendering/WebRendererInteropMethods';
 import { WebRendererId } from '../Rendering/WebRendererId';
 import { DescriptorHandler } from '../Rendering/DomMerging/DomSync';
-import { disposeCircuit, hasStartedServer, isCircuitAvailable, startCircuit, startServer } from '../Boot.Server.Common';
-import { hasLoadedWebAssemblyPlatform, hasStartedLoadingWebAssemblyPlatform, hasStartedWebAssembly, loadWebAssemblyPlatformIfNotStarted, startWebAssembly, waitForBootConfigLoaded } from '../Boot.WebAssembly.Common';
+import { disposeCircuit, hasStartedServer, isCircuitAvailable, startCircuit, startServer, updateServerRootComponents } from '../Boot.Server.Common';
+import { hasLoadedWebAssemblyPlatform, hasStartedLoadingWebAssemblyPlatform, hasStartedWebAssembly, isFirstUpdate, loadWebAssemblyPlatformIfNotStarted, resolveInitialUpdate, setWaitForRootComponents, startWebAssembly, updateWebAssemblyRootComponents, waitForBootConfigLoaded } from '../Boot.WebAssembly.Common';
 import { MonoConfig } from 'dotnet';
 import { RootComponentManager } from './RootComponentManager';
 import { Blazor } from '../GlobalExports';
@@ -110,6 +110,8 @@ export class WebRootComponentManager implements DescriptorHandler, RootComponent
       return;
     }
 
+    setWaitForRootComponents();
+
     const loadWebAssemblyPromise = loadWebAssemblyPlatformIfNotStarted();
 
     // If WebAssembly resources can't be loaded within some time limit,
@@ -263,12 +265,24 @@ export class WebRootComponentManager implements DescriptorHandler, RootComponent
 
     for (const [rendererId, operations] of operationsByRendererId) {
       const operationsJson = JSON.stringify(operations);
-      updateRootComponents(rendererId, operationsJson);
+      if (rendererId === WebRendererId.Server) {
+        updateServerRootComponents(operationsJson);
+      } else {
+        this.updateWebAssemblyRootComponents(operationsJson);
+      }
     }
 
     this.circuitMayHaveNoRootComponents();
   }
 
+  private updateWebAssemblyRootComponents(operationsJson: string) {
+    if (isFirstUpdate()) {
+      resolveInitialUpdate(operationsJson);
+    } else {
+      updateWebAssemblyRootComponents(operationsJson);
+    }
+  }
+
   private resolveRendererIdForDescriptor(descriptor: ComponentDescriptor): WebRendererId | null {
     const resolvedType = descriptor.type === 'auto' ? this.getAutoRenderMode() : descriptor.type;
     switch (resolvedType) {
@@ -345,6 +359,10 @@ export class WebRootComponentManager implements DescriptorHandler, RootComponent
       // updates.
     } else {
       this.unregisterComponent(component);
+      if (component.assignedRendererId !== undefined && component.interactiveComponentId !== undefined) {
+        const renderer = getRendererer(component.assignedRendererId);
+        renderer?.disposeComponent(component.interactiveComponentId);
+      }
 
       if (component.interactiveComponentId !== undefined) {
         // We have an interactive component for this marker, so we'll remove it.

+ 0 - 1
src/Components/Web/src/PublicAPI.Unshipped.txt

@@ -123,6 +123,5 @@ static Microsoft.AspNetCore.Components.Web.RenderMode.InteractiveWebAssembly.get
 virtual Microsoft.AspNetCore.Components.HtmlRendering.Infrastructure.StaticHtmlRenderer.RenderChildComponent(System.IO.TextWriter! output, ref Microsoft.AspNetCore.Components.RenderTree.RenderTreeFrame componentFrame) -> void
 virtual Microsoft.AspNetCore.Components.HtmlRendering.Infrastructure.StaticHtmlRenderer.WriteComponentHtml(int componentId, System.IO.TextWriter! output) -> void
 virtual Microsoft.AspNetCore.Components.RenderTree.WebRenderer.GetWebRendererId() -> int
-virtual Microsoft.AspNetCore.Components.RenderTree.WebRenderer.UpdateRootComponents(string! operationsJson) -> void
 Microsoft.AspNetCore.Components.Web.Virtualization.Virtualize<TItem>.EmptyContent.get -> Microsoft.AspNetCore.Components.RenderFragment?
 Microsoft.AspNetCore.Components.Web.Virtualization.Virtualize<TItem>.EmptyContent.set -> void

+ 0 - 13
src/Components/Web/src/WebRenderer.cs

@@ -85,14 +85,6 @@ public abstract class WebRenderer : Renderer
         return componentId;
     }
 
-    /// <summary>
-    /// Performs the specified operations on the renderer's root components.
-    /// </summary>
-    /// <param name="operationsJson">A JSON-serialized list of operations to perform on the renderer's root components.</param>
-    protected virtual void UpdateRootComponents(string operationsJson)
-    {
-    }
-
     /// <summary>
     /// Called by the framework to give a location for the specified root component in the browser DOM.
     /// </summary>
@@ -123,7 +115,6 @@ public abstract class WebRenderer : Renderer
         private readonly JSComponentInterop _jsComponentInterop;
 
         [DynamicDependency(nameof(DispatchEventAsync))]
-        [DynamicDependency(nameof(UpdateRootComponents))]
         public WebRendererInteropMethods(WebRenderer renderer, JsonSerializerOptions jsonOptions, JSComponentInterop jsComponentInterop)
         {
             _renderer = renderer;
@@ -141,10 +132,6 @@ public abstract class WebRenderer : Renderer
                 webEventData.EventArgs);
         }
 
-        [JSInvokable]
-        public void UpdateRootComponents(string operationsJson)
-            => _renderer.UpdateRootComponents(operationsJson);
-
         [JSInvokable] // Linker preserves this if you call RootComponents.Add
         public int AddRootComponent(string identifier, string domElementSelector)
             => _jsComponentInterop.AddRootComponent(identifier, domElementSelector);

+ 28 - 5
src/Components/WebAssembly/WebAssembly/src/Hosting/WebAssemblyHost.cs

@@ -151,24 +151,47 @@ public sealed class WebAssemblyHost : IAsyncDisposable
 
             WebAssemblyNavigationManager.Instance.CreateLogger(loggerFactory);
 
+            RootComponentMapping[] mappings = [];
+            if (Environment.GetEnvironmentVariable("__BLAZOR_WEBASSEMBLY_WAIT_FOR_ROOT_COMPONENTS") == "true")
+            {
+                // In Blazor web, we wait for the JS side to tell us about the components available
+                // before we render the initial set of components. Any additional update goes through
+                // UpdateRootComponents.
+                // We do it this way to ensure that the persistent component state is only used the first time
+                // the wasm runtime is initalized and is done in the same way for both webassembly and blazor
+                // web.
+                mappings = await InternalJSImportMethods.GetInitialComponentUpdate();
+            }
+
             var initializationTcs = new TaskCompletionSource();
-            WebAssemblyCallQueue.Schedule((_rootComponents, _renderer, initializationTcs), static async state =>
+            WebAssemblyCallQueue.Schedule((_rootComponents, _renderer, initializationTcs), async state =>
             {
                 var (rootComponents, renderer, initializationTcs) = state;
-
                 try
                 {
                     // Here, we add each root component but don't await the returned tasks so that the
                     // components can be processed in parallel.
                     var count = rootComponents.Count;
-                    var pendingRenders = new Task[count];
+                    var pendingRenders = new List<Task>(count + mappings.Length);
                     for (var i = 0; i < count; i++)
                     {
                         var rootComponent = rootComponents[i];
-                        pendingRenders[i] = renderer.AddComponentAsync(
+                        pendingRenders.Add(renderer.AddComponentAsync(
                             rootComponent.ComponentType,
                             rootComponent.Parameters,
-                            rootComponent.Selector);
+                            rootComponent.Selector));
+                    }
+
+                    if (mappings != null)
+                    {
+                        for (var i = 0; i < mappings.Length; i++)
+                        {
+                            var rootComponent = mappings[i];
+                            pendingRenders.Add(renderer.AddComponentAsync(
+                                rootComponent.ComponentType,
+                                rootComponent.Parameters,
+                                rootComponent.Selector));
+                        }
                     }
 
                     // Now we wait for all components to finish rendering.

+ 22 - 85
src/Components/WebAssembly/WebAssembly/src/Rendering/WebAssemblyRenderer.cs

@@ -5,7 +5,6 @@ using System.Diagnostics.CodeAnalysis;
 using System.Globalization;
 using System.Runtime.CompilerServices;
 using System.Runtime.InteropServices.JavaScript;
-using System.Text.Json;
 using Microsoft.AspNetCore.Components.RenderTree;
 using Microsoft.AspNetCore.Components.Web;
 using Microsoft.AspNetCore.Components.Web.Infrastructure;
@@ -23,7 +22,6 @@ namespace Microsoft.AspNetCore.Components.WebAssembly.Rendering;
 /// </summary>
 internal sealed partial class WebAssemblyRenderer : WebRenderer
 {
-    private readonly RootComponentTypeCache _rootComponentCache = new();
     private readonly ILogger _logger;
 
     public WebAssemblyRenderer(IServiceProvider serviceProvider, ILoggerFactory loggerFactory, JSComponentInterop jsComponentInterop)
@@ -32,6 +30,28 @@ internal sealed partial class WebAssemblyRenderer : WebRenderer
         _logger = loggerFactory.CreateLogger<WebAssemblyRenderer>();
 
         ElementReferenceContext = DefaultWebAssemblyJSRuntime.Instance.ElementReferenceContext;
+        DefaultWebAssemblyJSRuntime.Instance.OnUpdateRootComponents += OnUpdateRootComponents;
+    }
+
+    [UnconditionalSuppressMessage("Trimming", "IL2067", Justification = "These are root components which belong to the user and are in assemblies that don't get trimmed.")]
+    private void OnUpdateRootComponents(OperationDescriptor[] operations)
+    {
+        for (var i = 0; i < operations.Length; i++)
+        {
+            var (operation, componentType, parameters) = operations[i];
+            switch (operation.Type)
+            {
+                case RootComponentOperationType.Add:
+                    _ = AddComponentAsync(componentType!, parameters, operation.SelectorId!.Value.ToString(CultureInfo.InvariantCulture));
+                    break;
+                case RootComponentOperationType.Update:
+                    _ = RenderRootComponentAsync(operation.ComponentId!.Value, parameters);
+                    break;
+                case RootComponentOperationType.Remove:
+                    RemoveRootComponent(operation.ComponentId!.Value);
+                    break;
+            }
+        }
     }
 
     public override Dispatcher Dispatcher => NullDispatcher.Instance;
@@ -53,89 +73,6 @@ internal sealed partial class WebAssemblyRenderer : WebRenderer
             RendererId);
     }
 
-    [DynamicDependency(JsonSerialized, typeof(RootComponentOperation))]
-    [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "The correct members will be preserved by the above DynamicDependency")]
-    protected override void UpdateRootComponents(string operationsJson)
-    {
-        var operations = JsonSerializer.Deserialize<IEnumerable<RootComponentOperation>>(
-            operationsJson,
-            WebAssemblyComponentSerializationSettings.JsonSerializationOptions)!;
-
-        foreach (var operation in operations)
-        {
-            switch (operation.Type)
-            {
-                case RootComponentOperationType.Add:
-                    AddRootComponent(operation);
-                    break;
-                case RootComponentOperationType.Update:
-                    UpdateRootComponent(operation);
-                    break;
-                case RootComponentOperationType.Remove:
-                    RemoveRootComponent(operation);
-                    break;
-            }
-        }
-
-        return;
-
-        [UnconditionalSuppressMessage("Trimming", "IL2072", Justification = "Root components are expected to be defined in assemblies that do not get trimmed.")]
-        void AddRootComponent(RootComponentOperation operation)
-        {
-            if (operation.SelectorId is not { } selectorId)
-            {
-                throw new InvalidOperationException($"The component operation of type '{operation.Type}' requires a '{nameof(operation.SelectorId)}' to be specified.");
-            }
-
-            if (operation.Marker is not { } marker)
-            {
-                throw new InvalidOperationException($"The component operation of type '{operation.Type}' requires a '{nameof(operation.Marker)}' to be specified.");
-            }
-
-            var componentType = _rootComponentCache.GetRootComponent(marker.Assembly!, marker.TypeName!)
-                ?? throw new InvalidOperationException($"Root component type '{marker.TypeName}' could not be found in the assembly '{marker.Assembly}'.");
-
-            var parameters = DeserializeComponentParameters(marker);
-            _ = AddComponentAsync(componentType, parameters, selectorId.ToString(CultureInfo.InvariantCulture));
-        }
-
-        void UpdateRootComponent(RootComponentOperation operation)
-        {
-            if (operation.ComponentId is not { } componentId)
-            {
-                throw new InvalidOperationException($"The component operation of type '{operation.Type}' requires a '{nameof(operation.ComponentId)}' to be specified.");
-            }
-
-            if (operation.Marker is not { } marker)
-            {
-                throw new InvalidOperationException($"The component operation of type '{operation.Type}' requires a '{nameof(operation.Marker)}' to be specified.");
-            }
-
-            var parameters = DeserializeComponentParameters(marker);
-            _ = RenderRootComponentAsync(componentId, parameters);
-        }
-
-        void RemoveRootComponent(RootComponentOperation operation)
-        {
-            if (operation.ComponentId is not { } componentId)
-            {
-                throw new InvalidOperationException($"The component operation of type '{operation.Type}' requires a '{nameof(operation.ComponentId)}' to be specified.");
-            }
-
-            this.RemoveRootComponent(componentId);
-        }
-
-        static ParameterView DeserializeComponentParameters(ComponentMarker marker)
-        {
-            var definitions = WebAssemblyComponentParameterDeserializer.GetParameterDefinitions(marker.ParameterDefinitions!);
-            var values = WebAssemblyComponentParameterDeserializer.GetParameterValues(marker.ParameterValues!);
-            var componentDeserializer = WebAssemblyComponentParameterDeserializer.Instance;
-            var parameters = componentDeserializer.DeserializeParameters(definitions, values);
-
-            return parameters;
-        }
-    }
-
     /// <inheritdoc />
     protected override void Dispose(bool disposing)
     {

+ 108 - 0
src/Components/WebAssembly/WebAssembly/src/Services/DefaultWebAssemblyJSRuntime.cs

@@ -10,19 +10,24 @@ using Microsoft.AspNetCore.Components.WebAssembly.Hosting;
 using Microsoft.JSInterop;
 using Microsoft.JSInterop.Infrastructure;
 using Microsoft.JSInterop.WebAssembly;
+using static Microsoft.AspNetCore.Internal.LinkerFlags;
 
 namespace Microsoft.AspNetCore.Components.WebAssembly.Services;
 
 internal sealed partial class DefaultWebAssemblyJSRuntime : WebAssemblyJSRuntime
 {
+    private readonly RootComponentTypeCache _rootComponentCache = new();
     internal static readonly DefaultWebAssemblyJSRuntime Instance = new();
 
     public ElementReferenceContext ElementReferenceContext { get; }
 
+    public event Action<OperationDescriptor[]>? OnUpdateRootComponents;
+
     [DynamicDependency(nameof(InvokeDotNet))]
     [DynamicDependency(nameof(EndInvokeJS))]
     [DynamicDependency(nameof(BeginInvokeDotNet))]
     [DynamicDependency(nameof(ReceiveByteArrayFromJS))]
+    [DynamicDependency(nameof(UpdateRootComponentsCore))]
     private DefaultWebAssemblyJSRuntime()
     {
         ElementReferenceContext = new WebElementReferenceContext(this);
@@ -84,6 +89,83 @@ internal sealed partial class DefaultWebAssemblyJSRuntime : WebAssemblyJSRuntime
         });
     }
 
+    [SupportedOSPlatform("browser")]
+    [JSExport]
+    public static void UpdateRootComponentsCore(string operationsJson)
+    {
+        try
+        {
+            var operations = DeserializeOperations(operationsJson);
+            Instance.OnUpdateRootComponents?.Invoke(operations);
+        }
+        catch (Exception ex)
+        {
+            Console.Error.WriteLine($"Error deserializing root component operations: {ex}");
+        }
+    }
+
+    [DynamicDependency(JsonSerialized, typeof(RootComponentOperation))]
+    [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "The correct members will be preserved by the above DynamicDependency")]
+    [SuppressMessage("Trimming", "IL2072:Target parameter argument does not satisfy 'DynamicallyAccessedMembersAttribute' in call to target method. The return value of the source method does not have matching annotations.", Justification = "Types in that cache are components from the user assembly which are never trimmed.")]
+    internal static OperationDescriptor[] DeserializeOperations(string operationsJson)
+    {
+        var deserialized = JsonSerializer.Deserialize<RootComponentOperation[]>(
+            operationsJson,
+            WebAssemblyComponentSerializationSettings.JsonSerializationOptions)!;
+
+        var operations = new OperationDescriptor[deserialized.Length];
+
+        for (var i = 0; i < deserialized.Length; i++)
+        {
+            var operation = deserialized[i];
+            if (operation.Type == RootComponentOperationType.Remove ||
+                operation.Type == RootComponentOperationType.Update)
+            {
+                if (operation.ComponentId == null)
+                {
+                    throw new InvalidOperationException($"The component operation of type '{operation.Type}' requires a '{nameof(operation.ComponentId)}' to be specified.");
+                }
+            }
+
+            if (operation.Type == RootComponentOperationType.Remove)
+            {
+                operations[i] = new(operation, null, ParameterView.Empty);
+                continue;
+            }
+
+            if (operation.Marker == null)
+            {
+                throw new InvalidOperationException($"The component operation of type '{operation.Type}' requires a '{nameof(operation.Marker)}' to be specified.");
+            }
+
+            Type? componentType = null;
+            if (operation.Type == RootComponentOperationType.Add)
+            {
+                if (operation.SelectorId == null)
+                {
+                    throw new InvalidOperationException($"The component operation of type '{operation.Type}' requires a '{nameof(operation.SelectorId)}' to be specified.");
+                }
+                componentType = Instance._rootComponentCache.GetRootComponent(operation.Marker!.Value.Assembly!, operation.Marker.Value.TypeName!)
+                ?? throw new InvalidOperationException($"Root component type '{operation.Marker.Value.TypeName}' could not be found in the assembly '{operation.Marker.Value.Assembly}'.");
+            }
+
+            var parameters = DeserializeComponentParameters(operation.Marker.Value);
+            operations[i] = new(operation, componentType, parameters);
+        }
+
+        return operations;
+    }
+
+    static ParameterView DeserializeComponentParameters(ComponentMarker marker)
+    {
+        var definitions = WebAssemblyComponentParameterDeserializer.GetParameterDefinitions(marker.ParameterDefinitions!);
+        var values = WebAssemblyComponentParameterDeserializer.GetParameterValues(marker.ParameterValues!);
+        var componentDeserializer = WebAssemblyComponentParameterDeserializer.Instance;
+        var parameters = componentDeserializer.DeserializeParameters(definitions, values);
+
+        return parameters;
+    }
+
     [JSExport]
     [SupportedOSPlatform("browser")]
     private static void ReceiveByteArrayFromJS(int id, byte[] data)
@@ -101,3 +183,29 @@ internal sealed partial class DefaultWebAssemblyJSRuntime : WebAssemblyJSRuntime
         return TransmitDataStreamToJS.TransmitStreamAsync(this, "Blazor._internal.receiveWebAssemblyDotNetDataStream", streamId, dotNetStreamReference);
     }
 }
+
+internal readonly struct OperationDescriptor
+{
+    public OperationDescriptor(
+        RootComponentOperation operation,
+        Type? componentType,
+        ParameterView parameters)
+    {
+        Operation = operation;
+        ComponentType = componentType;
+        Parameters = parameters;
+    }
+
+    public RootComponentOperation Operation { get; }
+
+    public Type? ComponentType { get; }
+
+    public ParameterView Parameters { get; }
+
+    public void Deconstruct(out RootComponentOperation operation, out Type? componentType, out ParameterView parameters)
+    {
+        operation = Operation;
+        componentType = ComponentType;
+        parameters = Parameters;
+    }
+}

+ 28 - 0
src/Components/WebAssembly/WebAssembly/src/Services/InternalJSImportMethods.cs

@@ -1,6 +1,8 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
+using System.Diagnostics.CodeAnalysis;
+using System.Globalization;
 using System.Runtime.InteropServices.JavaScript;
 using Microsoft.AspNetCore.Components.Web;
 using Microsoft.AspNetCore.Components.WebAssembly.Hosting;
@@ -16,6 +18,29 @@ internal partial class InternalJSImportMethods : IInternalJSImportMethods
     public string GetPersistedState()
         => GetPersistedStateCore();
 
+    [UnconditionalSuppressMessage("Trimming", "IL2067", Justification = "These are root components which belong to the user and are in assemblies that don't get trimmed.")]
+    public static async Task<RootComponentMapping[]> GetInitialComponentUpdate()
+    {
+        var components = await InternalJSImportMethods.GetInitialUpdateCore();
+        var operations = DefaultWebAssemblyJSRuntime.DeserializeOperations(components);
+        var mappings = new RootComponentMapping[operations.Length];
+
+        for (var i = 0; i < operations.Length; i++)
+        {
+            var (operation, component, parameters) = operations[i];
+            if (operation.Type != RootComponentOperationType.Add)
+            {
+                throw new InvalidOperationException("All initial operations must be additions.");
+            }
+            mappings[i] = new RootComponentMapping(
+                component!,
+                operation.SelectorId!.Value.ToString(CultureInfo.InvariantCulture),
+                parameters);
+        }
+
+        return mappings;
+    }
+
     public string GetApplicationEnvironment()
         => GetApplicationEnvironmentCore();
 
@@ -52,6 +77,9 @@ internal partial class InternalJSImportMethods : IInternalJSImportMethods
     [JSImport("Blazor._internal.getPersistedState", "blazor-internal")]
     private static partial string GetPersistedStateCore();
 
+    [JSImport("Blazor._internal.getInitialComponentsUpdate", "blazor-internal")]
+    private static partial Task<string> GetInitialUpdateCore();
+
     [JSImport("Blazor._internal.getApplicationEnvironment", "blazor-internal")]
     private static partial string GetApplicationEnvironmentCore();
 

+ 79 - 0
src/Components/test/E2ETest/ServerRenderingTests/InteractivityTest.cs

@@ -847,6 +847,85 @@ public class InteractivityTest : ServerTestBase<BasicTestAppServerSiteFixture<Ra
         AssertBrowserLogContainsMessage("Error: The circuit");
     }
 
+    [Fact]
+    public void CanPersistPrerenderedState_Server()
+    {
+        Navigate($"{ServerPathBase}/persist-state?server=true");
+
+        Browser.Equal("restored", () => Browser.FindElement(By.Id("server")).Text);
+        Browser.Equal("Server", () => Browser.FindElement(By.Id("render-mode-server")).Text);
+    }
+
+    [Fact]
+    public void CanPersistPrerenderedState_WebAssembly()
+    {
+        Navigate($"{ServerPathBase}/persist-state?wasm=true");
+
+        Browser.Equal("restored", () => Browser.FindElement(By.Id("wasm")).Text);
+        Browser.Equal("WebAssembly", () => Browser.FindElement(By.Id("render-mode-wasm")).Text);
+    }
+
+    [Fact]
+    public void CanPersistPrerenderedState_Auto_PersistsOnWebAssembly()
+    {
+        Navigate($"{ServerPathBase}/persist-state?auto=true");
+
+        Browser.Equal("restored", () => Browser.FindElement(By.Id("auto")).Text);
+        Browser.Equal("WebAssembly", () => Browser.FindElement(By.Id("render-mode-auto")).Text);
+    }
+
+    [Fact]
+    public void CanPersistPrerenderedState_Auto_PersistsOnServer()
+    {
+        Navigate(ServerPathBase);
+        Browser.Equal("Hello", () => Browser.Exists(By.TagName("h1")).Text);
+        BlockWebAssemblyResourceLoad();
+
+        Navigate($"{ServerPathBase}/persist-state?auto=true");
+
+        Browser.Equal("restored", () => Browser.FindElement(By.Id("auto")).Text);
+        Browser.Equal("Server", () => Browser.FindElement(By.Id("render-mode-auto")).Text);
+    }
+
+    [Fact]
+    public void CanPersistState_AllRenderModesAtTheSameTime()
+    {
+        Navigate($"{ServerPathBase}/persist-state?server=true&wasm=true&auto=true");
+
+        Browser.Equal("restored", () => Browser.FindElement(By.Id("server")).Text);
+        Browser.Equal("Server", () => Browser.FindElement(By.Id("render-mode-server")).Text);
+
+        Browser.Equal("restored", () => Browser.FindElement(By.Id("wasm")).Text);
+        Browser.Equal("WebAssembly", () => Browser.FindElement(By.Id("render-mode-wasm")).Text);
+
+        Browser.Equal("restored", () => Browser.FindElement(By.Id("auto")).Text);
+        Browser.Equal("WebAssembly", () => Browser.FindElement(By.Id("render-mode-auto")).Text);
+    }
+
+    [Fact]
+    public void CanPersistPrerenderedState_ServerPrerenderedStateAvailableOnlyOnFirstRender()
+    {
+        Navigate($"{ServerPathBase}/persist-server-state");
+
+        Browser.Equal("restored", () => Browser.FindElement(By.Id("server")).Text);
+
+        Browser.Click(By.Id("destroy-and-recreate"));
+
+        Browser.Equal("not restored", () => Browser.FindElement(By.Id("server")).Text);
+    }
+
+    [Fact]
+    public void CanPersistPrerenderedState_WebAssemblyPrerenderedStateAvailableOnlyOnFirstRender()
+    {
+        Navigate($"{ServerPathBase}/persist-wasm-state");
+
+        Browser.Equal("restored", () => Browser.FindElement(By.Id("wasm")).Text);
+
+        Browser.Click(By.Id("destroy-and-recreate"));
+
+        Browser.Equal("not restored", () => Browser.FindElement(By.Id("wasm")).Text);
+    }
+
     private void BlockWebAssemblyResourceLoad()
     {
         ((IJavaScriptExecutor)Browser).ExecuteScript("sessionStorage.setItem('block-load-boot-resource', 'true')");

+ 2 - 0
src/Components/test/E2ETest/Tests/SaveStateTest.cs

@@ -10,6 +10,8 @@ using Xunit.Abstractions;
 
 namespace Microsoft.AspNetCore.Components.E2ETest.ServerExecutionTests;
 
+// These tests are for Blazor Server and Webassembly implementations
+// For Blazor Web, check StatePersistenceTest.cs
 public class SaveStateTest : ServerTestBase<AspNetSiteServerFixture>
 {
     public SaveStateTest(

+ 237 - 0
src/Components/test/E2ETest/Tests/StatePersistenceTest.cs

@@ -0,0 +1,237 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using Components.TestServer.RazorComponents;
+using Microsoft.AspNetCore.Components.E2ETest.Infrastructure;
+using Microsoft.AspNetCore.Components.E2ETest.Infrastructure.ServerFixtures;
+using Microsoft.AspNetCore.Components.E2ETests.ServerRenderingTests;
+using Microsoft.AspNetCore.Components.Web;
+using Microsoft.AspNetCore.E2ETesting;
+using OpenQA.Selenium;
+using TestServer;
+using Xunit.Abstractions;
+
+namespace Microsoft.AspNetCore.Components.E2ETests.Tests;
+
+// These tests are for Blazor Web implementation
+// For Blazor Server and Webassembly, check SaveStateTest.cs
+public class StatePersistenceTest : ServerTestBase<BasicTestAppServerSiteFixture<RazorComponentEndpointsStartup<App>>>
+{
+    static int _nextStreamingIdContext;
+
+    public StatePersistenceTest(
+        BrowserFixture browserFixture,
+        BasicTestAppServerSiteFixture<RazorComponentEndpointsStartup<App>> serverFixture,
+        ITestOutputHelper output)
+        : base(browserFixture, serverFixture, output)
+    {
+    }
+
+    // Separate contexts to ensure that caches and other state don't interfere across tests.
+    public override Task InitializeAsync()
+        => InitializeAsync(BrowserFixture.StreamingContext + _nextStreamingIdContext++);
+
+    // Validates that we can use persisted state across server, webasembly, and auto modes, with and without
+    // streaming rendering.
+    // For streaming rendering, we validate that the state is captured and restored after streaming completes.
+    // For enhanced navigation we validate that the state is captured at the time components are rendered for
+    // the first time on the page.
+    // For auto mode, we validate that the state is captured and restored for both server and wasm runtimes.
+    // In each case, we validate that the state is available until the initial set of components first render reaches quiescence. Similar to how it works for Server and WebAssembly.
+    // For server we validate that the state is provided every time a circuit is initialized.
+    [Theory]
+    [InlineData(true, typeof(InteractiveServerRenderMode), (string)null)]
+    [InlineData(true, typeof(InteractiveServerRenderMode), "ServerStreaming")]
+    [InlineData(true, typeof(InteractiveWebAssemblyRenderMode), (string)null)]
+    [InlineData(true, typeof(InteractiveWebAssemblyRenderMode), "WebAssemblyStreaming")]
+    [InlineData(true, typeof(InteractiveAutoRenderMode), (string)null)]
+    [InlineData(true, typeof(InteractiveAutoRenderMode), "AutoStreaming")]
+    [InlineData(false, typeof(InteractiveServerRenderMode), (string)null)]
+    [InlineData(false, typeof(InteractiveServerRenderMode), "ServerStreaming")]
+    [InlineData(false, typeof(InteractiveWebAssemblyRenderMode), (string)null)]
+    [InlineData(false, typeof(InteractiveWebAssemblyRenderMode), "WebAssemblyStreaming")]
+    [InlineData(false, typeof(InteractiveAutoRenderMode), (string)null)]
+    [InlineData(false, typeof(InteractiveAutoRenderMode), "AutoStreaming")]
+    public void CanRenderComponentWithPersistedState(bool suppressEnhancedNavigation, Type renderMode, string streaming)
+    {
+        var mode = renderMode switch
+        {
+            var t when t == typeof(InteractiveServerRenderMode) => "server",
+            var t when t == typeof(InteractiveWebAssemblyRenderMode) => "wasm",
+            var t when t == typeof(InteractiveAutoRenderMode) => "auto",
+            _ => throw new ArgumentException($"Unknown render mode: {renderMode.Name}")
+        };
+
+        if (!suppressEnhancedNavigation)
+        {
+            // Navigate to a page without components first to make sure that we exercise rendering components
+            // with enhanced navigation on.
+            if (streaming == null)
+            {
+                Navigate($"subdir/persistent-state/page-no-components?render-mode={mode}&suppress-autostart");
+            }
+            else
+            {
+                Navigate($"subdir/persistent-state/page-no-components?render-mode={mode}&streaming-id={streaming}&suppress-autostart");
+            }
+            if (mode == "auto")
+            {
+                BlockWebAssemblyResourceLoad();
+            }
+            Browser.Click(By.Id("call-blazor-start"));
+            Browser.Click(By.Id("page-with-components-link"));
+        }
+        else
+        {
+            SuppressEnhancedNavigation(true);
+        }
+
+        if (mode != "auto")
+        {
+            RenderComponentsWithPersistentStateAndValidate(suppressEnhancedNavigation, mode, renderMode, streaming);
+        }
+        else
+        {
+            if (suppressEnhancedNavigation)
+            {
+                BlockWebAssemblyResourceLoad();
+            }
+            // For auto mode, validate that the state is persisted for both runtimes and is able
+            // to be loaded on server and wasm.
+            RenderComponentsWithPersistentStateAndValidate(suppressEnhancedNavigation, mode, renderMode, streaming, interactiveRuntime: "server");
+
+            UnblockWebAssemblyResourceLoad();
+            Browser.Navigate().Refresh();
+
+            RenderComponentsWithPersistentStateAndValidate(suppressEnhancedNavigation, mode, renderMode, streaming, interactiveRuntime: "wasm");
+        }
+    }
+
+    [Theory]
+    [InlineData((string)null)]
+    [InlineData("ServerStreaming")]
+    public async Task StateIsProvidedEveryTimeACircuitGetsCreated(string streaming)
+    {
+        var mode = "server";
+        if (streaming == null)
+        {
+            Navigate($"subdir/persistent-state/page-no-components?render-mode={mode}");
+        }
+        else
+        {
+            Navigate($"subdir/persistent-state/page-no-components?render-mode={mode}&streaming-id={streaming}");
+        }
+        Browser.Click(By.Id("page-with-components-link"));
+
+        RenderComponentsWithPersistentStateAndValidate(suppresEnhancedNavigation: false, mode, typeof(InteractiveServerRenderMode), streaming);
+        Browser.Click(By.Id("page-no-components-link"));
+        // Ensure that the circuit is gone.
+        await Task.Delay(1000);
+        Browser.Click(By.Id("page-with-components-link-and-state"));
+        RenderComponentsWithPersistentStateAndValidate(suppresEnhancedNavigation: false, mode, typeof(InteractiveServerRenderMode), streaming, stateValue: "other");
+    }
+
+    private void BlockWebAssemblyResourceLoad()
+    {
+        ((IJavaScriptExecutor)Browser).ExecuteScript("sessionStorage.setItem('block-load-boot-resource', 'true')");
+
+        // Clear caches so that we can block the resource load
+        ((IJavaScriptExecutor)Browser).ExecuteScript("caches.keys().then(keys => keys.forEach(key => caches.delete(key)))");
+    }
+
+    private void UnblockWebAssemblyResourceLoad()
+    {
+        ((IJavaScriptExecutor)Browser).ExecuteScript("window.unblockLoadBootResource()");
+    }
+
+    private void RenderComponentsWithPersistentStateAndValidate(
+        bool suppresEnhancedNavigation,
+        string mode,
+        Type renderMode,
+        string streaming,
+        string interactiveRuntime = null,
+        string stateValue = null)
+    {
+        stateValue ??= "restored";
+        // No need to navigate if we are using enhanced navigation, the tests will have already navigated to the page via a link.
+        if (suppresEnhancedNavigation)
+        {
+            // In this case we suppress auto start to check some server side state before we boot Blazor.
+            if (streaming == null)
+            {
+                Navigate($"subdir/persistent-state/page-with-components?render-mode={mode}&suppress-autostart");
+            }
+            else
+            {
+                Navigate($"subdir/persistent-state/page-with-components?render-mode={mode}&streaming-id={streaming}&suppress-autostart");
+            }
+
+            AssertPageState(
+                mode: mode,
+                renderMode: renderMode.Name,
+                interactive: false,
+                stateFound: true,
+                stateValue: stateValue,
+                streamingId: streaming,
+                streamingCompleted: false,
+                interactiveRuntime: interactiveRuntime);
+
+            Browser.Click(By.Id("call-blazor-start"));
+        }
+
+        AssertPageState(
+            mode: mode,
+            renderMode: renderMode.Name,
+            interactive: streaming == null,
+            stateFound: true,
+            stateValue: stateValue,
+            streamingId: streaming,
+            streamingCompleted: false,
+            interactiveRuntime: interactiveRuntime);
+
+        if (streaming != null)
+        {
+            Browser.Click(By.Id("end-streaming"));
+        }
+
+        AssertPageState(
+            mode: mode,
+            renderMode: renderMode.Name,
+            interactive: true,
+            stateFound: true,
+            stateValue: stateValue,
+            streamingId: streaming,
+            streamingCompleted: true,
+            interactiveRuntime: interactiveRuntime);
+    }
+
+    private void AssertPageState(
+        string mode,
+        string renderMode,
+        bool interactive,
+        bool stateFound,
+        string stateValue,
+        string streamingId = null,
+        bool streamingCompleted = false,
+        string interactiveRuntime = null)
+    {
+        Browser.Equal($"Render mode: {renderMode}", () => Browser.FindElement(By.Id("render-mode")).Text);
+        Browser.Equal($"Streaming id:{streamingId}", () => Browser.FindElement(By.Id("streaming-id")).Text);
+        Browser.Equal($"Interactive: {interactive}", () => Browser.FindElement(By.Id("interactive")).Text);
+        if (streamingId == null || streamingCompleted)
+        {
+            interactiveRuntime = !interactive ? "none" : mode == "server" || mode == "wasm" ? mode : (interactiveRuntime ?? throw new InvalidOperationException("Specify interactiveRuntime for auto mode"));
+
+            Browser.Equal($"Interactive runtime: {interactiveRuntime}", () => Browser.FindElement(By.Id("interactive-runtime")).Text);
+            Browser.Equal($"State found:{stateFound}", () => Browser.FindElement(By.Id("state-found")).Text);
+            Browser.Equal($"State value:{stateValue}", () => Browser.FindElement(By.Id("state-value")).Text);
+        }
+        else
+        {
+            Browser.Equal("Streaming: True", () => Browser.FindElement(By.Id("streaming")).Text);
+        }
+    }
+
+    private void SuppressEnhancedNavigation(bool shouldSuppress)
+        => EnhancedNavigationTestUtil.SuppressEnhancedNavigation(this, shouldSuppress);
+}

+ 2 - 1
src/Components/test/testassets/BasicTestApp/PreserveStateService.cs

@@ -2,6 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using Microsoft.AspNetCore.Components;
+using Microsoft.AspNetCore.Components.Web;
 
 namespace BasicTestApp;
 
@@ -15,7 +16,7 @@ public class PreserveStateService : IDisposable
     public PreserveStateService(PersistentComponentState componentApplicationState)
     {
         _componentApplicationState = componentApplicationState;
-        _persistingSubscription = _componentApplicationState.RegisterOnPersisting(PersistState);
+        _persistingSubscription = _componentApplicationState.RegisterOnPersisting(PersistState, RenderMode.InteractiveAuto);
         TryRestoreState();
     }
 

+ 24 - 0
src/Components/test/testassets/Components.TestServer/RazorComponents/Pages/PersistServerState.razor

@@ -0,0 +1,24 @@
+@page "/persist-server-state"
+@attribute [RenderModeInteractiveServer]
+
+<p>Server Persist State Component</p>
+
+@if (!destroyAndRecreate)
+{
+    <TestContentPackage.PersistStateComponent KeyName="server" />
+}
+else
+{
+    <TestContentPackage.PersistStateComponent KeyName="server" />
+}
+
+<button id="destroy-and-recreate" @onclick="DestroyAndRecreate">Destroy and recreate</button>
+
+@code {
+    private bool destroyAndRecreate = false;
+
+    private void DestroyAndRecreate()
+    {
+        destroyAndRecreate = true;
+    }
+}

+ 33 - 0
src/Components/test/testassets/Components.TestServer/RazorComponents/Pages/PersistStateComponents.razor

@@ -0,0 +1,33 @@
+@page "/persist-state"
+@using Microsoft.AspNetCore.Components.Web
+
+<h1>Persist State Components</h1>
+
+@if (Server.GetValueOrDefault()) {
+    <strong>Server Persist State Component</strong>
+    <TestContentPackage.PersistStateComponent KeyName="server" @rendermode="@RenderMode.InteractiveServer" />
+    <hr />
+}
+
+@if (WebAssembly.GetValueOrDefault()) {
+    <strong>WebAssembly Persist State Component</strong>
+    <TestContentPackage.PersistStateComponent KeyName="wasm" @rendermode="@RenderMode.InteractiveWebAssembly" />
+    <hr />
+}
+
+@if (Auto.GetValueOrDefault()) {
+    <strong>Auto Persist State Component</strong>
+    <TestContentPackage.PersistStateComponent KeyName="auto" @rendermode="@RenderMode.InteractiveAuto" />
+    <hr />
+}
+
+@code {
+    [Parameter, SupplyParameterFromQuery(Name = "server")]
+    public bool? Server { get; set; }
+
+    [Parameter, SupplyParameterFromQuery(Name = "wasm")]
+    public bool? WebAssembly { get; set; }
+
+    [Parameter, SupplyParameterFromQuery(Name = "auto")]
+    public bool? Auto { get; set; }
+}

+ 36 - 0
src/Components/test/testassets/Components.TestServer/RazorComponents/Pages/PersistentState/EndStreamingPage.razor

@@ -0,0 +1,36 @@
+@page "/persistent-state/end-streaming"
+@using Components.TestServer.Services
+
+<h3>End streaming page</h3>
+
+<h4>This page is used to terminate the streaming rendering operations for
+    persistent component state tests. It accepts a streaming-id query parameter
+    which is used to complete streaming on the request with the associated streaming-id.
+</h4>
+
+<p>Streaming finished for stream id: @StreamingId</p>
+
+<script>
+    setTimeout(() => {
+        window.close();
+    }, 3000);
+</script>
+
+@code {
+
+    [SupplyParameterFromQuery(Name = "streaming-id")] public string StreamingId { get; set; }
+
+    [Inject] public AsyncOperationService StreamingManager { get; set; }
+
+    protected override void OnInitialized()
+    {
+        if (string.IsNullOrEmpty(StreamingId))
+        {
+            throw new InvalidOperationException("StreamingId is required.");
+        }
+        else
+        {
+            StreamingManager.Complete(StreamingId);
+        }
+    }
+}

+ 65 - 0
src/Components/test/testassets/Components.TestServer/RazorComponents/Pages/PersistentState/PageWithComponents.razor

@@ -0,0 +1,65 @@
+@page "/persistent-state/page-with-components"
+@using TestContentPackage.PersistentComponents
+
+<h3>Streaming page with components</h3>
+
+<h3>
+    This page can render one or more components with different render modes and different streaming rendering modes.
+    It accepts a render-mode query parameter which is used to determine the render mode for the components on the page.
+    It also accepts a streaming-id query parameter which is used to select whether to render a component that uses streaming rendering or not.
+</h3>
+
+<p id="render-mode">Render mode: @_renderMode?.GetType()?.Name</p>
+<p id="streaming-id">Streaming id:@StreamingId</p>
+@if (_renderMode != null)
+{
+    <CascadingValue Name="RunningOnServer" Value="true">
+        @if (!string.IsNullOrEmpty(StreamingId))
+        {
+            <StreamingComponentWithPersistentState @rendermode="@_renderMode" StreamingId="@StreamingId" ServerState="@ServerState" />
+        }
+        else
+        {
+            <NonStreamingComponentWithPersistentState @rendermode="@_renderMode" ServerState="@ServerState" />
+        }
+    </CascadingValue>
+}
+@if (!string.IsNullOrEmpty(StreamingId))
+{
+    <a id="end-streaming" href="@($"persistent-state/end-streaming?streaming-id={StreamingId}")" target="_blank">End streaming</a>
+}
+
+<a id="page-no-components-link" href=@($"persistent-state/page-no-components?render-mode={RenderMode}&streaming-id={StreamingId}")>Go to page with no components</a>
+
+
+@code {
+
+    private IComponentRenderMode _renderMode;
+
+    [SupplyParameterFromQuery(Name = "render-mode")] public string RenderMode { get; set; }
+
+    [SupplyParameterFromQuery(Name = "streaming-id")] public string StreamingId { get; set; }
+
+    [SupplyParameterFromQuery(Name = "server-state")] public string ServerState { get; set; }
+
+    protected override void OnInitialized()
+    {
+        if (!string.IsNullOrEmpty(RenderMode))
+        {
+            switch (RenderMode)
+            {
+                case "server":
+                    _renderMode = new InteractiveServerRenderMode(true);
+                    break;
+                case "wasm":
+                    _renderMode = new InteractiveWebAssemblyRenderMode(true);
+                    break;
+                case "auto":
+                    _renderMode = new InteractiveAutoRenderMode(true);
+                    break;
+                default:
+                    throw new ArgumentException($"Invalid render mode: {RenderMode}");
+            }
+        }
+    }
+}

+ 14 - 0
src/Components/test/testassets/Components.TestServer/RazorComponents/Pages/PersistentState/PageWithoutComponents.razor

@@ -0,0 +1,14 @@
+@page "/persistent-state/page-no-components"
+
+<h3>This page does not render any component. We use it to test that persisted state is only provided at the time interactive components get activated on the page.</h3>
+
+<a id="page-with-components-link" href=@($"persistent-state/page-with-components?render-mode={RenderMode}&streaming-id={StreamingId}")>Go to page with components</a>
+
+<a id="page-with-components-link-and-state" href=@($"persistent-state/page-with-components?render-mode={RenderMode}&streaming-id={StreamingId}&server-state=other")>Go to page with components</a>
+
+
+@code {
+    [SupplyParameterFromQuery(Name = "render-mode")] public string RenderMode { get; set; }
+
+    [SupplyParameterFromQuery(Name = "streaming-id")] public string StreamingId { get; set; }
+}

+ 25 - 0
src/Components/test/testassets/Components.WasmMinimal/Pages/PersistWebAssemblyState.razor

@@ -0,0 +1,25 @@
+@page "/persist-wasm-state"
+@using Microsoft.AspNetCore.Components.Web
+@attribute [RenderModeInteractiveWebAssembly]
+
+<p>WebAssembly Persist State Component</p>
+
+@if (!destroyAndRecreate)
+{
+    <TestContentPackage.PersistStateComponent KeyName="wasm" />
+}
+else
+{
+    <TestContentPackage.PersistStateComponent KeyName="wasm" />
+}
+
+<button id="destroy-and-recreate" @onclick="DestroyAndRecreate">Destroy and recreate</button>
+
+@code {
+    private bool destroyAndRecreate = false;
+
+    private void DestroyAndRecreate()
+    {
+        destroyAndRecreate = true;
+    }
+}

+ 4 - 0
src/Components/test/testassets/Components.WasmMinimal/Program.cs

@@ -1,10 +1,14 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
+using System.ComponentModel;
 using System.Reflection;
+using Components.TestServer.Services;
 using Microsoft.AspNetCore.Components.WebAssembly.Hosting;
 
 Assembly.Load(nameof(TestContentPackage));
 
 var builder = WebAssemblyHostBuilder.CreateDefault(args);
+builder.Services.AddSingleton<AsyncOperationService>();
+
 await builder.Build().RunAsync();

+ 40 - 0
src/Components/test/testassets/TestContentPackage/PersistStateComponent.razor

@@ -0,0 +1,40 @@
+@implements IDisposable
+@inject PersistentComponentState ApplicationState
+
+<p>Application state is <span id="@KeyName">@value</span></p>
+<p>Render mode: <span id="render-mode-@KeyName">@_renderMode</span></p>
+
+@code {
+    [Parameter, EditorRequired]
+    public string KeyName { get; set; } = "";
+
+    private string value = "not started";
+    private string _renderMode = "SSR";
+
+    private PersistingComponentStateSubscription persistingSubscription;
+
+    protected override void OnInitialized()
+    {
+        persistingSubscription = ApplicationState.RegisterOnPersisting(() =>
+        {
+            ApplicationState.PersistAsJson(KeyName, $"restored");
+            return Task.CompletedTask;
+        });
+
+        if (!ApplicationState.TryTakeFromJson<string>(KeyName, out var restored))
+        {
+            value = "not restored";
+        }
+        else
+        {
+            value = restored!;
+        }
+
+        _renderMode = OperatingSystem.IsBrowser() ? "WebAssembly" : "Server";
+    }
+
+    void IDisposable.Dispose()
+    {
+        persistingSubscription.Dispose();
+    }
+}

+ 49 - 0
src/Components/test/testassets/TestContentPackage/PersistentComponents/NonStreamingComponentWithPersistentState.razor

@@ -0,0 +1,49 @@
+<p>Non streaming component with persistent state</p>
+
+<p>This component demonstrates state persistence in the absence of streaming rendering. When the component renders it will try to restore the state and if present display that it succeded in doing so and the restored value. If the state is not present, it will indicate it didn't find it and display a "fresh" value.</p>
+
+<p id="interactive">Interactive: @(!RunningOnServer)</p>
+<p id="interactive-runtime">Interactive runtime: @_interactiveRuntime</p>
+<p id="state-found">State found:@_stateFound</p>
+<p id="state-value">State value:@_stateValue</p>
+
+@code {
+
+    private bool _stateFound;
+    private string _stateValue;
+    private string _interactiveRuntime;
+
+    [Inject] public PersistentComponentState PersistentComponentState { get; set; }
+
+    [CascadingParameter(Name = nameof(RunningOnServer))] public bool RunningOnServer { get; set; }
+    [Parameter] public string ServerState { get; set; }
+
+    protected override void OnInitialized()
+    {
+        PersistentComponentState.RegisterOnPersisting(PersistState);
+
+        _stateFound = PersistentComponentState.TryTakeFromJson<string>("NonStreamingComponentWithPersistentState", out _stateValue);
+
+        if (!_stateFound)
+        {
+            _stateValue = "fresh";
+        }
+
+        if (RunningOnServer)
+        {
+            _interactiveRuntime = "none";
+            _stateFound = true;
+            _stateValue = ServerState ?? "restored";
+        }
+        else
+        {
+            _interactiveRuntime = OperatingSystem.IsBrowser() ? "wasm" : "server";
+        }
+    }
+
+    Task PersistState()
+    {
+        PersistentComponentState.PersistAsJson("NonStreamingComponentWithPersistentState", _stateValue);
+        return Task.CompletedTask;
+    }
+}

+ 77 - 0
src/Components/test/testassets/TestContentPackage/PersistentComponents/StreamingComponentWithPersistentState.razor

@@ -0,0 +1,77 @@
+@using Components.TestServer.Services
+@attribute [StreamRendering]
+<p>Streaming component with persistent state</p>
+
+<p>This component demonstrates state persistence alongside streaming rendering. When the component first renders, it'll emit a message "streaming" and yield until its notified via a call to <code>/persistent-state/end-streaming</code>. When the component renders it will try to restore the state and if present display that it succeded in doing so and the restored value. If the state is not present, it will indicate it didn't find it and display a "fresh" value.</p>
+
+    <p id="interactive">Interactive: @(!RunningOnServer)</p>
+@if (_streaming)
+{
+    <p id="streaming">Streaming: @_streaming</p>
+}
+else
+{
+    <p id="interactive-runtime">Interactive runtime: @_interactiveRuntime</p>
+    <p id="state-found">State found:@_stateFound</p>
+    <p id="state-value">State value:@_stateValue</p>
+}
+
+@code {
+
+    private bool _streaming;
+    private bool _stateFound;
+    private string _stateValue;
+    private string _interactiveRuntime;
+
+    [Inject] public PersistentComponentState PersistentComponentState { get; set; }
+
+    [Inject] public AsyncOperationService StreamingManager { get; set; }
+
+    [Parameter] public string StreamingId { get; set; }
+
+    [Parameter] public string ServerState { get; set; }
+
+    [CascadingParameter(Name = nameof(RunningOnServer))] public bool RunningOnServer { get; set; }
+
+    protected override async Task OnInitializedAsync()
+    {
+        if (string.IsNullOrEmpty(StreamingId))
+        {
+            throw new InvalidOperationException("StreamingId is required.");
+        }
+        PersistentComponentState.RegisterOnPersisting(PersistState);
+
+        if (RunningOnServer)
+        {
+            _interactiveRuntime = "none";
+            _streaming = true;
+            await StreamingManager.Start(StreamingId);
+            _streaming = false;
+        }else
+        {
+            _interactiveRuntime = OperatingSystem.IsBrowser() ? "wasm" : "server";
+        }
+
+        // We do this to ensure that the state remains accessible during the entire first render
+        // cycle (technically until the root component reaches quiescence).
+        await Task.Yield();
+        _stateFound = PersistentComponentState.TryTakeFromJson<string>("NonStreamingComponentWithPersistentState", out _stateValue);
+
+        if (!_stateFound)
+        {
+            _stateValue = "fresh";
+        }
+
+        if (RunningOnServer)
+        {
+            _stateFound = true;
+            _stateValue = ServerState ?? "restored";
+        }
+    }
+
+    Task PersistState()
+    {
+        PersistentComponentState.PersistAsJson("NonStreamingComponentWithPersistentState", _stateValue);
+        return Task.CompletedTask;
+    }
+}

+ 0 - 0
src/Components/test/testassets/Components.TestServer/Services/AsyncOperationService.cs → src/Components/test/testassets/TestContentPackage/Services/AsyncOperationService.cs


+ 2 - 4
src/Identity/Extensions.Core/src/IEmailSender.cs → src/Identity/Core/src/IEmailSender.cs

@@ -1,19 +1,17 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
-using System.Threading.Tasks;
-
 namespace Microsoft.AspNetCore.Identity.UI.Services;
 
 /// <summary>
 /// This API supports the ASP.NET Core Identity infrastructure and is not intended to be used as a general purpose
-/// email abstraction. It should be implemented by the application so the Identity infrastructure can send confirmation emails.
+/// email abstraction. It should be implemented by the application so the Identity infrastructure can send confirmation and password reset emails.
 /// </summary>
 public interface IEmailSender
 {
     /// <summary>
     /// This API supports the ASP.NET Core Identity infrastructure and is not intended to be used as a general purpose
-    /// email abstraction. It should be implemented by the application so the Identity infrastructure can send confirmation emails.
+    /// email abstraction. It should be implemented by the application so the Identity infrastructure can send confirmation and apassword reset emails.
     /// </summary>
     /// <param name="email">The recipient's email address.</param>
     /// <param name="subject">The subject of the email.</param>

+ 41 - 0
src/Identity/Core/src/IEmailSenderOfT.cs

@@ -0,0 +1,41 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace Microsoft.AspNetCore.Identity;
+
+/// <summary>
+/// This API supports the ASP.NET Core Identity infrastructure and is not intended to be used as a general purpose
+/// email abstraction. It should be implemented by the application so the Identity infrastructure can send confirmation and password reset emails.
+/// </summary>
+public interface IEmailSender<TUser> where TUser : class
+{
+    /// <summary>
+    /// This API supports the ASP.NET Core Identity infrastructure and is not intended to be used as a general purpose
+    /// email abstraction. It should be implemented by the application so the Identity infrastructure can send confirmation emails.
+    /// </summary>
+    /// <param name="user">The user that is attempting to confirm their email.</param>
+    /// <param name="email">The recipient's email address.</param>
+    /// <param name="confirmationLink">The link to follow to confirm a user's email. Do not double encode this.</param>
+    /// <returns></returns>
+    Task SendConfirmationLinkAsync(TUser user, string email, string confirmationLink);
+
+    /// <summary>
+    /// This API supports the ASP.NET Core Identity infrastructure and is not intended to be used as a general purpose
+    /// email abstraction. It should be implemented by the application so the Identity infrastructure can send password reset emails.
+    /// </summary>
+    /// <param name="user">The user that is attempting to reset their password.</param>
+    /// <param name="email">The recipient's email address.</param>
+    /// <param name="resetLink">The link to follow to reset the user password. Do not double encode this.</param>
+    /// <returns></returns>
+    Task SendPasswordResetLinkAsync(TUser user, string email, string resetLink);
+
+    /// <summary>
+    /// This API supports the ASP.NET Core Identity infrastructure and is not intended to be used as a general purpose
+    /// email abstraction. It should be implemented by the application so the Identity infrastructure can send password reset emails.
+    /// </summary>
+    /// <param name="user">The user that is attempting to reset their password.</param>
+    /// <param name="email">The recipient's email address.</param>
+    /// <param name="resetCode">The code to use to reset the user password. Do not double encode this.</param>
+    /// <returns></returns>
+    Task SendPasswordResetCodeAsync(TUser user, string email, string resetCode);
+}

+ 3 - 7
src/Identity/Core/src/IdentityApiEndpointRouteBuilderExtensions.cs

@@ -14,7 +14,6 @@ using Microsoft.AspNetCore.Http.HttpResults;
 using Microsoft.AspNetCore.Http.Metadata;
 using Microsoft.AspNetCore.Identity;
 using Microsoft.AspNetCore.Identity.Data;
-using Microsoft.AspNetCore.Identity.UI.Services;
 using Microsoft.AspNetCore.WebUtilities;
 using Microsoft.Extensions.DependencyInjection;
 using Microsoft.Extensions.Options;
@@ -45,7 +44,7 @@ public static class IdentityApiEndpointRouteBuilderExtensions
 
         var timeProvider = endpoints.ServiceProvider.GetRequiredService<TimeProvider>();
         var bearerTokenOptions = endpoints.ServiceProvider.GetRequiredService<IOptionsMonitor<BearerTokenOptions>>();
-        var emailSender = endpoints.ServiceProvider.GetRequiredService<IEmailSender>();
+        var emailSender = endpoints.ServiceProvider.GetRequiredService<IEmailSender<TUser>>();
         var linkGenerator = endpoints.ServiceProvider.GetRequiredService<LinkGenerator>();
 
         // We'll figure out a unique endpoint name based on the final route pattern during endpoint generation.
@@ -189,7 +188,6 @@ public static class IdentityApiEndpointRouteBuilderExtensions
             var finalPattern = ((RouteEndpointBuilder)endpointBuilder).RoutePattern.RawText;
             confirmEmailEndpointName = $"{nameof(MapIdentityApi)}-{finalPattern}";
             endpointBuilder.Metadata.Add(new EndpointNameMetadata(confirmEmailEndpointName));
-            endpointBuilder.Metadata.Add(new RouteNameMetadata(confirmEmailEndpointName));
         });
 
         routeGroup.MapPost("/resendConfirmationEmail", async Task<Ok>
@@ -216,8 +214,7 @@ public static class IdentityApiEndpointRouteBuilderExtensions
                 var code = await userManager.GeneratePasswordResetTokenAsync(user);
                 code = WebEncoders.Base64UrlEncode(Encoding.UTF8.GetBytes(code));
 
-                await emailSender.SendEmailAsync(resetRequest.Email, "Reset your password",
-                    $"Reset your password using the following code: {HtmlEncoder.Default.Encode(code)}");
+                await emailSender.SendPasswordResetCodeAsync(user, resetRequest.Email, HtmlEncoder.Default.Encode(code));
             }
 
             // Don't reveal that the user does not exist or is not confirmed, so don't return a 200 if we would have
@@ -416,8 +413,7 @@ public static class IdentityApiEndpointRouteBuilderExtensions
             var confirmEmailUrl = linkGenerator.GetUriByName(context, confirmEmailEndpointName, routeValues)
                 ?? throw new NotSupportedException($"Could not find endpoint named '{confirmEmailEndpointName}'.");
 
-            await emailSender.SendEmailAsync(email, "Confirm your email",
-                $"Please confirm your account by <a href='{HtmlEncoder.Default.Encode(confirmEmailUrl)}'>clicking here</a>.");
+            await emailSender.SendConfirmationLinkAsync(user, email, HtmlEncoder.Default.Encode(confirmEmailUrl));
         }
 
         return new IdentityEndpointsConventionBuilder(routeGroup);

+ 1 - 0
src/Identity/Core/src/IdentityBuilderExtensions.cs

@@ -97,6 +97,7 @@ public static class IdentityBuilderExtensions
 
         builder.AddSignInManager();
         builder.AddDefaultTokenProviders();
+        builder.Services.TryAddTransient(typeof(IEmailSender<>), typeof(DefaultMessageEmailSender<>));
         builder.Services.TryAddTransient<IEmailSender, NoOpEmailSender>();
         builder.Services.TryAddEnumerable(ServiceDescriptor.Singleton<IConfigureOptions<JsonOptions>, IdentityEndpointsJsonOptionsSetup>());
         return builder;

+ 1 - 1
src/Identity/Core/src/Microsoft.AspNetCore.Identity.csproj

@@ -13,7 +13,7 @@
   </PropertyGroup>
 
   <ItemGroup>
-    <Compile Include="$(SharedSourceRoot)BearerToken\DTO\*.cs" LinkBase="DTO" />
+    <Compile Include="$(SharedSourceRoot)DefaultMessageEmailSender.cs" />
   </ItemGroup>
 
   <ItemGroup>

+ 0 - 2
src/Identity/Extensions.Core/src/NoOpEmailSender.cs → src/Identity/Core/src/NoOpEmailSender.cs

@@ -1,8 +1,6 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
-using System.Threading.Tasks;
-
 namespace Microsoft.AspNetCore.Identity.UI.Services;
 
 /// <summary>

+ 9 - 0
src/Identity/Core/src/PublicAPI.Unshipped.txt

@@ -75,6 +75,10 @@ Microsoft.AspNetCore.Identity.Data.TwoFactorResponse.RecoveryCodesLeft.init -> v
 Microsoft.AspNetCore.Identity.Data.TwoFactorResponse.SharedKey.get -> string!
 Microsoft.AspNetCore.Identity.Data.TwoFactorResponse.SharedKey.init -> void
 Microsoft.AspNetCore.Identity.Data.TwoFactorResponse.TwoFactorResponse() -> void
+Microsoft.AspNetCore.Identity.IEmailSender<TUser>
+Microsoft.AspNetCore.Identity.IEmailSender<TUser>.SendConfirmationLinkAsync(TUser! user, string! email, string! confirmationLink) -> System.Threading.Tasks.Task!
+Microsoft.AspNetCore.Identity.IEmailSender<TUser>.SendPasswordResetCodeAsync(TUser! user, string! email, string! resetCode) -> System.Threading.Tasks.Task!
+Microsoft.AspNetCore.Identity.IEmailSender<TUser>.SendPasswordResetLinkAsync(TUser! user, string! email, string! resetLink) -> System.Threading.Tasks.Task!
 Microsoft.AspNetCore.Identity.SecurityStampValidator<TUser>.SecurityStampValidator(Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.Identity.SecurityStampValidatorOptions!>! options, Microsoft.AspNetCore.Identity.SignInManager<TUser!>! signInManager, Microsoft.Extensions.Logging.ILoggerFactory! logger) -> void
 Microsoft.AspNetCore.Identity.SecurityStampValidator<TUser>.TimeProvider.get -> System.TimeProvider!
 Microsoft.AspNetCore.Identity.SecurityStampValidatorOptions.TimeProvider.get -> System.TimeProvider?
@@ -82,6 +86,11 @@ Microsoft.AspNetCore.Identity.SecurityStampValidatorOptions.TimeProvider.set ->
 Microsoft.AspNetCore.Identity.SignInManager<TUser>.AuthenticationScheme.get -> string!
 Microsoft.AspNetCore.Identity.SignInManager<TUser>.AuthenticationScheme.set -> void
 Microsoft.AspNetCore.Identity.TwoFactorSecurityStampValidator<TUser>.TwoFactorSecurityStampValidator(Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.Identity.SecurityStampValidatorOptions!>! options, Microsoft.AspNetCore.Identity.SignInManager<TUser!>! signInManager, Microsoft.Extensions.Logging.ILoggerFactory! logger) -> void
+Microsoft.AspNetCore.Identity.UI.Services.IEmailSender
+Microsoft.AspNetCore.Identity.UI.Services.IEmailSender.SendEmailAsync(string! email, string! subject, string! htmlMessage) -> System.Threading.Tasks.Task!
+Microsoft.AspNetCore.Identity.UI.Services.NoOpEmailSender
+Microsoft.AspNetCore.Identity.UI.Services.NoOpEmailSender.NoOpEmailSender() -> void
+Microsoft.AspNetCore.Identity.UI.Services.NoOpEmailSender.SendEmailAsync(string! email, string! subject, string! htmlMessage) -> System.Threading.Tasks.Task!
 Microsoft.AspNetCore.Routing.IdentityApiEndpointRouteBuilderExtensions
 static Microsoft.AspNetCore.Identity.IdentityBuilderExtensions.AddApiEndpoints(this Microsoft.AspNetCore.Identity.IdentityBuilder! builder) -> Microsoft.AspNetCore.Identity.IdentityBuilder!
 static Microsoft.AspNetCore.Routing.IdentityApiEndpointRouteBuilderExtensions.MapIdentityApi<TUser>(this Microsoft.AspNetCore.Routing.IEndpointRouteBuilder! endpoints) -> Microsoft.AspNetCore.Builder.IEndpointConventionBuilder!

+ 0 - 5
src/Identity/Extensions.Core/src/PublicAPI.Unshipped.txt

@@ -2,11 +2,6 @@
 Microsoft.AspNetCore.Identity.IdentitySchemaVersions
 Microsoft.AspNetCore.Identity.StoreOptions.SchemaVersion.get -> System.Version!
 Microsoft.AspNetCore.Identity.StoreOptions.SchemaVersion.set -> void
-Microsoft.AspNetCore.Identity.UI.Services.IEmailSender
-Microsoft.AspNetCore.Identity.UI.Services.IEmailSender.SendEmailAsync(string! email, string! subject, string! htmlMessage) -> System.Threading.Tasks.Task!
-Microsoft.AspNetCore.Identity.UI.Services.NoOpEmailSender
-Microsoft.AspNetCore.Identity.UI.Services.NoOpEmailSender.NoOpEmailSender() -> void
-Microsoft.AspNetCore.Identity.UI.Services.NoOpEmailSender.SendEmailAsync(string! email, string! subject, string! htmlMessage) -> System.Threading.Tasks.Task!
 static readonly Microsoft.AspNetCore.Identity.IdentitySchemaVersions.Default -> System.Version!
 static readonly Microsoft.AspNetCore.Identity.IdentitySchemaVersions.Version1 -> System.Version!
 static readonly Microsoft.AspNetCore.Identity.IdentitySchemaVersions.Version2 -> System.Version!

+ 3 - 5
src/Identity/UI/src/Areas/Identity/Pages/V4/Account/ExternalLogin.cshtml.cs

@@ -7,7 +7,6 @@ using System.Security.Claims;
 using System.Text;
 using System.Text.Encodings.Web;
 using Microsoft.AspNetCore.Authorization;
-using Microsoft.AspNetCore.Identity.UI.Services;
 using Microsoft.AspNetCore.Mvc;
 using Microsoft.AspNetCore.Mvc.RazorPages;
 using Microsoft.AspNetCore.WebUtilities;
@@ -95,7 +94,7 @@ internal sealed class ExternalLoginModel<TUser> : ExternalLoginModel where TUser
     private readonly UserManager<TUser> _userManager;
     private readonly IUserStore<TUser> _userStore;
     private readonly IUserEmailStore<TUser> _emailStore;
-    private readonly IEmailSender _emailSender;
+    private readonly IEmailSender<TUser> _emailSender;
     private readonly ILogger<ExternalLoginModel> _logger;
 
     public ExternalLoginModel(
@@ -103,7 +102,7 @@ internal sealed class ExternalLoginModel<TUser> : ExternalLoginModel where TUser
         UserManager<TUser> userManager,
         IUserStore<TUser> userStore,
         ILogger<ExternalLoginModel> logger,
-        IEmailSender emailSender)
+        IEmailSender<TUser> emailSender)
     {
         _signInManager = signInManager;
         _userManager = userManager;
@@ -206,8 +205,7 @@ internal sealed class ExternalLoginModel<TUser> : ExternalLoginModel where TUser
                         values: new { area = "Identity", userId = userId, code = code },
                         protocol: Request.Scheme)!;
 
-                    await _emailSender.SendEmailAsync(Input.Email, "Confirm your email",
-                        $"Please confirm your account by <a href='{HtmlEncoder.Default.Encode(callbackUrl)}'>clicking here</a>.");
+                    await _emailSender.SendConfirmationLinkAsync(user, Input.Email, HtmlEncoder.Default.Encode(callbackUrl));
 
                     // If account confirmation is required, we need to show the link if we don't have a real email sender
                     if (_userManager.Options.SignIn.RequireConfirmedAccount)

+ 3 - 7
src/Identity/UI/src/Areas/Identity/Pages/V4/Account/ForgotPassword.cshtml.cs

@@ -5,7 +5,6 @@ using System.ComponentModel.DataAnnotations;
 using System.Text;
 using System.Text.Encodings.Web;
 using Microsoft.AspNetCore.Authorization;
-using Microsoft.AspNetCore.Identity.UI.Services;
 using Microsoft.AspNetCore.Mvc;
 using Microsoft.AspNetCore.Mvc.RazorPages;
 using Microsoft.AspNetCore.WebUtilities;
@@ -52,9 +51,9 @@ public abstract class ForgotPasswordModel : PageModel
 internal sealed class ForgotPasswordModel<TUser> : ForgotPasswordModel where TUser : class
 {
     private readonly UserManager<TUser> _userManager;
-    private readonly IEmailSender _emailSender;
+    private readonly IEmailSender<TUser> _emailSender;
 
-    public ForgotPasswordModel(UserManager<TUser> userManager, IEmailSender emailSender)
+    public ForgotPasswordModel(UserManager<TUser> userManager, IEmailSender<TUser> emailSender)
     {
         _userManager = userManager;
         _emailSender = emailSender;
@@ -81,10 +80,7 @@ internal sealed class ForgotPasswordModel<TUser> : ForgotPasswordModel where TUs
                 values: new { area = "Identity", code },
                 protocol: Request.Scheme)!;
 
-            await _emailSender.SendEmailAsync(
-                Input.Email,
-                "Reset Password",
-                $"Please reset your password by <a href='{HtmlEncoder.Default.Encode(callbackUrl)}'>clicking here</a>.");
+            await _emailSender.SendPasswordResetLinkAsync(user, Input.Email, HtmlEncoder.Default.Encode(callbackUrl));
 
             return RedirectToPage("./ForgotPasswordConfirmation");
         }

+ 4 - 11
src/Identity/UI/src/Areas/Identity/Pages/V4/Account/Manage/Email.cshtml.cs

@@ -4,7 +4,6 @@
 using System.ComponentModel.DataAnnotations;
 using System.Text;
 using System.Text.Encodings.Web;
-using Microsoft.AspNetCore.Identity.UI.Services;
 using Microsoft.AspNetCore.Mvc;
 using Microsoft.AspNetCore.Mvc.RazorPages;
 using Microsoft.AspNetCore.WebUtilities;
@@ -83,12 +82,12 @@ internal sealed class EmailModel<TUser> : EmailModel where TUser : class
 {
     private readonly UserManager<TUser> _userManager;
     private readonly SignInManager<TUser> _signInManager;
-    private readonly IEmailSender _emailSender;
+    private readonly IEmailSender<TUser> _emailSender;
 
     public EmailModel(
         UserManager<TUser> userManager,
         SignInManager<TUser> signInManager,
-        IEmailSender emailSender)
+        IEmailSender<TUser> emailSender)
     {
         _userManager = userManager;
         _signInManager = signInManager;
@@ -145,10 +144,7 @@ internal sealed class EmailModel<TUser> : EmailModel where TUser : class
                 pageHandler: null,
                 values: new { area = "Identity", userId = userId, email = Input.NewEmail, code = code },
                 protocol: Request.Scheme)!;
-            await _emailSender.SendEmailAsync(
-                Input.NewEmail,
-                "Confirm your email",
-                $"Please confirm your account by <a href='{HtmlEncoder.Default.Encode(callbackUrl)}'>clicking here</a>.");
+            await _emailSender.SendConfirmationLinkAsync(user, Input.NewEmail, HtmlEncoder.Default.Encode(callbackUrl));
 
             StatusMessage = "Confirmation link to change email sent. Please check your email.";
             return RedirectToPage();
@@ -181,10 +177,7 @@ internal sealed class EmailModel<TUser> : EmailModel where TUser : class
             pageHandler: null,
             values: new { area = "Identity", userId = userId, code = code },
             protocol: Request.Scheme);
-        await _emailSender.SendEmailAsync(
-            email!,
-            "Confirm your email",
-            $"Please confirm your account by <a href='{HtmlEncoder.Default.Encode(callbackUrl!)}'>clicking here</a>.");
+        await _emailSender.SendConfirmationLinkAsync(user, email!, HtmlEncoder.Default.Encode(callbackUrl!));
 
         StatusMessage = "Verification email sent. Please check your email.";
         return RedirectToPage();

+ 3 - 5
src/Identity/UI/src/Areas/Identity/Pages/V4/Account/Register.cshtml.cs

@@ -8,7 +8,6 @@ using System.Text;
 using System.Text.Encodings.Web;
 using Microsoft.AspNetCore.Authentication;
 using Microsoft.AspNetCore.Authorization;
-using Microsoft.AspNetCore.Identity.UI.Services;
 using Microsoft.AspNetCore.Mvc;
 using Microsoft.AspNetCore.Mvc.RazorPages;
 using Microsoft.AspNetCore.WebUtilities;
@@ -98,14 +97,14 @@ internal sealed class RegisterModel<TUser> : RegisterModel where TUser : class
     private readonly IUserStore<TUser> _userStore;
     private readonly IUserEmailStore<TUser> _emailStore;
     private readonly ILogger<RegisterModel> _logger;
-    private readonly IEmailSender _emailSender;
+    private readonly IEmailSender<TUser> _emailSender;
 
     public RegisterModel(
         UserManager<TUser> userManager,
         IUserStore<TUser> userStore,
         SignInManager<TUser> signInManager,
         ILogger<RegisterModel> logger,
-        IEmailSender emailSender)
+        IEmailSender<TUser> emailSender)
     {
         _userManager = userManager;
         _userStore = userStore;
@@ -146,8 +145,7 @@ internal sealed class RegisterModel<TUser> : RegisterModel where TUser : class
                     values: new { area = "Identity", userId = userId, code = code, returnUrl = returnUrl },
                     protocol: Request.Scheme)!;
 
-                await _emailSender.SendEmailAsync(Input.Email, "Confirm your email",
-                    $"Please confirm your account by <a href='{HtmlEncoder.Default.Encode(callbackUrl)}'>clicking here</a>.");
+                await _emailSender.SendConfirmationLinkAsync(user, Input.Email, HtmlEncoder.Default.Encode(callbackUrl));
 
                 if (_userManager.Options.SignIn.RequireConfirmedAccount)
                 {

+ 3 - 4
src/Identity/UI/src/Areas/Identity/Pages/V4/Account/RegisterConfirmation.cshtml.cs

@@ -3,7 +3,6 @@
 
 using System.Text;
 using Microsoft.AspNetCore.Authorization;
-using Microsoft.AspNetCore.Identity.UI.Services;
 using Microsoft.AspNetCore.Mvc;
 using Microsoft.AspNetCore.Mvc.RazorPages;
 using Microsoft.AspNetCore.WebUtilities;
@@ -46,9 +45,9 @@ public class RegisterConfirmationModel : PageModel
 internal sealed class RegisterConfirmationModel<TUser> : RegisterConfirmationModel where TUser : class
 {
     private readonly UserManager<TUser> _userManager;
-    private readonly IEmailSender _sender;
+    private readonly IEmailSender<TUser> _sender;
 
-    public RegisterConfirmationModel(UserManager<TUser> userManager, IEmailSender sender)
+    public RegisterConfirmationModel(UserManager<TUser> userManager, IEmailSender<TUser> sender)
     {
         _userManager = userManager;
         _sender = sender;
@@ -70,7 +69,7 @@ internal sealed class RegisterConfirmationModel<TUser> : RegisterConfirmationMod
 
         Email = email;
         // If the email sender is a no-op, display the confirm link in the page
-        DisplayConfirmAccountLink = _sender is NoOpEmailSender;
+        DisplayConfirmAccountLink = _sender is DefaultMessageEmailSender<TUser> defaultMessageSender && defaultMessageSender.IsNoOp;
         if (DisplayConfirmAccountLink)
         {
             var userId = await _userManager.GetUserIdAsync(user);

+ 3 - 7
src/Identity/UI/src/Areas/Identity/Pages/V4/Account/ResendEmailConfirmation.cshtml.cs

@@ -5,7 +5,6 @@ using System.ComponentModel.DataAnnotations;
 using System.Text;
 using System.Text.Encodings.Web;
 using Microsoft.AspNetCore.Authorization;
-using Microsoft.AspNetCore.Identity.UI.Services;
 using Microsoft.AspNetCore.Mvc;
 using Microsoft.AspNetCore.Mvc.RazorPages;
 using Microsoft.AspNetCore.WebUtilities;
@@ -58,9 +57,9 @@ public class ResendEmailConfirmationModel : PageModel
 internal sealed class ResendEmailConfirmationModel<TUser> : ResendEmailConfirmationModel where TUser : class
 {
     private readonly UserManager<TUser> _userManager;
-    private readonly IEmailSender _emailSender;
+    private readonly IEmailSender<TUser> _emailSender;
 
-    public ResendEmailConfirmationModel(UserManager<TUser> userManager, IEmailSender emailSender)
+    public ResendEmailConfirmationModel(UserManager<TUser> userManager, IEmailSender<TUser> emailSender)
     {
         _userManager = userManager;
         _emailSender = emailSender;
@@ -92,10 +91,7 @@ internal sealed class ResendEmailConfirmationModel<TUser> : ResendEmailConfirmat
             pageHandler: null,
             values: new { userId = userId, code = code },
             protocol: Request.Scheme)!;
-        await _emailSender.SendEmailAsync(
-            Input.Email,
-            "Confirm your email",
-            $"Please confirm your account by <a href='{HtmlEncoder.Default.Encode(callbackUrl)}'>clicking here</a>.");
+        await _emailSender.SendConfirmationLinkAsync(user, Input.Email, HtmlEncoder.Default.Encode(callbackUrl));
 
         ModelState.AddModelError(string.Empty, "Verification email sent. Please check your email.");
         return Page();

+ 3 - 5
src/Identity/UI/src/Areas/Identity/Pages/V5/Account/ExternalLogin.cshtml.cs

@@ -7,7 +7,6 @@ using System.Security.Claims;
 using System.Text;
 using System.Text.Encodings.Web;
 using Microsoft.AspNetCore.Authorization;
-using Microsoft.AspNetCore.Identity.UI.Services;
 using Microsoft.AspNetCore.Mvc;
 using Microsoft.AspNetCore.Mvc.RazorPages;
 using Microsoft.AspNetCore.WebUtilities;
@@ -95,7 +94,7 @@ internal sealed class ExternalLoginModel<TUser> : ExternalLoginModel where TUser
     private readonly UserManager<TUser> _userManager;
     private readonly IUserStore<TUser> _userStore;
     private readonly IUserEmailStore<TUser> _emailStore;
-    private readonly IEmailSender _emailSender;
+    private readonly IEmailSender<TUser> _emailSender;
     private readonly ILogger<ExternalLoginModel> _logger;
 
     public ExternalLoginModel(
@@ -103,7 +102,7 @@ internal sealed class ExternalLoginModel<TUser> : ExternalLoginModel where TUser
         UserManager<TUser> userManager,
         IUserStore<TUser> userStore,
         ILogger<ExternalLoginModel> logger,
-        IEmailSender emailSender)
+        IEmailSender<TUser> emailSender)
     {
         _signInManager = signInManager;
         _userManager = userManager;
@@ -206,8 +205,7 @@ internal sealed class ExternalLoginModel<TUser> : ExternalLoginModel where TUser
                         values: new { area = "Identity", userId = userId, code = code },
                         protocol: Request.Scheme)!;
 
-                    await _emailSender.SendEmailAsync(Input.Email, "Confirm your email",
-                        $"Please confirm your account by <a href='{HtmlEncoder.Default.Encode(callbackUrl)}'>clicking here</a>.");
+                    await _emailSender.SendConfirmationLinkAsync(user, Input.Email, HtmlEncoder.Default.Encode(callbackUrl));
 
                     // If account confirmation is required, we need to show the link if we don't have a real email sender
                     if (_userManager.Options.SignIn.RequireConfirmedAccount)

+ 3 - 7
src/Identity/UI/src/Areas/Identity/Pages/V5/Account/ForgotPassword.cshtml.cs

@@ -5,7 +5,6 @@ using System.ComponentModel.DataAnnotations;
 using System.Text;
 using System.Text.Encodings.Web;
 using Microsoft.AspNetCore.Authorization;
-using Microsoft.AspNetCore.Identity.UI.Services;
 using Microsoft.AspNetCore.Mvc;
 using Microsoft.AspNetCore.Mvc.RazorPages;
 using Microsoft.AspNetCore.WebUtilities;
@@ -52,9 +51,9 @@ public abstract class ForgotPasswordModel : PageModel
 internal sealed class ForgotPasswordModel<TUser> : ForgotPasswordModel where TUser : class
 {
     private readonly UserManager<TUser> _userManager;
-    private readonly IEmailSender _emailSender;
+    private readonly IEmailSender<TUser> _emailSender;
 
-    public ForgotPasswordModel(UserManager<TUser> userManager, IEmailSender emailSender)
+    public ForgotPasswordModel(UserManager<TUser> userManager, IEmailSender<TUser> emailSender)
     {
         _userManager = userManager;
         _emailSender = emailSender;
@@ -81,10 +80,7 @@ internal sealed class ForgotPasswordModel<TUser> : ForgotPasswordModel where TUs
                 values: new { area = "Identity", code },
                 protocol: Request.Scheme)!;
 
-            await _emailSender.SendEmailAsync(
-                Input.Email,
-                "Reset Password",
-                $"Please reset your password by <a href='{HtmlEncoder.Default.Encode(callbackUrl)}'>clicking here</a>.");
+            await _emailSender.SendPasswordResetLinkAsync(user, Input.Email, HtmlEncoder.Default.Encode(callbackUrl));
 
             return RedirectToPage("./ForgotPasswordConfirmation");
         }

+ 4 - 11
src/Identity/UI/src/Areas/Identity/Pages/V5/Account/Manage/Email.cshtml.cs

@@ -4,7 +4,6 @@
 using System.ComponentModel.DataAnnotations;
 using System.Text;
 using System.Text.Encodings.Web;
-using Microsoft.AspNetCore.Identity.UI.Services;
 using Microsoft.AspNetCore.Mvc;
 using Microsoft.AspNetCore.Mvc.RazorPages;
 using Microsoft.AspNetCore.WebUtilities;
@@ -83,12 +82,12 @@ internal sealed class EmailModel<TUser> : EmailModel where TUser : class
 {
     private readonly UserManager<TUser> _userManager;
     private readonly SignInManager<TUser> _signInManager;
-    private readonly IEmailSender _emailSender;
+    private readonly IEmailSender<TUser> _emailSender;
 
     public EmailModel(
         UserManager<TUser> userManager,
         SignInManager<TUser> signInManager,
-        IEmailSender emailSender)
+        IEmailSender<TUser> emailSender)
     {
         _userManager = userManager;
         _signInManager = signInManager;
@@ -145,10 +144,7 @@ internal sealed class EmailModel<TUser> : EmailModel where TUser : class
                 pageHandler: null,
                 values: new { area = "Identity", userId = userId, email = Input.NewEmail, code = code },
                 protocol: Request.Scheme)!;
-            await _emailSender.SendEmailAsync(
-                Input.NewEmail,
-                "Confirm your email",
-                $"Please confirm your account by <a href='{HtmlEncoder.Default.Encode(callbackUrl)}'>clicking here</a>.");
+            await _emailSender.SendConfirmationLinkAsync(user, Input.NewEmail, HtmlEncoder.Default.Encode(callbackUrl));
 
             StatusMessage = "Confirmation link to change email sent. Please check your email.";
             return RedirectToPage();
@@ -181,10 +177,7 @@ internal sealed class EmailModel<TUser> : EmailModel where TUser : class
             pageHandler: null,
             values: new { area = "Identity", userId = userId, code = code },
             protocol: Request.Scheme)!;
-        await _emailSender.SendEmailAsync(
-            email!,
-            "Confirm your email",
-            $"Please confirm your account by <a href='{HtmlEncoder.Default.Encode(callbackUrl)}'>clicking here</a>.");
+        await _emailSender.SendConfirmationLinkAsync(user, email!, HtmlEncoder.Default.Encode(callbackUrl));
 
         StatusMessage = "Verification email sent. Please check your email.";
         return RedirectToPage();

+ 3 - 5
src/Identity/UI/src/Areas/Identity/Pages/V5/Account/Register.cshtml.cs

@@ -8,7 +8,6 @@ using System.Text;
 using System.Text.Encodings.Web;
 using Microsoft.AspNetCore.Authentication;
 using Microsoft.AspNetCore.Authorization;
-using Microsoft.AspNetCore.Identity.UI.Services;
 using Microsoft.AspNetCore.Mvc;
 using Microsoft.AspNetCore.Mvc.RazorPages;
 using Microsoft.AspNetCore.WebUtilities;
@@ -98,14 +97,14 @@ internal sealed class RegisterModel<TUser> : RegisterModel where TUser : class
     private readonly IUserStore<TUser> _userStore;
     private readonly IUserEmailStore<TUser> _emailStore;
     private readonly ILogger<RegisterModel> _logger;
-    private readonly IEmailSender _emailSender;
+    private readonly IEmailSender<TUser> _emailSender;
 
     public RegisterModel(
         UserManager<TUser> userManager,
         IUserStore<TUser> userStore,
         SignInManager<TUser> signInManager,
         ILogger<RegisterModel> logger,
-        IEmailSender emailSender)
+        IEmailSender<TUser>  emailSender)
     {
         _userManager = userManager;
         _userStore = userStore;
@@ -146,8 +145,7 @@ internal sealed class RegisterModel<TUser> : RegisterModel where TUser : class
                     values: new { area = "Identity", userId = userId, code = code, returnUrl = returnUrl },
                     protocol: Request.Scheme)!;
 
-                await _emailSender.SendEmailAsync(Input.Email, "Confirm your email",
-                    $"Please confirm your account by <a href='{HtmlEncoder.Default.Encode(callbackUrl)}'>clicking here</a>.");
+                await _emailSender.SendConfirmationLinkAsync(user, Input.Email, HtmlEncoder.Default.Encode(callbackUrl));
 
                 if (_userManager.Options.SignIn.RequireConfirmedAccount)
                 {

+ 3 - 4
src/Identity/UI/src/Areas/Identity/Pages/V5/Account/RegisterConfirmation.cshtml.cs

@@ -3,7 +3,6 @@
 
 using System.Text;
 using Microsoft.AspNetCore.Authorization;
-using Microsoft.AspNetCore.Identity.UI.Services;
 using Microsoft.AspNetCore.Mvc;
 using Microsoft.AspNetCore.Mvc.RazorPages;
 using Microsoft.AspNetCore.WebUtilities;
@@ -46,9 +45,9 @@ public class RegisterConfirmationModel : PageModel
 internal sealed class RegisterConfirmationModel<TUser> : RegisterConfirmationModel where TUser : class
 {
     private readonly UserManager<TUser> _userManager;
-    private readonly IEmailSender _sender;
+    private readonly IEmailSender<TUser> _sender;
 
-    public RegisterConfirmationModel(UserManager<TUser> userManager, IEmailSender sender)
+    public RegisterConfirmationModel(UserManager<TUser> userManager, IEmailSender<TUser> sender)
     {
         _userManager = userManager;
         _sender = sender;
@@ -70,7 +69,7 @@ internal sealed class RegisterConfirmationModel<TUser> : RegisterConfirmationMod
 
         Email = email;
         // If the email sender is a no-op, display the confirm link in the page
-        DisplayConfirmAccountLink = _sender is NoOpEmailSender;
+        DisplayConfirmAccountLink = _sender is DefaultMessageEmailSender<TUser> defaultMessageSender && defaultMessageSender.IsNoOp;
         if (DisplayConfirmAccountLink)
         {
             var userId = await _userManager.GetUserIdAsync(user);

+ 3 - 7
src/Identity/UI/src/Areas/Identity/Pages/V5/Account/ResendEmailConfirmation.cshtml.cs

@@ -5,7 +5,6 @@ using System.ComponentModel.DataAnnotations;
 using System.Text;
 using System.Text.Encodings.Web;
 using Microsoft.AspNetCore.Authorization;
-using Microsoft.AspNetCore.Identity.UI.Services;
 using Microsoft.AspNetCore.Mvc;
 using Microsoft.AspNetCore.Mvc.RazorPages;
 using Microsoft.AspNetCore.WebUtilities;
@@ -58,9 +57,9 @@ public class ResendEmailConfirmationModel : PageModel
 internal sealed class ResendEmailConfirmationModel<TUser> : ResendEmailConfirmationModel where TUser : class
 {
     private readonly UserManager<TUser> _userManager;
-    private readonly IEmailSender _emailSender;
+    private readonly IEmailSender<TUser> _emailSender;
 
-    public ResendEmailConfirmationModel(UserManager<TUser> userManager, IEmailSender emailSender)
+    public ResendEmailConfirmationModel(UserManager<TUser> userManager, IEmailSender<TUser> emailSender)
     {
         _userManager = userManager;
         _emailSender = emailSender;
@@ -92,10 +91,7 @@ internal sealed class ResendEmailConfirmationModel<TUser> : ResendEmailConfirmat
             pageHandler: null,
             values: new { userId = userId, code = code },
             protocol: Request.Scheme)!;
-        await _emailSender.SendEmailAsync(
-            Input.Email,
-            "Confirm your email",
-            $"Please confirm your account by <a href='{HtmlEncoder.Default.Encode(callbackUrl)}'>clicking here</a>.");
+        await _emailSender.SendConfirmationLinkAsync(user, Input.Email, HtmlEncoder.Default.Encode(callbackUrl));
 
         ModelState.AddModelError(string.Empty, "Verification email sent. Please check your email.");
         return Page();

+ 1 - 0
src/Identity/UI/src/IdentityBuilderUIExtensions.cs

@@ -60,6 +60,7 @@ public static class IdentityBuilderUIExtensions
             typeof(IdentityDefaultUIConfigureOptions<>)
                 .MakeGenericType(builder.UserType));
         builder.Services.TryAddTransient<IEmailSender, NoOpEmailSender>();
+        builder.Services.TryAddTransient(typeof(IEmailSender<>), typeof(DefaultMessageEmailSender<>));
 
         return builder;
     }

+ 4 - 0
src/Identity/UI/src/Microsoft.AspNetCore.Identity.UI.csproj

@@ -20,6 +20,10 @@
     <StaticWebAssetsGetBuildAssetsTargets>GetIdentityUIAssets</StaticWebAssetsGetBuildAssetsTargets>
   </PropertyGroup>
 
+  <ItemGroup>
+    <Compile Include="$(SharedSourceRoot)DefaultMessageEmailSender.cs" />
+  </ItemGroup>
+
   <ItemGroup>
     <None Include="@(Content)" />
     <Content Remove="@(Content)" />

+ 2 - 2
src/Identity/UI/src/PublicAPI.Unshipped.txt

@@ -1,5 +1,5 @@
 #nullable enable
 *REMOVED*Microsoft.AspNetCore.Identity.UI.Services.IEmailSender
 *REMOVED*Microsoft.AspNetCore.Identity.UI.Services.IEmailSender.SendEmailAsync(string! email, string! subject, string! htmlMessage) -> System.Threading.Tasks.Task!
-Microsoft.AspNetCore.Identity.UI.Services.IEmailSender (forwarded, contained in Microsoft.Extensions.Identity.Core)
-Microsoft.AspNetCore.Identity.UI.Services.IEmailSender.SendEmailAsync(string! email, string! subject, string! htmlMessage) -> System.Threading.Tasks.Task! (forwarded, contained in Microsoft.Extensions.Identity.Core)
+Microsoft.AspNetCore.Identity.UI.Services.IEmailSender (forwarded, contained in Microsoft.AspNetCore.Identity)
+Microsoft.AspNetCore.Identity.UI.Services.IEmailSender.SendEmailAsync(string! email, string! subject, string! htmlMessage) -> System.Threading.Tasks.Task! (forwarded, contained in Microsoft.AspNetCore.Identity)

+ 3 - 4
src/Identity/samples/IdentitySample.DefaultUI/Areas/Identity/Pages/Account/Register.cshtml.cs

@@ -17,13 +17,13 @@ public class RegisterModel : PageModel
     private readonly SignInManager<ApplicationUser> _signInManager;
     private readonly UserManager<ApplicationUser> _userManager;
     private readonly ILogger<RegisterModel> _logger;
-    private readonly IEmailSender _emailSender;
+    private readonly IEmailSender<ApplicationUser> _emailSender;
 
     public RegisterModel(
         UserManager<ApplicationUser> userManager,
         SignInManager<ApplicationUser> signInManager,
         ILogger<RegisterModel> logger,
-        IEmailSender emailSender)
+        IEmailSender<ApplicationUser> emailSender)
     {
         _userManager = userManager;
         _signInManager = signInManager;
@@ -99,8 +99,7 @@ public class RegisterModel : PageModel
                     values: new { userId = user.Id, code = code },
                     protocol: Request.Scheme);
 
-                await _emailSender.SendEmailAsync(Input.Email, "Confirm your email",
-                    $"Please confirm your account by <a href='{HtmlEncoder.Default.Encode(callbackUrl)}'>clicking here</a>.");
+                await _emailSender.SendConfirmationLinkAsync(user, Input.Email, HtmlEncoder.Default.Encode(callbackUrl));
 
                 if (_userManager.Options.SignIn.RequireConfirmedAccount)
                 {

+ 40 - 0
src/Identity/test/Identity.FunctionalTests/MapIdentityApiTests.cs

@@ -562,6 +562,27 @@ public class MapIdentityApiTests : LoggedTest
         AssertOk(await client.PostAsJsonAsync("/identity/login", new { Email, Password }));
     }
 
+    [Fact]
+    public async Task AccountConfirmationEmailCanBeCustomized()
+    {
+        var emailSender = new TestEmailSender();
+        var customEmailSender = new TestCustomEmailSender(emailSender);
+
+        await using var app = await CreateAppAsync(services =>
+        {
+            AddIdentityApiEndpoints(services);
+            services.AddSingleton<IEmailSender<ApplicationUser>>(customEmailSender);
+        });
+        using var client = app.GetTestClient();
+
+        await RegisterAsync(client);
+
+        var email = Assert.Single(emailSender.Emails);
+        Assert.Equal(Email, email.Address);
+        Assert.Equal(TestCustomEmailSender.CustomSubject, email.Subject);
+        Assert.Equal(TestCustomEmailSender.CustomMessage, email.HtmlMessage);
+    }
+
     [Fact]
     public async Task CanAddEndpointsToMultipleRouteGroupsForSameUserType()
     {
@@ -1509,5 +1530,24 @@ public class MapIdentityApiTests : LoggedTest
         }
     }
 
+    private sealed class TestCustomEmailSender(IEmailSender emailSender) : IEmailSender<ApplicationUser>
+    {
+        public const string CustomSubject = "Custom subject";
+        public const string CustomMessage = "Custom message";
+
+        public Task SendConfirmationLinkAsync(ApplicationUser user, string email, string confirmationLink)
+        {
+            Assert.Equal(user.Email, email);
+            emailSender.SendEmailAsync(email, "Custom subject", "Custom message");
+            return Task.CompletedTask;
+        }
+
+        public Task SendPasswordResetCodeAsync(ApplicationUser user, string email, string resetCode) =>
+            throw new NotImplementedException();
+
+        public Task SendPasswordResetLinkAsync(ApplicationUser user, string email, string resetLink) =>
+            throw new NotImplementedException();
+    }
+
     private sealed record TestEmail(string Address, string Subject, string HtmlMessage);
 }

+ 79 - 8
src/Mvc/Mvc.TagHelpers/test/PersistComponentStateTagHelperTest.cs

@@ -19,6 +19,7 @@ using Microsoft.Extensions.DependencyInjection;
 using Microsoft.Extensions.Logging;
 using Microsoft.Extensions.Logging.Abstractions;
 using Moq;
+using RenderMode = Microsoft.AspNetCore.Components.Web.RenderMode;
 
 namespace Microsoft.AspNetCore.Mvc.TagHelpers;
 
@@ -50,6 +51,28 @@ public class PersistComponentStateTagHelperTest
         Assert.Null(output.TagName);
     }
 
+    [Fact]
+    public async Task ExecuteAsync_DoesNotRenderWebAssemblyStateWhenStateWasNotPersisted()
+    {
+        // Arrange
+        var tagHelper = new PersistComponentStateTagHelper
+        {
+            ViewContext = GetViewContext(),
+            PersistenceMode = PersistenceMode.WebAssembly
+        };
+
+        var context = GetTagHelperContext();
+        var output = GetTagHelperOutput();
+
+        // Act
+        await tagHelper.ProcessAsync(context, output);
+
+        // Assert
+        var content = HtmlContentUtilities.HtmlContentToString(output.Content);
+        Assert.Empty(content);
+        Assert.Null(output.TagName);
+    }
+
     [Fact]
     public async Task ExecuteAsync_RendersWebAssemblyStateExplicitly()
     {
@@ -62,14 +85,21 @@ public class PersistComponentStateTagHelperTest
 
         var context = GetTagHelperContext();
         var output = GetTagHelperOutput();
+        var manager = tagHelper.ViewContext.HttpContext.RequestServices.GetRequiredService<ComponentStatePersistenceManager>();
 
         // Act
+        manager.State.RegisterOnPersisting(() =>
+        {
+            manager.State.PersistAsJson("state", "state value");
+            return Task.CompletedTask;
+        }, RenderMode.InteractiveWebAssembly);
         await tagHelper.ProcessAsync(context, output);
 
         // Assert
         var content = HtmlContentUtilities.HtmlContentToString(output.Content);
-        Assert.Equal("<!--Blazor-Component-State:e30=-->", content);
         Assert.Null(output.TagName);
+        var message = content["<!--Blazor-WebAssembly-Component-State:".Length..^"-->".Length];
+        Assert.True(message.Length > 0);
     }
 
     [Fact]
@@ -81,18 +111,25 @@ public class PersistComponentStateTagHelperTest
             ViewContext = GetViewContext()
         };
 
-        EndpointHtmlRenderer.UpdateSaveStateRenderMode(tagHelper.ViewContext.HttpContext, Components.Web.RenderMode.InteractiveWebAssembly);
+        EndpointHtmlRenderer.UpdateSaveStateRenderMode(tagHelper.ViewContext.HttpContext, RenderMode.InteractiveWebAssembly);
 
         var context = GetTagHelperContext();
         var output = GetTagHelperOutput();
+        var manager = tagHelper.ViewContext.HttpContext.RequestServices.GetRequiredService<ComponentStatePersistenceManager>();
 
         // Act
+        manager.State.RegisterOnPersisting(() =>
+        {
+            manager.State.PersistAsJson("state", "state value");
+            return Task.CompletedTask;
+        }, RenderMode.InteractiveWebAssembly);
         await tagHelper.ProcessAsync(context, output);
 
         // Assert
         var content = HtmlContentUtilities.HtmlContentToString(output.Content);
-        Assert.Equal("<!--Blazor-Component-State:e30=-->", content);
         Assert.Null(output.TagName);
+        var message = content["<!--Blazor-WebAssembly-Component-State:".Length..^"-->".Length];
+        Assert.True(message.Length > 0);
     }
 
     [Fact]
@@ -107,17 +144,44 @@ public class PersistComponentStateTagHelperTest
 
         var context = GetTagHelperContext();
         var output = GetTagHelperOutput();
+        var manager = tagHelper.ViewContext.HttpContext.RequestServices.GetRequiredService<ComponentStatePersistenceManager>();
 
         // Act
+        manager.State.RegisterOnPersisting(() =>
+        {
+            manager.State.PersistAsJson("state", "state value");
+            return Task.CompletedTask;
+        }, RenderMode.InteractiveServer);
+
         await tagHelper.ProcessAsync(context, output);
 
         // Assert
         var content = HtmlContentUtilities.HtmlContentToString(output.Content);
         Assert.NotEmpty(content);
-        var payload = content["<!--Blazor-Component-State:".Length..^"-->".Length];
+        var payload = content["<!--Blazor-Server-Component-State:".Length..^"-->".Length];
         var message = _protector.Unprotect(payload);
-        Assert.Equal("{}", message);
-        Assert.Null(output.TagName);
+        Assert.True(message.Length > 0);
+    }
+
+    [Fact]
+    public async Task ExecuteAsync_DoesNotRenderServerStateWhenStateWasNotPersisted()
+    {
+        // Arrange
+        var tagHelper = new PersistComponentStateTagHelper
+        {
+            ViewContext = GetViewContext(),
+            PersistenceMode = PersistenceMode.Server
+        };
+
+        var context = GetTagHelperContext();
+        var output = GetTagHelperOutput();
+
+        // Act
+        await tagHelper.ProcessAsync(context, output);
+
+        // Assert
+        var content = HtmlContentUtilities.HtmlContentToString(output.Content);
+        Assert.Empty(content);
     }
 
     [Fact]
@@ -133,16 +197,23 @@ public class PersistComponentStateTagHelperTest
 
         var context = GetTagHelperContext();
         var output = GetTagHelperOutput();
+        var manager = tagHelper.ViewContext.HttpContext.RequestServices.GetRequiredService<ComponentStatePersistenceManager>();
 
         // Act
+        manager.State.RegisterOnPersisting(() =>
+        {
+            manager.State.PersistAsJson("state", "state value");
+            return Task.CompletedTask;
+        }, RenderMode.InteractiveServer);
+
         await tagHelper.ProcessAsync(context, output);
 
         // Assert
         var content = HtmlContentUtilities.HtmlContentToString(output.Content);
         Assert.NotEmpty(content);
-        var payload = content["<!--Blazor-Component-State:".Length..^"-->".Length];
+        var payload = content["<!--Blazor-Server-Component-State:".Length..^"-->".Length];
         var message = _protector.Unprotect(payload);
-        Assert.Equal("{}", message);
+        Assert.True(message.Length > 0);
     }
 
     [Fact]

+ 18 - 1
src/Shared/Components/PrerenderComponentApplicationStore.cs

@@ -3,6 +3,7 @@
 
 using System.Diagnostics.CodeAnalysis;
 using System.Text.Json;
+using Microsoft.AspNetCore.Components.Web;
 
 namespace Microsoft.AspNetCore.Components;
 
@@ -10,6 +11,8 @@ namespace Microsoft.AspNetCore.Components;
 internal class PrerenderComponentApplicationStore : IPersistentComponentStateStore
 #pragma warning restore CA1852 // Seal internal types
 {
+    private bool _stateIsPersisted;
+
     public PrerenderComponentApplicationStore()
     {
         ExistingState = new();
@@ -52,7 +55,21 @@ internal class PrerenderComponentApplicationStore : IPersistentComponentStateSto
 
     public Task PersistStateAsync(IReadOnlyDictionary<string, byte[]> state)
     {
-        PersistedState = Convert.ToBase64String(SerializeState(state));
+        if (_stateIsPersisted)
+        {
+            throw new InvalidOperationException("State already persisted.");
+        }
+
+        _stateIsPersisted = true;
+
+        if (state is not null && state.Count > 0)
+        {
+            PersistedState = Convert.ToBase64String(SerializeState(state));
+        }
+
         return Task.CompletedTask;
     }
+
+    public virtual bool SupportsRenderMode(IComponentRenderMode renderMode) =>
+        renderMode is null || renderMode is InteractiveWebAssemblyRenderMode || renderMode is InteractiveAutoRenderMode;
 }

+ 5 - 0
src/Shared/Components/ProtectedPrerenderComponentApplicationStore.cs

@@ -1,6 +1,7 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
+using Microsoft.AspNetCore.Components.Web;
 using Microsoft.AspNetCore.DataProtection;
 
 namespace Microsoft.AspNetCore.Components;
@@ -28,4 +29,8 @@ internal sealed class ProtectedPrerenderComponentApplicationStore : PrerenderCom
 
     private void CreateProtector(IDataProtectionProvider dataProtectionProvider) =>
         _protector = dataProtectionProvider.CreateProtector("Microsoft.AspNetCore.Components.Server.State");
+
+    public override bool SupportsRenderMode(IComponentRenderMode renderMode) =>
+        renderMode is null ||
+        renderMode is InteractiveServerRenderMode || renderMode is InteractiveAutoRenderMode;
 }

+ 20 - 0
src/Shared/DefaultMessageEmailSender.cs

@@ -0,0 +1,20 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using Microsoft.AspNetCore.Identity.UI.Services;
+
+namespace Microsoft.AspNetCore.Identity;
+
+internal sealed class DefaultMessageEmailSender<TUser>(IEmailSender emailSender) : IEmailSender<TUser> where TUser : class
+{
+    internal bool IsNoOp => emailSender is NoOpEmailSender;
+
+    public Task SendConfirmationLinkAsync(TUser user, string email, string confirmationLink) =>
+        emailSender.SendEmailAsync(email, "Confirm your email", $"Please confirm your account by <a href='{confirmationLink}'>clicking here</a>.");
+
+    public Task SendPasswordResetLinkAsync(TUser user, string email, string resetLink) =>
+        emailSender.SendEmailAsync(email, "Reset your password", $"Please reset your password by <a href='{resetLink}'>clicking here</a>.");
+
+    public Task SendPasswordResetCodeAsync(TUser user, string email, string resetCode) =>
+        emailSender.SendEmailAsync(email, "Reset your password", $"Please reset your password using the following code: {resetCode}");
+}

+ 1 - 1
src/Shared/E2ETesting/BrowserFixture.cs

@@ -142,7 +142,7 @@ public class BrowserFixture : IAsyncLifetime
     {
         var opts = new ChromeOptions();
 
-        if (string.Equals(context, StreamingContext, StringComparison.Ordinal))
+        if (context?.StartsWith(StreamingContext, StringComparison.Ordinal) == true)
         {
             // Tells Selenium not to wait until the page navigation has completed before continuing with the tests
             opts.PageLoadStrategy = PageLoadStrategy.None;

この差分においてかなりの量のファイルが変更されているため、一部のファイルを表示していません