Просмотр исходного кода

Add a feature for accessing the AuthenticateResult (#33408)

Brennan 4 лет назад
Родитель
Сommit
dbf84eaa5a

+ 19 - 0
src/Http/Authentication.Abstractions/src/IAuthenticateResultFeature.cs

@@ -0,0 +1,19 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using Microsoft.AspNetCore.Http.Features.Authentication;
+
+namespace Microsoft.AspNetCore.Authentication
+{
+    /// <summary>
+    /// Used to capture the <see cref="AuthenticateResult"/> from the authorization middleware.
+    /// </summary>
+    public interface IAuthenticateResultFeature
+    {
+        /// <summary>
+        /// The <see cref="AuthenticateResult"/> from the authorization middleware.
+        /// Set to null if the <see cref="IHttpAuthenticationFeature.User"/> property is set after the authorization middleware.
+        /// </summary>
+        AuthenticateResult? AuthenticateResult { get; set; }
+    }
+}

+ 3 - 0
src/Http/Authentication.Abstractions/src/PublicAPI.Unshipped.txt

@@ -1 +1,4 @@
 #nullable enable
+Microsoft.AspNetCore.Authentication.IAuthenticateResultFeature
+Microsoft.AspNetCore.Authentication.IAuthenticateResultFeature.AuthenticateResult.get -> Microsoft.AspNetCore.Authentication.AuthenticateResult?
+Microsoft.AspNetCore.Authentication.IAuthenticateResultFeature.AuthenticateResult.set -> void

+ 42 - 0
src/Security/Authentication/Core/src/AuthenticationFeatures.cs

@@ -0,0 +1,42 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System.Security.Claims;
+using Microsoft.AspNetCore.Http.Features.Authentication;
+
+namespace Microsoft.AspNetCore.Authentication
+{
+    /// <summary>
+    /// Keeps the User and AuthenticationResult consistent with each other
+    /// </summary>
+    internal sealed class AuthenticationFeatures : IAuthenticateResultFeature, IHttpAuthenticationFeature
+    {
+        private ClaimsPrincipal? _user;
+        private AuthenticateResult? _result;
+
+        public AuthenticationFeatures(AuthenticateResult result)
+        {
+            AuthenticateResult = result;
+        }
+
+        public AuthenticateResult? AuthenticateResult
+        {
+            get => _result;
+            set
+            {
+                _result = value;
+                _user = _result?.Principal;
+            }
+        }
+
+        public ClaimsPrincipal? User
+        {
+            get => _user;
+            set
+            {
+                _user = value;
+                _result = null;
+            }
+        }
+    }
+}

+ 7 - 0
src/Security/Authentication/Core/src/AuthenticationMiddleware.cs

@@ -4,6 +4,7 @@
 using System;
 using System.Threading.Tasks;
 using Microsoft.AspNetCore.Http;
+using Microsoft.AspNetCore.Http.Features.Authentication;
 using Microsoft.Extensions.DependencyInjection;
 
 namespace Microsoft.AspNetCore.Authentication
@@ -71,6 +72,12 @@ namespace Microsoft.AspNetCore.Authentication
                 {
                     context.User = result.Principal;
                 }
+                if (result?.Succeeded ?? false)
+                {
+                    var authFeatures = new AuthenticationFeatures(result);
+                    context.Features.Set<IHttpAuthenticationFeature>(authFeatures);
+                    context.Features.Set<IAuthenticateResultFeature>(authFeatures);
+                }
             }
 
             await _next(context);

+ 124 - 1
src/Security/Authentication/test/AuthenticationMiddlewareTests.cs

@@ -3,12 +3,15 @@
 using System;
 using System.Security.Claims;
 using System.Threading.Tasks;
+using Microsoft.AspNetCore.Authentication.JwtBearer;
 using Microsoft.AspNetCore.Builder;
 using Microsoft.AspNetCore.Hosting;
 using Microsoft.AspNetCore.Http;
 using Microsoft.AspNetCore.TestHost;
 using Microsoft.Extensions.DependencyInjection;
 using Microsoft.Extensions.Hosting;
