소스 검색

Add a startup filter which initializes the key ring before the server starts

Nate McMaster 8 년 전
부모
커밋
fe83e69b1a

+ 2 - 0
src/Microsoft.AspNetCore.DataProtection/DataProtectionServiceCollectionExtensions.cs

@@ -9,6 +9,7 @@ using Microsoft.AspNetCore.DataProtection.Internal;
 using Microsoft.AspNetCore.DataProtection.KeyManagement;
 using Microsoft.AspNetCore.DataProtection.KeyManagement.Internal;
 using Microsoft.AspNetCore.DataProtection.XmlEncryption;
+using Microsoft.AspNetCore.Hosting;
 using Microsoft.Extensions.DependencyInjection.Extensions;
 using Microsoft.Extensions.Logging;
 using Microsoft.Extensions.Options;
@@ -77,6 +78,7 @@ namespace Microsoft.Extensions.DependencyInjection
 
             services.TryAddSingleton<IKeyManager, XmlKeyManager>();
             services.TryAddSingleton<IApplicationDiscriminator, HostingApplicationDiscriminator>();
+            services.TryAddEnumerable(ServiceDescriptor.Singleton<IStartupFilter, DataProtectionStartupFilter>());
 
             // Internal services
             services.TryAddSingleton<IDefaultKeyResolver, DefaultKeyResolver>();

+ 43 - 0
src/Microsoft.AspNetCore.DataProtection/Internal/DataProtectionStartupFilter.cs

@@ -0,0 +1,43 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using Microsoft.AspNetCore.Builder;
+using Microsoft.AspNetCore.DataProtection.KeyManagement.Internal;
+using Microsoft.AspNetCore.Hosting;
+using Microsoft.Extensions.Logging;
+
+namespace Microsoft.AspNetCore.DataProtection.Internal
+{
+    internal class DataProtectionStartupFilter : IStartupFilter
+    {
+        private readonly IKeyRingProvider _keyRingProvider;
+        private readonly ILogger<DataProtectionStartupFilter> _logger;
+
+        public DataProtectionStartupFilter(IKeyRingProvider keyRingProvider, ILoggerFactory loggerFactory)
+        {
+            _keyRingProvider = keyRingProvider;
+            _logger = loggerFactory.CreateLogger<DataProtectionStartupFilter>();
+        }
+
+        public Action<IApplicationBuilder> Configure(Action<IApplicationBuilder> next)
+        {
+            try
+            {
+                // It doesn't look like much, but this preloads the key ring,
+                // which in turn may load data from remote stores like Redis or Azure.
+                var keyRing = _keyRingProvider.GetCurrentKeyRing();
+
+                _logger.KeyRingWasLoadedOnStartup(keyRing.DefaultKeyId);
+            }
+            catch (Exception ex)
+            {
+                // This should be non-fatal, so swallow, log, and allow server startup to continue.
+                // The KeyRingProvider may be able to try again on the first request.
+                _logger.KeyRingFailedToLoadOnStartup(ex);
+            }
+
+            return next;
+        }
+    }
+}

+ 23 - 1
src/Microsoft.AspNetCore.DataProtection/LoggingExtensions.cs

@@ -129,6 +129,10 @@ namespace Microsoft.Extensions.Logging
 
         private static Action<ILogger, Exception> _policyResolutionStatesThatANewKeyShouldBeAddedToTheKeyRing;
 
+        private static Action<ILogger, Guid, Exception> _keyRingWasLoadedOnStartup;
+
+        private static Action<ILogger, Exception> _keyRingFailedToLoadOnStartup;
+
         private static Action<ILogger, Exception> _usingEphemeralKeyRepository;
 
         private static Action<ILogger, string, Exception> _usingRegistryAsKeyRepositoryWithDPAPI;
@@ -388,6 +392,14 @@ namespace Microsoft.Extensions.Logging
             _usingAzureAsKeyRepository = LoggerMessage.Define<string>(eventId: 0,
                 logLevel: LogLevel.Information,
                 formatString: "Azure Web Sites environment detected. Using '{FullName}' as key repository; keys will not be encrypted at rest.");
+            _keyRingWasLoadedOnStartup = LoggerMessage.Define<Guid>(
+                eventId: 0,
+                logLevel: LogLevel.Debug,
+                formatString: "Key ring with default key {KeyId:B} was loaded during application startup.");
+            _keyRingFailedToLoadOnStartup = LoggerMessage.Define(
+                eventId: 0,
+                logLevel: LogLevel.Information,
+                formatString: "Key ring failed to load during application startup.");
         }
 
         /// <summary>
