瀏覽代碼

SignOut 2fa remember me cookie when validation fails (#38262)

Hao Kung 4 年之前
父節點
當前提交
6d3d274df3

+ 8 - 1
src/Identity/Core/src/IdentityCookiesBuilderExtensions.cs

@@ -81,7 +81,14 @@ public static class IdentityCookieAuthenticationBuilderExtensions
     /// <returns>The <see cref="OptionsBuilder{TOptions}"/> which can be used to configure the cookie authentication.</returns>
     public static OptionsBuilder<CookieAuthenticationOptions> AddTwoFactorRememberMeCookie(this AuthenticationBuilder builder)
     {
-        builder.AddCookie(IdentityConstants.TwoFactorRememberMeScheme, o => o.Cookie.Name = IdentityConstants.TwoFactorRememberMeScheme);
+        builder.AddCookie(IdentityConstants.TwoFactorRememberMeScheme, o =>
+        {
+            o.Cookie.Name = IdentityConstants.TwoFactorRememberMeScheme;
+            o.Events = new CookieAuthenticationEvents
+            {
+                OnValidatePrincipal = SecurityStampValidator.ValidateAsync<ITwoFactorSecurityStampValidator>
+            };
+        });
         return new OptionsBuilder<CookieAuthenticationOptions>(builder.Services, IdentityConstants.TwoFactorRememberMeScheme);
     }
 

+ 1 - 0
src/Identity/Core/src/SecurityStampValidator.cs

@@ -142,6 +142,7 @@ public class SecurityStampValidator<TUser> : ISecurityStampValidator where TUser
                 Logger.LogDebug(EventIds.SecurityStampValidationFailed, "Security stamp validation failed, rejecting cookie.");
                 context.RejectPrincipal();
                 await SignInManager.SignOutAsync();
+                await SignInManager.Context.SignOutAsync(IdentityConstants.TwoFactorRememberMeScheme);
             }
         }
     }

+ 9 - 1
src/Identity/test/Identity.Test/SecurityStampValidatorTest.cs

@@ -115,10 +115,13 @@ public class SecurityStampTest
             signInManager.Setup(s => s.CreateUserPrincipalAsync(user)).ReturnsAsync(principal).Verifiable();
         }
 
+        var authService = new Mock<IAuthenticationService>();
+        authService.Setup(c => c.SignOutAsync(httpContext.Object, IdentityConstants.TwoFactorRememberMeScheme, /*properties*/null)).Returns(Task.CompletedTask).Verifiable();
         var services = new ServiceCollection();
         services.AddSingleton(options.Object);
         services.AddSingleton(signInManager.Object);
         services.AddSingleton<ISecurityStampValidator>(new SecurityStampValidator<PocoUser>(options.Object, signInManager.Object, new SystemClock(), new LoggerFactory()));
+        services.AddSingleton(authService.Object);
         httpContext.Setup(c => c.RequestServices).Returns(services.BuildServiceProvider());
 
         await testCode.Invoke();
@@ -154,7 +157,6 @@ public class SecurityStampTest
         var user = new PocoUser("test");
         var httpContext = new Mock<HttpContext>();
 
-
         var userManager = MockHelpers.MockUserManager<PocoUser>();
 
         var claimsManager = new Mock<IUserClaimsPrincipalFactory<PocoUser>>();
@@ -208,10 +210,13 @@ public class SecurityStampTest
         var signInManager = new Mock<SignInManager<PocoUser>>(userManager.Object,
             contextAccessor.Object, claimsManager.Object, identityOptions.Object, null, new Mock<IAuthenticationSchemeProvider>().Object, new DefaultUserConfirmation<PocoUser>());
         signInManager.Setup(s => s.ValidateSecurityStampAsync(It.IsAny<ClaimsPrincipal>())).ReturnsAsync(default(PocoUser)).Verifiable();