+using Microsoft.Extensions.Logging;
+using Moq;
 using Xunit;
 
 namespace Microsoft.AspNetCore.Authentication
@@ -54,6 +57,126 @@ namespace Microsoft.AspNetCore.Authentication
             Assert.Equal(607, (int)response.StatusCode);
         }
 
+        [Fact]
+        public async Task IAuthenticateResultFeature_SetOnSuccessfulAuthenticate()
+        {
+            var authenticationService = new Mock<IAuthenticationService>();
+            authenticationService.Setup(s => s.AuthenticateAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
+                .Returns(Task.FromResult(AuthenticateResult.Success(new AuthenticationTicket(new ClaimsPrincipal(), "custom"))));
+            var schemeProvider = new Mock<IAuthenticationSchemeProvider>();
+            schemeProvider.Setup(p => p.GetDefaultAuthenticateSchemeAsync())
+                .Returns(Task.FromResult(new AuthenticationScheme("custom", "custom", typeof(JwtBearerHandler))));
+            var middleware = new AuthenticationMiddleware(c => Task.CompletedTask, schemeProvider.Object);
+            var context = GetHttpContext(authenticationService: authenticationService.Object);
+
+            // Act
+            await middleware.Invoke(context);
+
+            // Assert
+            var authenticateResultFeature = context.Features.Get<IAuthenticateResultFeature>();
+            Assert.NotNull(authenticateResultFeature);
+            Assert.NotNull(authenticateResultFeature.AuthenticateResult);
+            Assert.True(authenticateResultFeature.AuthenticateResult.Succeeded);
+            Assert.Same(context.User, authenticateResultFeature.AuthenticateResult.Principal);
+        }
+
+        [Fact]
+        public async Task IAuthenticateResultFeature_NotSetOnUnsuccessfulAuthenticate()
+        {
+            var authenticationService = new Mock<IAuthenticationService>();
+            authenticationService.Setup(s => s.AuthenticateAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
+                .Returns(Task.FromResult(AuthenticateResult.Fail("not authenticated")));
+            var schemeProvider = new Mock<IAuthenticationSchemeProvider>();
+            schemeProvider.Setup(p => p.GetDefaultAuthenticateSchemeAsync())
+                .Returns(Task.FromResult(new AuthenticationScheme("custom", "custom", typeof(JwtBearerHandler))));
+            var middleware = new AuthenticationMiddleware(c => Task.CompletedTask, schemeProvider.Object);
+            var context = GetHttpContext(authenticationService: authenticationService.Object);
+
+            // Act
+            await middleware.Invoke(context);
+
+            // Assert
+            var authenticateResultFeature = context.Features.Get<IAuthenticateResultFeature>();
+            Assert.Null(authenticateResultFeature);
+        }
+
+        [Fact]
+        public async Task IAuthenticateResultFeature_NullResultWhenUserSetAfter()
+        {
+            var authenticationService = new Mock<IAuthenticationService>();
+            authenticationService.Setup(s => s.AuthenticateAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
+                .Returns(Task.FromResult(AuthenticateResult.Success(new AuthenticationTicket(new ClaimsPrincipal(), "custom"))));
+            var schemeProvider = new Mock<IAuthenticationSchemeProvider>();
+            schemeProvider.Setup(p => p.GetDefaultAuthenticateSchemeAsync())
+                .Returns(Task.FromResult(new AuthenticationScheme("custom", "custom", typeof(JwtBearerHandler))));
+            var middleware = new AuthenticationMiddleware(c => Task.CompletedTask, schemeProvider.Object);
+            var context = GetHttpContext(authenticationService: authenticationService.Object);
+
+            // Act
+            await middleware.Invoke(context);
+
+            // Assert
+            var authenticateResultFeature = context.Features.Get<IAuthenticateResultFeature>();
+            Assert.NotNull(authenticateResultFeature);
+            Assert.NotNull(authenticateResultFeature.AuthenticateResult);
+            Assert.True(authenticateResultFeature.AuthenticateResult.Succeeded);
+            Assert.Same(context.User, authenticateResultFeature.AuthenticateResult.Principal);
+
+            context.User = new ClaimsPrincipal();
+            Assert.Null(authenticateResultFeature.AuthenticateResult);
+        }
+
+        [Fact]
+        public async Task IAuthenticateResultFeature_SettingResultSetsUser()
+        {
+            var authenticationService = new Mock<IAuthenticationService>();
+            authenticationService.Setup(s => s.AuthenticateAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
+                .Returns(Task.FromResult(AuthenticateResult.Success(new AuthenticationTicket(new ClaimsPrincipal(), "custom"))));
+            var schemeProvider = new Mock<IAuthenticationSchemeProvider>();
+            schemeProvider.Setup(p => p.GetDefaultAuthenticateSchemeAsync())
+                .Returns(Task.FromResult(new AuthenticationScheme("custom", "custom", typeof(JwtBearerHandler))));
+            var middleware = new AuthenticationMiddleware(c => Task.CompletedTask, schemeProvider.Object);
+            var context = GetHttpContext(authenticationService: authenticationService.Object);
+
+            // Act
+            await middleware.Invoke(context);
+
+            // Assert
+            var authenticateResultFeature = context.Features.Get<IAuthenticateResultFeature>();
+            Assert.NotNull(authenticateResultFeature);
+            Assert.NotNull(authenticateResultFeature.AuthenticateResult);
+            Assert.True(authenticateResultFeature.AuthenticateResult.Succeeded);
+            Assert.Same(context.User, authenticateResultFeature.AuthenticateResult.Principal);
+
+            var newTicket = new AuthenticationTicket(new ClaimsPrincipal(), "");
+            authenticateResultFeature.AuthenticateResult = AuthenticateResult.Success(newTicket);
+            Assert.Same(context.User, newTicket.Principal);
+        }
+
+        private HttpContext GetHttpContext(
+            Action<IServiceCollection> registerServices = null,
+            IAuthenticationService authenticationService = null)
+        {
+            // ServiceProvider
+            var serviceCollection = new ServiceCollection();
+
+            authenticationService = authenticationService ?? Mock.Of<IAuthenticationService>();
+
+            serviceCollection.AddSingleton(authenticationService);
+            serviceCollection.AddOptions();
+            serviceCollection.AddLogging();
+            serviceCollection.AddAuthentication();
+            registerServices?.Invoke(serviceCollection);
+
+            var serviceProvider = serviceCollection.BuildServiceProvider();
+
+            //// HttpContext
+            var httpContext = new DefaultHttpContext();
+            httpContext.RequestServices = serviceProvider;
+
+            return httpContext;
+        }
+
         private class ThreeOhFiveHandler : StatusCodeHandler {
             public ThreeOhFiveHandler() : base(305) { }
         }
@@ -77,7 +200,7 @@ namespace Microsoft.AspNetCore.Authentication
             {
                 _code = code;
             }
-            
+
             public Task<AuthenticateResult> AuthenticateAsync()
             {
                 throw new NotImplementedException();

+ 43 - 0
src/Security/Authorization/Policy/src/AuthenticationFeatures.cs

@@ -0,0 +1,43 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System.Security.Claims;
+using Microsoft.AspNetCore.Authentication;
+using Microsoft.AspNetCore.Http.Features.Authentication;
+
+namespace Microsoft.AspNetCore.Authorization.Policy
+{
+    /// <summary>
+    /// Keeps the User and AuthenticationResult consistent with each other
+    /// </summary>
+    internal sealed class AuthenticationFeatures : IAuthenticateResultFeature, IHttpAuthenticationFeature
+    {
+        private ClaimsPrincipal? _user;
+        private AuthenticateResult? _result;
+
+        public AuthenticationFeatures(AuthenticateResult result)
+        {
+            AuthenticateResult = result;
+        }
+
+        public AuthenticateResult? AuthenticateResult
+        {
+            get => _result;
+            set
+            {
+                _result = value;
+                _user = _result?.Principal;
+            }
+        }
+
+        public ClaimsPrincipal? User
+        {
+            get => _user;
+            set
+            {
+                _user = value;
+                _result = null;
+            }
+        }
+    }
+}

+ 21 - 5
src/Security/Authorization/Policy/src/AuthorizationMiddleware.cs

@@ -3,8 +3,10 @@
 
 using System;
 using System.Threading.Tasks;
+using Microsoft.AspNetCore.Authentication;
 using Microsoft.AspNetCore.Authorization.Policy;
 using Microsoft.AspNetCore.Http;
+using Microsoft.AspNetCore.Http.Features.Authentication;
 using Microsoft.Extensions.DependencyInjection;
 
 namespace Microsoft.AspNetCore.Authorization
@@ -29,7 +31,7 @@ namespace Microsoft.AspNetCore.Authorization
         /// </summary>
         /// <param name="next">The next middleware in the application middleware pipeline.</param>
         /// <param name="policyProvider">The <see cref="IAuthorizationPolicyProvider"/>.</param>
-        public AuthorizationMiddleware(RequestDelegate next, IAuthorizationPolicyProvider policyProvider) 
+        public AuthorizationMiddleware(RequestDelegate next, IAuthorizationPolicyProvider policyProvider)
         {
             _next = next ?? throw new ArgumentNullException(nameof(next));
             _policyProvider = policyProvider ?? throw new ArgumentNullException(nameof(policyProvider));
@@ -64,12 +66,26 @@ namespace Microsoft.AspNetCore.Authorization
                 return;
             }
 
-            // Policy evaluator has transient lifetime so it fetched from request services instead of injecting in constructor
+            // Policy evaluator has transient lifetime so it's fetched from request services instead of injecting in constructor
             var policyEvaluator = context.RequestServices.GetRequiredService<IPolicyEvaluator>();
 
             var authenticateResult = await policyEvaluator.AuthenticateAsync(policy, context);
 
-            // Allow Anonymous skips all authorization
+            if (authenticateResult?.Succeeded ?? false)
+            {
+                if (context.Features.Get<IAuthenticateResultFeature>() is IAuthenticateResultFeature authenticateResultFeature)
+                {
+                    authenticateResultFeature.AuthenticateResult = authenticateResult;
+                }
+                else
+                {
+                    var authFeatures = new AuthenticationFeatures(authenticateResult);
+                    context.Features.Set<IHttpAuthenticationFeature>(authFeatures);
+                    context.Features.Set<IAuthenticateResultFeature>(authFeatures);
+                }
+            }
+
+            // Allow Anonymous still wants to run authorization to populate the User but skips any failure/challenge handling
             if (endpoint?.Metadata.GetMetadata<IAllowAnonymous>() != null)
             {
                 await _next(context);
@@ -85,8 +101,8 @@ namespace Microsoft.AspNetCore.Authorization
             {
                 resource = context;
             }
-            
-            var authorizeResult = await policyEvaluator.AuthorizeAsync(policy, authenticateResult, context, resource);
+
+            var authorizeResult = await policyEvaluator.AuthorizeAsync(policy, authenticateResult!, context, resource);
             var authorizationMiddlewareResultHandler = context.RequestServices.GetRequiredService<IAuthorizationMiddlewareResultHandler>();
             await authorizationMiddlewareResultHandler.HandleAsync(_next, context, policy, authorizeResult);
         }

+ 11 - 1
src/Security/Authorization/Policy/src/PolicyEvaluator.cs

@@ -38,19 +38,29 @@ namespace Microsoft.AspNetCore.Authorization.Policy
             if (policy.AuthenticationSchemes != null && policy.AuthenticationSchemes.Count > 0)
             {
                 ClaimsPrincipal? newPrincipal = null;
+                DateTimeOffset? minExpiresUtc = null;
                 foreach (var scheme in policy.AuthenticationSchemes)
                 {
                     var result = await context.AuthenticateAsync(scheme);
                     if (result != null && result.Succeeded)
                     {
                         newPrincipal = SecurityHelper.MergeUserPrincipal(newPrincipal, result.Principal);
+
+                        if (minExpiresUtc is null || result.Properties?.ExpiresUtc < minExpiresUtc)
+                        {
+                            minExpiresUtc = result.Properties?.ExpiresUtc;
+                        }
                     }
                 }
 
                 if (newPrincipal != null)
                 {
                     context.User = newPrincipal;
-                    return AuthenticateResult.Success(new AuthenticationTicket(newPrincipal, string.Join(";", policy.AuthenticationSchemes)));
+                    var ticket = new AuthenticationTicket(newPrincipal, string.Join(";", policy.AuthenticationSchemes));
+                    // ExpiresUtc is the easiest property to reason about when dealing with multiple schemes
+                    // SignalR will use this property to evaluate auth expiration for long running connections
+                    ticket.Properties.ExpiresUtc = minExpiresUtc;
+                    return AuthenticateResult.Success(ticket);
                 }
                 else
                 {

+ 177 - 4
src/Security/Authorization/test/AuthorizationMiddlewareTests.cs

@@ -2,6 +2,7 @@
 // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
 
 using System;
+using System.Linq;
 using System.Security.Claims;
 using System.Threading.Tasks;
 using Microsoft.AspNetCore.Authentication;
@@ -9,7 +10,6 @@ using Microsoft.AspNetCore.Authorization.Policy;
 using Microsoft.AspNetCore.Authorization.Test.TestObjects;
 using Microsoft.AspNetCore.Http;
 using Microsoft.Extensions.DependencyInjection;
-using Microsoft.Extensions.Options;
 using Moq;
 using Xunit;
 
@@ -54,7 +54,7 @@ namespace Microsoft.AspNetCore.Authorization.Test
             // Assert
             Assert.False(next.Called);
         }
-        
+
         [Fact]
         public async Task HasEndpointWithoutAuth_AnonymousUser_Allows()
         {
@@ -156,7 +156,7 @@ namespace Microsoft.AspNetCore.Authorization.Test
             Assert.False(next.Called);
             Assert.True(authenticationService.ChallengeCalled);
         }
-        
+
         [Fact]
         public async Task HasEndpointWithAuth_AnonymousUser_ChallengePerScheme()
         {
@@ -367,7 +367,7 @@ namespace Microsoft.AspNetCore.Authorization.Test
             // Assert
             Assert.Equal(endpoint, resource);
         }
-        
+
         [Fact]
         public async Task Invoke_RequireUnknownRoleShouldForbid()
         {
@@ -435,6 +435,179 @@ namespace Microsoft.AspNetCore.Authorization.Test
             Assert.True(authenticationService.ForbidCalled);
         }
 
+        [Fact]
+        public async Task IAuthenticateResultFeature_SetOnSuccessfulAuthorize()
+        {
+            // Arrange
+            var policy = new AuthorizationPolicyBuilder().RequireClaim("Permission", "CanViewPage").Build();
+            var policyProvider = new Mock<IAuthorizationPolicyProvider>();
+            policyProvider.Setup(p => p.GetDefaultPolicyAsync()).ReturnsAsync(policy);
+            var next = new TestRequestDelegate();
+
+            var middleware = CreateMiddleware(next.Invoke, policyProvider.Object);
+            var context = GetHttpContext(endpoint: CreateEndpoint(new AuthorizeAttribute()));
+
+            // Act
+            await middleware.Invoke(context);
+
+            // Assert
+            Assert.True(next.Called);
+            var authenticateResultFeature = context.Features.Get<IAuthenticateResultFeature>();
+            Assert.NotNull(authenticateResultFeature);
+            Assert.NotNull(authenticateResultFeature.AuthenticateResult);
+            Assert.Same(context.User, authenticateResultFeature.AuthenticateResult.Principal);
+        }
+
+        [Fact]
+        public async Task IAuthenticateResultFeature_NotSetOnUnsuccessfulAuthorize()
+        {
+            // Arrange
+            var policy = new AuthorizationPolicyBuilder().RequireRole("Wut").AddAuthenticationSchemes("NotImplemented").Build();
+            var policyProvider = new Mock<IAuthorizationPolicyProvider>();
+            policyProvider.Setup(p => p.GetDefaultPolicyAsync()).ReturnsAsync(policy);
+            var next = new TestRequestDelegate();
+            var authenticationService = new TestAuthenticationService();
+
+            var middleware = CreateMiddleware(next.Invoke, policyProvider.Object);
+            var context = GetHttpContext(endpoint: CreateEndpoint(new AuthorizeAttribute(), new AllowAnonymousAttribute()), authenticationService: authenticationService);
+
+            // Act
+            await middleware.Invoke(context);
+
+            // Assert
+            Assert.True(next.Called);
+            var authenticateResultFeature = context.Features.Get<IAuthenticateResultFeature>();
+            Assert.Null(authenticateResultFeature);
+            Assert.True(authenticationService.AuthenticateCalled);
+        }
+
+        [Fact]
+        public async Task IAuthenticateResultFeature_ContainsLowestExpiration()
+        {
+            // Arrange
+            var policy = new AuthorizationPolicyBuilder().RequireRole("Wut").AddAuthenticationSchemes("Basic", "Bearer").Build();
+            var policyProvider = new Mock<IAuthorizationPolicyProvider>();
+            policyProvider.Setup(p => p.GetDefaultPolicyAsync()).ReturnsAsync(policy);
+            var next = new TestRequestDelegate();
+
+            var firstExpiration = new DateTimeOffset(2021, 5, 12, 2, 3, 4, TimeSpan.Zero);
+            var secondExpiration = new DateTimeOffset(2021, 5, 11, 2, 3, 4, TimeSpan.Zero);
+            var authenticationService = new Mock<IAuthenticationService>();
+            authenticationService.Setup(s => s.AuthenticateAsync(It.IsAny<HttpContext>(), "Basic"))
+                .ReturnsAsync((HttpContext c, string scheme) =>
+                {
+                    var res = AuthenticateResult.Success(new AuthenticationTicket(new ClaimsPrincipal(c.User.Identities.FirstOrDefault(i => i.AuthenticationType == scheme)), scheme));
+                    res.Properties.ExpiresUtc = firstExpiration;
+                    return res;
+                });
+            authenticationService.Setup(s => s.AuthenticateAsync(It.IsAny<HttpContext>(), "Bearer"))
+                .ReturnsAsync((HttpContext c, string scheme) =>
+                {
+                    var res = AuthenticateResult.Success(new AuthenticationTicket(new ClaimsPrincipal(c.User.Identities.FirstOrDefault(i => i.AuthenticationType == scheme)), scheme));
+                    res.Properties.ExpiresUtc = secondExpiration;
+                    return res;
+                });
+
+            var middleware = CreateMiddleware(next.Invoke, policyProvider.Object);
+            var context = GetHttpContext(endpoint: CreateEndpoint(new AuthorizeAttribute()), authenticationService: authenticationService.Object);
+
+            // Act
+            await middleware.Invoke(context);
+
+            // Assert
+            var authenticateResultFeature = context.Features.Get<IAuthenticateResultFeature>();
+            Assert.NotNull(authenticateResultFeature);
+            Assert.NotNull(authenticateResultFeature.AuthenticateResult);
+            Assert.Same(context.User, authenticateResultFeature.AuthenticateResult.Principal);
+            Assert.Equal(secondExpiration, authenticateResultFeature.AuthenticateResult?.Properties?.ExpiresUtc);
+        }
+
+        [Fact]
+        public async Task IAuthenticateResultFeature_NullResultWhenUserSetAfter()
+        {
+            // Arrange
+            var policy = new AuthorizationPolicyBuilder().RequireClaim("Permission", "CanViewPage").Build();
+            var policyProvider = new Mock<IAuthorizationPolicyProvider>();
+            policyProvider.Setup(p => p.GetDefaultPolicyAsync()).ReturnsAsync(policy);
+            var next = new TestRequestDelegate();
+
+            var middleware = CreateMiddleware(next.Invoke, policyProvider.Object);
+            var context = GetHttpContext(endpoint: CreateEndpoint(new AuthorizeAttribute()));
+
+            // Act
+            await middleware.Invoke(context);
+
+            // Assert
+            Assert.True(next.Called);
+            var authenticateResultFeature = context.Features.Get<IAuthenticateResultFeature>();
+            Assert.NotNull(authenticateResultFeature);
+            Assert.NotNull(authenticateResultFeature.AuthenticateResult);
+            Assert.Same(context.User, authenticateResultFeature.AuthenticateResult.Principal);
+
+            context.User = new ClaimsPrincipal();
+            Assert.Null(authenticateResultFeature.AuthenticateResult);
+        }
+
+        [Fact]
+        public async Task IAuthenticateResultFeature_SettingResultSetsUser()
+        {
+            // Arrange
+            var policy = new AuthorizationPolicyBuilder().RequireClaim("Permission", "CanViewPage").Build();
+            var policyProvider = new Mock<IAuthorizationPolicyProvider>();
+            policyProvider.Setup(p => p.GetDefaultPolicyAsync()).ReturnsAsync(policy);
+            var next = new TestRequestDelegate();
+
+            var middleware = CreateMiddleware(next.Invoke, policyProvider.Object);
+            var context = GetHttpContext(endpoint: CreateEndpoint(new AuthorizeAttribute()));
+
+            // Act
+            await middleware.Invoke(context);
+
+            // Assert
+            Assert.True(next.Called);
+            var authenticateResultFeature = context.Features.Get<IAuthenticateResultFeature>();
+            Assert.NotNull(authenticateResultFeature);
+            Assert.NotNull(authenticateResultFeature.AuthenticateResult);
+            Assert.Same(context.User, authenticateResultFeature.AuthenticateResult.Principal);
+
+            var newTicket = new AuthenticationTicket(new ClaimsPrincipal(), "");
+            authenticateResultFeature.AuthenticateResult = AuthenticateResult.Success(newTicket);
+            Assert.Same(context.User, newTicket.Principal);
+        }
+
+        class TestAuthResultFeature : IAuthenticateResultFeature
+        {
+            public AuthenticateResult AuthenticateResult { get; set; }
+        }
+
+        [Fact]
+        public async Task IAuthenticateResultFeature_UsesExistingFeature()
+        {
+            // Arrange
+            var policy = new AuthorizationPolicyBuilder().RequireClaim("Permission", "CanViewPage").Build();
+            var policyProvider = new Mock<IAuthorizationPolicyProvider>();
+            policyProvider.Setup(p => p.GetDefaultPolicyAsync()).ReturnsAsync(policy);
+            var next = new TestRequestDelegate();
+
+            var middleware = CreateMiddleware(next.Invoke, policyProvider.Object);
+            var context = GetHttpContext(endpoint: CreateEndpoint(new AuthorizeAttribute()));
+            var testAuthenticateResultFeature = new TestAuthResultFeature();
+            var authenticateResult = AuthenticateResult.Success(new AuthenticationTicket(new ClaimsPrincipal(), ""));
+            testAuthenticateResultFeature.AuthenticateResult = authenticateResult;
+            context.Features.Set<IAuthenticateResultFeature>(testAuthenticateResultFeature);
+
+            // Act
+            await middleware.Invoke(context);
+
+            // Assert
+            Assert.True(next.Called);
+            var authenticateResultFeature = context.Features.Get<IAuthenticateResultFeature>();
+            Assert.NotNull(authenticateResultFeature);
+            Assert.NotNull(authenticateResultFeature.AuthenticateResult);
+            Assert.Same(testAuthenticateResultFeature, authenticateResultFeature);
+            Assert.NotSame(authenticateResult, authenticateResultFeature.AuthenticateResult);
+        }
+
         private AuthorizationMiddleware CreateMiddleware(RequestDelegate requestDelegate = null, IAuthorizationPolicyProvider policyProvider = null)
         {
             requestDelegate = requestDelegate ?? ((context) => Task.CompletedTask);