Browse Source

[SignalR] Pass a resource into IPolicyEvaluator for Hub method auth (#11070)

Brennan 6 years ago
parent
commit
5b31a9540a

+ 7 - 0
src/SignalR/server/Core/ref/Microsoft.AspNetCore.SignalR.Core.netcoreapp3.0.cs

@@ -172,6 +172,13 @@ namespace Microsoft.AspNetCore.SignalR
             public void Reset() { }
         }
     }
+    public partial class HubInvocationContext
+    {
+        public HubInvocationContext(Microsoft.AspNetCore.SignalR.HubCallerContext context, string hubMethodName, object[] hubMethodArguments) { }
+        public Microsoft.AspNetCore.SignalR.HubCallerContext Context { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } }
+        public System.Collections.Generic.IReadOnlyList<object> HubMethodArguments { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } }
+        public string HubMethodName { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } }
+    }
     public abstract partial class HubLifetimeManager<THub> where THub : Microsoft.AspNetCore.SignalR.Hub
     {
         protected HubLifetimeManager() { }

+ 4 - 0
src/SignalR/server/Core/src/HubConnectionContext.cs

@@ -59,6 +59,8 @@ namespace Microsoft.AspNetCore.SignalR
             _connectionContext = connectionContext;
             _logger = loggerFactory.CreateLogger<HubConnectionContext>();
             ConnectionAborted = _connectionAbortedTokenSource.Token;
+
+            HubCallerContext = new DefaultHubCallerContext(this);
         }
 
         internal StreamTracker StreamTracker
@@ -75,6 +77,8 @@ namespace Microsoft.AspNetCore.SignalR
             }
         }
 
+        internal HubCallerContext HubCallerContext { get; }
+
         /// <summary>
         /// Gets a <see cref="CancellationToken"/> that notifies when the connection is aborted.
         /// </summary>

+ 21 - 0
src/SignalR/server/Core/src/HubInvocationContext.cs

@@ -0,0 +1,21 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System.Collections.Generic;
+
+namespace Microsoft.AspNetCore.SignalR
+{
+    public class HubInvocationContext
+    {
+        public HubInvocationContext(HubCallerContext context, string hubMethodName, object[] hubMethodArguments)
+        {
+            HubMethodName = hubMethodName;
+            HubMethodArguments = hubMethodArguments;
+            Context = context;
+        }
+
+        public HubCallerContext Context { get; }
+        public string HubMethodName { get; }
+        public IReadOnlyList<object> HubMethodArguments { get; }
+    }
+}

+ 6 - 6
src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs

@@ -221,7 +221,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
             THub hub = null;
             try
             {
-                if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection.User, descriptor.Policies))
+                if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection, descriptor.Policies, descriptor.MethodExecutor.MethodInfo.Name, hubMethodInvocationMessage.Arguments))
                 {
                     Log.HubMethodNotAuthorized(_logger, hubMethodInvocationMessage.Target);
                     await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
@@ -479,11 +479,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal
         private void InitializeHub(THub hub, HubConnectionContext connection)
         {
             hub.Clients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId);
-            hub.Context = new DefaultHubCallerContext(connection);
+            hub.Context = connection.HubCallerContext;
             hub.Groups = _hubContext.Groups;
         }
 
-        private Task<bool> IsHubMethodAuthorized(IServiceProvider provider, ClaimsPrincipal principal, IList<IAuthorizeData> policies)
+        private Task<bool> IsHubMethodAuthorized(IServiceProvider provider, HubConnectionContext hubConnectionContext, IList<IAuthorizeData> policies, string hubMethodName, object[] hubMethodArguments)
         {
             // If there are no policies we don't need to run auth
             if (!policies.Any())
@@ -491,10 +491,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal
                 return TaskCache.True;
             }
 
-            return IsHubMethodAuthorizedSlow(provider, principal, policies);
+            return IsHubMethodAuthorizedSlow(provider, hubConnectionContext.User, policies, new HubInvocationContext(hubConnectionContext.HubCallerContext, hubMethodName, hubMethodArguments));
         }
 