@@ -760,5 +772,15 @@ namespace Microsoft.Extensions.Logging
         {
             _usingAzureAsKeyRepository(logger, fullName, null);
         }
+
+        public static void KeyRingWasLoadedOnStartup(this ILogger logger, Guid defaultKeyId)
+        {
+            _keyRingWasLoadedOnStartup(logger, defaultKeyId, null);
+        }
+
+        public static void KeyRingFailedToLoadOnStartup(this ILogger logger, Exception innerException)
+        {
+            _keyRingFailedToLoadOnStartup(logger, innerException);
+        }
     }
-}
+}

+ 104 - 0
test/Microsoft.AspNetCore.DataProtection.Test/HostingTests.cs

@@ -0,0 +1,104 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.AspNetCore.Builder;
+using Microsoft.AspNetCore.DataProtection.KeyManagement.Internal;
+using Microsoft.AspNetCore.Hosting;
+using Microsoft.AspNetCore.Hosting.Server;
+using Microsoft.AspNetCore.Http.Features;
+using Microsoft.AspNetCore.Testing;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.DependencyInjection.Extensions;
+using Moq;
+using Xunit;
+
+namespace Microsoft.AspNetCore.DataProtection.Test
+{
+    public class HostingTests
+    {
+        [Fact]
+        public async Task LoadsKeyRingBeforeServerStarts()
+        {
+            var tcs = new TaskCompletionSource<object>();
+            var mockKeyRing = new Mock<IKeyRingProvider>();
+            mockKeyRing.Setup(m => m.GetCurrentKeyRing())
+                .Returns(Mock.Of<IKeyRing>())
+                .Callback(() => tcs.TrySetResult(null));
+
+            var builder = new WebHostBuilder()
+                .UseStartup<TestStartup>()
+                .ConfigureServices(s =>
+                    s.AddDataProtection()
+                    .Services
+                    .Replace(ServiceDescriptor.Singleton(mockKeyRing.Object))
+                    .AddSingleton<IServer>(
+                        new FakeServer(onStart: () => tcs.TrySetException(new InvalidOperationException("Server was started before key ring was initialized")))));
+
+            using (var host = builder.Build())
+            {
+                await host.StartAsync();
+            }
+
+            await tcs.Task.TimeoutAfter(TimeSpan.FromSeconds(10));
+            mockKeyRing.VerifyAll();
+        }
+
+        [Fact]
+        public async Task StartupContinuesOnFailureToLoadKey()
+        {
+            var mockKeyRing = new Mock<IKeyRingProvider>();
+            mockKeyRing.Setup(m => m.GetCurrentKeyRing())
+                .Throws(new NotSupportedException("This mock doesn't actually work, but shouldn't kill the server"))
+                .Verifiable();
+
+            var builder = new WebHostBuilder()
+                .UseStartup<TestStartup>()
+                .ConfigureServices(s =>
+                    s.AddDataProtection()
+                    .Services
+                    .Replace(ServiceDescriptor.Singleton(mockKeyRing.Object))
+                    .AddSingleton(Mock.Of<IServer>()));
+
+            using (var host = builder.Build())
+            {
+                await host.StartAsync();
+            }
+            
+            mockKeyRing.VerifyAll();
+        }
+
+        private class TestStartup
+        {
+            public void Configure(IApplicationBuilder app)
+            {
+            }
+        }
+
+        public class FakeServer : IServer
+        {
+            private readonly Action _onStart;
+
+            public FakeServer(Action onStart)
+            {
+                _onStart = onStart;
+            }
+
+            public IFeatureCollection Features => new FeatureCollection();
+
+            public Task StartAsync<TContext>(IHttpApplication<TContext> application, CancellationToken cancellationToken)
+            {
+                _onStart();
+                return Task.CompletedTask;
+            }
+
+            public Task StopAsync(CancellationToken cancellationToken) => Task.CompletedTask;
+
+            public void Dispose()
+            {
+            }
+        }
+    }
+}

+ 1 - 0
test/Microsoft.AspNetCore.DataProtection.Test/Microsoft.AspNetCore.DataProtection.Test.csproj

@@ -17,6 +17,7 @@
   </ItemGroup>
 
   <ItemGroup>
+    <PackageReference Include="Microsoft.AspNetCore.Hosting" Version="$(AspNetCoreVersion)" />
     <PackageReference Include="Microsoft.AspNetCore.Testing" Version="$(AspNetCoreVersion)" />
     <PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="$(AspNetCoreVersion)" />
     <PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(TestSdkVersion)" />