+        var authService = new Mock<IAuthenticationService>();
+        authService.Setup(c => c.SignOutAsync(httpContext.Object, IdentityConstants.TwoFactorRememberMeScheme, /*properties*/null)).Returns(Task.CompletedTask).Verifiable();
         var services = new ServiceCollection();
         services.AddSingleton(options.Object);
         services.AddSingleton(signInManager.Object);
         services.AddSingleton<ISecurityStampValidator>(new SecurityStampValidator<PocoUser>(options.Object, signInManager.Object, new SystemClock(), new LoggerFactory()));
+        services.AddSingleton(authService.Object);
         httpContext.Setup(c => c.RequestServices).Returns(services.BuildServiceProvider());
         var id = new ClaimsIdentity(IdentityConstants.ApplicationScheme);
         id.AddClaim(new Claim(ClaimTypes.NameIdentifier, user.Id));
@@ -332,10 +337,13 @@ public class SecurityStampTest
             contextAccessor.Object, claimsManager.Object, identityOptions.Object, null, new Mock<IAuthenticationSchemeProvider>().Object, new DefaultUserConfirmation<PocoUser>());
         signInManager.Setup(s => s.ValidateTwoFactorSecurityStampAsync(It.IsAny<ClaimsPrincipal>())).ReturnsAsync(shouldStampValidate ? user : default).Verifiable();
 
+        var authService = new Mock<IAuthenticationService>();
+        authService.Setup(c => c.SignOutAsync(httpContext.Object, IdentityConstants.TwoFactorRememberMeScheme, /*properties*/null)).Returns(Task.CompletedTask).Verifiable();
         var services = new ServiceCollection();
         services.AddSingleton(options.Object);
         services.AddSingleton(signInManager.Object);
         services.AddSingleton<ITwoFactorSecurityStampValidator>(new TwoFactorSecurityStampValidator<PocoUser>(options.Object, signInManager.Object, new SystemClock(), new LoggerFactory()));
+        services.AddSingleton(authService.Object);
         httpContext.Setup(c => c.RequestServices).Returns(services.BuildServiceProvider());
 
         var principal = await signInManager.Object.StoreRememberClient(user);

+ 26 - 32
src/Identity/test/InMemory.Test/FunctionalTest.cs

@@ -1,14 +1,10 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
-using System;
-using System.IO;
-using System.Linq;
 using System.Net;
 using System.Net.Http;
 using System.Security.Claims;
 using System.Text;
-using System.Threading.Tasks;
 using System.Xml;
 using System.Xml.Linq;
 using Microsoft.AspNetCore.Authentication;
@@ -20,7 +16,6 @@ using Microsoft.AspNetCore.TestHost;
 using Microsoft.Extensions.DependencyInjection;
 using Microsoft.Extensions.Hosting;
 using Microsoft.Net.Http.Headers;
-using Xunit;
 
 namespace Microsoft.AspNetCore.Identity.InMemory;
 
@@ -67,8 +62,10 @@ public class FunctionalTest
         Assert.Null(transaction3.SetCookie);
     }
 
-    [Fact]
-    public async Task CanCreateMeLoginAndCookieStopsWorkingAfterExpiration()
+    [Theory]
+    [InlineData(true)]
+    [InlineData(false)]
+    public async Task CanCreateMeLoginAndCookieStopsWorkingAfterExpiration(bool testCore)
     {
         var clock = new TestClock();
         var server = await CreateServer(services =>
@@ -79,7 +76,7 @@ public class FunctionalTest
                 options.SlidingExpiration = false;
             });
             services.AddSingleton<ISystemClock>(clock);
-        });
+        }, testCore: testCore);
 
         var transaction1 = await SendAsync(server, "http://example.com/createMe");
         Assert.Equal(HttpStatusCode.OK, transaction1.Response.StatusCode);
@@ -108,12 +105,14 @@ public class FunctionalTest
     }
 
     [Theory]