-        private static async Task<bool> IsHubMethodAuthorizedSlow(IServiceProvider provider, ClaimsPrincipal principal, IList<IAuthorizeData> policies)
+        private static async Task<bool> IsHubMethodAuthorizedSlow(IServiceProvider provider, ClaimsPrincipal principal, IList<IAuthorizeData> policies, HubInvocationContext resource)
         {
             var authService = provider.GetRequiredService<IAuthorizationService>();
             var policyProvider = provider.GetRequiredService<IAuthorizationPolicyProvider>();
@@ -503,7 +503,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
             // AuthorizationPolicy.CombineAsync only returns null if there are no policies and we check that above
             Debug.Assert(authorizePolicy != null);
 
-            var authorizationResult = await authService.AuthorizeAsync(principal, authorizePolicy);
+            var authorizationResult = await authService.AuthorizeAsync(principal, resource, authorizePolicy);
             // Only check authorization success, challenge or forbid wouldn't make sense from a hub method invocation
             return authorizationResult.Succeeded;
         }

+ 5 - 0
src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs

@@ -143,6 +143,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
         {
         }
 
+        [Authorize("test")]
+        public void MultiParamAuthMethod(string s1, string s2)
+        {
+        }
+
         public Task SendToAllExcept(string message, IReadOnlyList<string> excludedConnectionIds)
         {
             return Clients.AllExcept(excludedConnectionIds).SendAsync("Send", message);

+ 70 - 0
src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs

@@ -12,9 +12,11 @@ using System.Text;
 using System.Threading.Tasks;
 using MessagePack;
 using MessagePack.Formatters;
+using Microsoft.AspNetCore.Authorization;
 using Microsoft.AspNetCore.Connections;
 using Microsoft.AspNetCore.Http;
 using Microsoft.AspNetCore.Http.Connections.Features;
+using Microsoft.AspNetCore.Http.Connections.Internal;
 using Microsoft.AspNetCore.SignalR.Internal;
 using Microsoft.AspNetCore.SignalR.Protocol;
 using Microsoft.Extensions.DependencyInjection;
@@ -2198,6 +2200,69 @@ namespace Microsoft.AspNetCore.SignalR.Tests
             }
         }
 
+        private class TestAuthHandler : IAuthorizationHandler
+        {
+            public Task HandleAsync(AuthorizationHandlerContext context)
+            {
+                Assert.NotNull(context.Resource);
+                var resource = Assert.IsType<HubInvocationContext>(context.Resource);
+                Assert.Equal(nameof(MethodHub.MultiParamAuthMethod), resource.HubMethodName);
+                Assert.Equal(2, resource.HubMethodArguments?.Count);
+                Assert.Equal("Hello", resource.HubMethodArguments[0]);
+                Assert.Equal("World!", resource.HubMethodArguments[1]);
+                Assert.NotNull(resource.Context);
+                Assert.Equal(context.User, resource.Context.User);
+                Assert.NotNull(resource.Context.GetHttpContext());
+
+                return Task.CompletedTask;
+            }
+        }
+
+        [Fact]
+        public async Task HubMethodWithAuthorizationProvidesResourceToAuthHandlers()
+        {
+            using (StartVerifiableLog())
+            {
+                var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
+                {
+                    services.AddAuthorization(options =>
+                    {
+                        options.AddPolicy("test", policy =>
+                        {
+                            policy.RequireClaim(ClaimTypes.NameIdentifier);
+                            policy.AddAuthenticationSchemes("Default");
+                        });
+                    });
+
+                    services.AddSingleton<IAuthorizationHandler, TestAuthHandler>();
+                }, LoggerFactory);
+
+                var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
+
+                using (var client = new TestClient())
+                {
+                    client.Connection.User.AddIdentity(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
+
+                    // Setup a HttpContext to make sure it flows to the AuthHandler correctly
+                    var httpConnectionContext = new HttpContextFeatureImpl();
+                    httpConnectionContext.HttpContext = new DefaultHttpContext();
+                    client.Connection.Features.Set<IHttpContextFeature>(httpConnectionContext);
+
+                    var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
+
+                    await client.Connected.OrTimeout();
+
+                    var message = await client.InvokeAsync(nameof(MethodHub.MultiParamAuthMethod), "Hello", "World!").OrTimeout();
+
+                    Assert.Null(message.Error);
+
+                    client.Dispose();
+
+                    await connectionHandlerTask.OrTimeout();
+                }
+            }
+        }
+
         [Fact]
         public async Task HubOptionsCanUseCustomJsonSerializerSettings()
         {
@@ -3632,5 +3697,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
             public int Bar { get; }
             public string Foo { get; }
         }
+
+        private class HttpContextFeatureImpl : IHttpContextFeature
+        {
+            public HttpContext HttpContext { get; set; }
+        }
     }
 }