-    [InlineData(true)]
-    [InlineData(false)]
-    public async Task CanCreateMeLoginAndSecurityStampExtendsExpiration(bool rememberMe)
+    [InlineData(true, true)]
+    [InlineData(true, false)]
+    [InlineData(false, true)]
+    [InlineData(false, false)]
+    public async Task CanCreateMeLoginAndSecurityStampExtendsExpiration(bool rememberMe, bool testCore)
     {
         var clock = new TestClock();
-        var server = await CreateServer(services => services.AddSingleton<ISystemClock>(clock));
+        var server = await CreateServer(services => services.AddSingleton<ISystemClock>(clock), testCore: testCore);
 
         var transaction1 = await SendAsync(server, "http://example.com/createMe");
         Assert.Equal(HttpStatusCode.OK, transaction1.Response.StatusCode);
@@ -153,8 +152,10 @@ public class FunctionalTest
         Assert.Equal("hao", FindClaimValue(transaction6, ClaimTypes.Name));
     }
 
-    [Fact]
-    public async Task CanAccessOldPrincipalDuringSecurityStampReplacement()
+    [Theory]
+    [InlineData(true)]
+    [InlineData(false)]
+    public async Task CanAccessOldPrincipalDuringSecurityStampReplacement(bool testCore)
     {
         var clock = new TestClock();
         var server = await CreateServer(services =>
@@ -170,7 +171,7 @@ public class FunctionalTest
                 };
             });
             services.AddSingleton<ISystemClock>(clock);
-        });
+        }, testCore: testCore);
 
         var transaction1 = await SendAsync(server, "http://example.com/createMe");
         Assert.Equal(HttpStatusCode.OK, transaction1.Response.StatusCode);
@@ -204,11 +205,13 @@ public class FunctionalTest
         Assert.Equal("hao", FindClaimValue(transaction6, ClaimTypes.Name));
     }
 
-    [Fact]
-    public async Task TwoFactorRememberCookieVerification()
+    [Theory]
+    [InlineData(true)]
+    [InlineData(false)]
+    public async Task TwoFactorRememberCookieVerification(bool testCore)
     {
         var clock = new TestClock();
-        var server = await CreateServer(services => services.AddSingleton<ISystemClock>(clock));
+        var server = await CreateServer(services => services.AddSingleton<ISystemClock>(clock), testCore: testCore);
 
         var transaction1 = await SendAsync(server, "http://example.com/createMe");
         Assert.Equal(HttpStatusCode.OK, transaction1.Response.StatusCode);
@@ -231,11 +234,13 @@ public class FunctionalTest
         Assert.Equal(HttpStatusCode.OK, transaction4.Response.StatusCode);
     }
 
-    [Fact]
-    public async Task TwoFactorRememberCookieClearedBySecurityStampChange()
+    [Theory]
+    [InlineData(true)]
+    [InlineData(false)]
+    public async Task TwoFactorRememberCookieClearedBySecurityStampChange(bool testCore)
     {
         var clock = new TestClock();
-        var server = await CreateServer(services => services.AddSingleton<ISystemClock>(clock));
+        var server = await CreateServer(services => services.AddSingleton<ISystemClock>(clock), testCore: testCore);
 
         var transaction1 = await SendAsync(server, "http://example.com/createMe");
         Assert.Equal(HttpStatusCode.OK, transaction1.Response.StatusCode);
@@ -275,17 +280,6 @@ public class FunctionalTest
         return claim.Attribute("value").Value;
     }
 
-    private static async Task<XElement> GetAuthData(TestServer server, string url, string cookie)
-    {
-        var request = new HttpRequestMessage(HttpMethod.Get, url);
-        request.Headers.Add("Cookie", cookie);
-
-        var response2 = await server.CreateClient().SendAsync(request);
-        var text = await response2.Content.ReadAsStringAsync();
-        var me = XElement.Parse(text);
-        return me;
-    }
-
     private static async Task<TestServer> CreateServer(Action<IServiceCollection> configureServices = null, Func<HttpContext, Task> testpath = null, Uri baseAddress = null, bool testCore = false)
     {
         var host = new HostBuilder()