Browse Source

Implement minimal RateLimitingMiddleware (#41008)

* First RateLimiting commit

* More

* ctrl-s

* lil bit

* Small

* Check-in launchSettings.json files from Middleware (#40695)

* sln

* More

* More

* More

* Remove stuff

* rm

* Not SharedFx

* Feedback

* Internal+IVT

* Feedback

* Feedback chunk 1

* Feedback chunk 2

* Small feedback

* Fix extension methods

* Small fix

* Func

* Config status code

* Feedback, add servicecollection extension

* Update API

* Some feedback

* Lil more feedback

* Partially fix test

* Fix/Add tests

* Add another test
William Godbe 3 years ago
parent
commit
4e7e7da88d

+ 38 - 2
AspNetCore.sln

@@ -1698,9 +1698,13 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ResultsOfTGenerator", "src\
 EndProject
 Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "OpenApi", "OpenApi", "{2299CCD8-8F9C-4F2B-A633-9BF4DA81022B}"
 EndProject
-Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNetCore.OpenApi.Tests", "src\OpenApi\test\Microsoft.AspNetCore.OpenApi.Tests.csproj", "{3AEFB466-6310-4F3F-923F-9154224E3629}"
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.OpenApi.Tests", "src\OpenApi\test\Microsoft.AspNetCore.OpenApi.Tests.csproj", "{3AEFB466-6310-4F3F-923F-9154224E3629}"
 EndProject
-Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNetCore.OpenApi", "src\OpenApi\src\Microsoft.AspNetCore.OpenApi.csproj", "{EFC8EA45-572D-4D8D-A597-9045A2D8EC40}"
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.OpenApi", "src\OpenApi\src\Microsoft.AspNetCore.OpenApi.csproj", "{EFC8EA45-572D-4D8D-A597-9045A2D8EC40}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.RateLimiting", "src\Middleware\RateLimiting\src\Microsoft.AspNetCore.RateLimiting.csproj", "{8EE73488-2B92-42BD-96C9-0DD65405C828}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.RateLimiting.Tests", "src\Middleware\RateLimiting\test\Microsoft.AspNetCore.RateLimiting.Tests.csproj", "{41FF4F96-98D2-4482-A2A7-4B179E80D285}"
 EndProject
 Global
 	GlobalSection(SolutionConfigurationPlatforms) = preSolution
@@ -10191,6 +10195,38 @@ Global
 		{EFC8EA45-572D-4D8D-A597-9045A2D8EC40}.Release|x64.Build.0 = Release|Any CPU
 		{EFC8EA45-572D-4D8D-A597-9045A2D8EC40}.Release|x86.ActiveCfg = Release|Any CPU
 		{EFC8EA45-572D-4D8D-A597-9045A2D8EC40}.Release|x86.Build.0 = Release|Any CPU
+		{8EE73488-2B92-42BD-96C9-0DD65405C828}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+		{8EE73488-2B92-42BD-96C9-0DD65405C828}.Debug|Any CPU.Build.0 = Debug|Any CPU
+		{8EE73488-2B92-42BD-96C9-0DD65405C828}.Debug|arm64.ActiveCfg = Debug|Any CPU
+		{8EE73488-2B92-42BD-96C9-0DD65405C828}.Debug|arm64.Build.0 = Debug|Any CPU
+		{8EE73488-2B92-42BD-96C9-0DD65405C828}.Debug|x64.ActiveCfg = Debug|Any CPU
+		{8EE73488-2B92-42BD-96C9-0DD65405C828}.Debug|x64.Build.0 = Debug|Any CPU
+		{8EE73488-2B92-42BD-96C9-0DD65405C828}.Debug|x86.ActiveCfg = Debug|Any CPU
+		{8EE73488-2B92-42BD-96C9-0DD65405C828}.Debug|x86.Build.0 = Debug|Any CPU
+		{8EE73488-2B92-42BD-96C9-0DD65405C828}.Release|Any CPU.ActiveCfg = Release|Any CPU
+		{8EE73488-2B92-42BD-96C9-0DD65405C828}.Release|Any CPU.Build.0 = Release|Any CPU
+		{8EE73488-2B92-42BD-96C9-0DD65405C828}.Release|arm64.ActiveCfg = Release|Any CPU
+		{8EE73488-2B92-42BD-96C9-0DD65405C828}.Release|arm64.Build.0 = Release|Any CPU
+		{8EE73488-2B92-42BD-96C9-0DD65405C828}.Release|x64.ActiveCfg = Release|Any CPU
+		{8EE73488-2B92-42BD-96C9-0DD65405C828}.Release|x64.Build.0 = Release|Any CPU
+		{8EE73488-2B92-42BD-96C9-0DD65405C828}.Release|x86.ActiveCfg = Release|Any CPU
+		{8EE73488-2B92-42BD-96C9-0DD65405C828}.Release|x86.Build.0 = Release|Any CPU
+		{41FF4F96-98D2-4482-A2A7-4B179E80D285}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+		{41FF4F96-98D2-4482-A2A7-4B179E80D285}.Debug|Any CPU.Build.0 = Debug|Any CPU
+		{41FF4F96-98D2-4482-A2A7-4B179E80D285}.Debug|arm64.ActiveCfg = Debug|Any CPU
+		{41FF4F96-98D2-4482-A2A7-4B179E80D285}.Debug|arm64.Build.0 = Debug|Any CPU
+		{41FF4F96-98D2-4482-A2A7-4B179E80D285}.Debug|x64.ActiveCfg = Debug|Any CPU
+		{41FF4F96-98D2-4482-A2A7-4B179E80D285}.Debug|x64.Build.0 = Debug|Any CPU
+		{41FF4F96-98D2-4482-A2A7-4B179E80D285}.Debug|x86.ActiveCfg = Debug|Any CPU
+		{41FF4F96-98D2-4482-A2A7-4B179E80D285}.Debug|x86.Build.0 = Debug|Any CPU
+		{41FF4F96-98D2-4482-A2A7-4B179E80D285}.Release|Any CPU.ActiveCfg = Release|Any CPU
+		{41FF4F96-98D2-4482-A2A7-4B179E80D285}.Release|Any CPU.Build.0 = Release|Any CPU
+		{41FF4F96-98D2-4482-A2A7-4B179E80D285}.Release|arm64.ActiveCfg = Release|Any CPU
+		{41FF4F96-98D2-4482-A2A7-4B179E80D285}.Release|arm64.Build.0 = Release|Any CPU
+		{41FF4F96-98D2-4482-A2A7-4B179E80D285}.Release|x64.ActiveCfg = Release|Any CPU
+		{41FF4F96-98D2-4482-A2A7-4B179E80D285}.Release|x64.Build.0 = Release|Any CPU
+		{41FF4F96-98D2-4482-A2A7-4B179E80D285}.Release|x86.ActiveCfg = Release|Any CPU
+		{41FF4F96-98D2-4482-A2A7-4B179E80D285}.Release|x86.Build.0 = Release|Any CPU
 	EndGlobalSection
 	GlobalSection(SolutionProperties) = preSolution
 		HideSolutionNode = FALSE

+ 1 - 0
eng/ProjectReferences.props

@@ -89,6 +89,7 @@
     <ProjectReferenceProvider Include="Microsoft.AspNetCore.Localization.Routing" ProjectPath="$(RepoRoot)src\Middleware\Localization.Routing\src\Microsoft.AspNetCore.Localization.Routing.csproj" />
     <ProjectReferenceProvider Include="Microsoft.AspNetCore.Localization" ProjectPath="$(RepoRoot)src\Middleware\Localization\src\Microsoft.AspNetCore.Localization.csproj" />
     <ProjectReferenceProvider Include="Microsoft.AspNetCore.MiddlewareAnalysis" ProjectPath="$(RepoRoot)src\Middleware\MiddlewareAnalysis\src\Microsoft.AspNetCore.MiddlewareAnalysis.csproj" />
+    <ProjectReferenceProvider Include="Microsoft.AspNetCore.RateLimiting" ProjectPath="$(RepoRoot)src\Middleware\RateLimiting\src\Microsoft.AspNetCore.RateLimiting.csproj" />
     <ProjectReferenceProvider Include="Microsoft.AspNetCore.ResponseCaching.Abstractions" ProjectPath="$(RepoRoot)src\Middleware\ResponseCaching.Abstractions\src\Microsoft.AspNetCore.ResponseCaching.Abstractions.csproj" />
     <ProjectReferenceProvider Include="Microsoft.AspNetCore.ResponseCaching" ProjectPath="$(RepoRoot)src\Middleware\ResponseCaching\src\Microsoft.AspNetCore.ResponseCaching.csproj" />
     <ProjectReferenceProvider Include="Microsoft.AspNetCore.ResponseCompression" ProjectPath="$(RepoRoot)src\Middleware\ResponseCompression\src\Microsoft.AspNetCore.ResponseCompression.csproj" />

+ 3 - 1
src/Middleware/Middleware.slnf

@@ -76,6 +76,8 @@
       "src\\Middleware\\MiddlewareAnalysis\\samples\\MiddlewareAnalysisSample\\MiddlewareAnalysisSample.csproj",
       "src\\Middleware\\MiddlewareAnalysis\\src\\Microsoft.AspNetCore.MiddlewareAnalysis.csproj",
       "src\\Middleware\\MiddlewareAnalysis\\test\\Microsoft.AspNetCore.MiddlewareAnalysis.Tests.csproj",
+      "src\\Middleware\\RateLimiting\\src\\Microsoft.AspNetCore.RateLimiting.csproj",
+      "src\\Middleware\\RateLimiting\\test\\Microsoft.AspNetCore.RateLimiting.Tests.csproj",
       "src\\Middleware\\ResponseCaching.Abstractions\\src\\Microsoft.AspNetCore.ResponseCaching.Abstractions.csproj",
       "src\\Middleware\\ResponseCaching\\samples\\ResponseCachingSample\\ResponseCachingSample.csproj",
       "src\\Middleware\\ResponseCaching\\src\\Microsoft.AspNetCore.ResponseCaching.csproj",
@@ -115,4 +117,4 @@
       "src\\Servers\\Kestrel\\Transport.Sockets\\src\\Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.csproj"
     ]
   }
-}
+}

+ 19 - 0
src/Middleware/RateLimiting/src/Microsoft.AspNetCore.RateLimiting.csproj

@@ -0,0 +1,19 @@
+<Project Sdk="Microsoft.NET.Sdk">
+
+  <PropertyGroup>
+    <Description>ASP.NET Core middleware for enforcing rate limiting in an application</Description>
+    <TargetFramework>$(DefaultNetCoreTargetFramework)</TargetFramework>
+    <GenerateDocumentationFile>true</GenerateDocumentationFile>
+    <PackageTags>aspnetcore</PackageTags>
+  </PropertyGroup>
+
+  <ItemGroup>
+    <Reference Include="Microsoft.AspNetCore.Http.Abstractions" />
+    <Reference Include="Microsoft.Extensions.Logging.Abstractions" />
+    <Reference Include="Microsoft.Extensions.Options" />
+    <Reference Include="System.Threading.RateLimiting" />
+
+    <Compile Include="$(SharedSourceRoot)ValueStopwatch\*.cs" />
+  </ItemGroup>
+
+</Project>

+ 36 - 0
src/Middleware/RateLimiting/src/NoLimiter.cs

@@ -0,0 +1,36 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Threading.RateLimiting;
+
+namespace Microsoft.AspNetCore.RateLimiting;
+internal class NoLimiter<TResource> : PartitionedRateLimiter<TResource>
+{
+    public override int GetAvailablePermits(TResource resourceID)
+    {
+        return 1;
+    }
+
+    protected override RateLimitLease AcquireCore(TResource resourceID, int permitCount)
+    {
+        return new NoLimiterLease();
+    }
+
+    protected override ValueTask<RateLimitLease> WaitAsyncCore(TResource resourceID, int permitCount, CancellationToken cancellationToken)
+    {
+        return new ValueTask<RateLimitLease>(new NoLimiterLease());
+    }
+}
+
+internal class NoLimiterLease : RateLimitLease
+{
+    public override bool IsAcquired => true;
+
+    public override IEnumerable<string> MetadataNames => new List<string>();
+
+    public override bool TryGetMetadata(string metadataName, out object? metadata)
+    {
+        metadata = null;
+        return false;
+    }
+}

+ 6 - 0
src/Middleware/RateLimiting/src/Properties/AssemblyInfo.cs

@@ -0,0 +1,6 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Runtime.CompilerServices;
+
+[assembly: InternalsVisibleTo("Microsoft.AspNetCore.RateLimiting.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] 

+ 1 - 0
src/Middleware/RateLimiting/src/PublicAPI.Shipped.txt

@@ -0,0 +1 @@
+#nullable enable

+ 11 - 0
src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt

@@ -0,0 +1,11 @@
+Microsoft.AspNetCore.RateLimiting.RateLimiterOptions
+Microsoft.AspNetCore.RateLimiting.RateLimiterOptions.DefaultRejectionStatusCode.get -> int
+Microsoft.AspNetCore.RateLimiting.RateLimiterOptions.DefaultRejectionStatusCode.set -> void
+Microsoft.AspNetCore.RateLimiting.RateLimiterOptions.Limiter.get -> System.Threading.RateLimiting.PartitionedRateLimiter<Microsoft.AspNetCore.Http.HttpContext!>!
+Microsoft.AspNetCore.RateLimiting.RateLimiterOptions.Limiter.set -> void
+Microsoft.AspNetCore.RateLimiting.RateLimiterOptions.OnRejected.get -> System.Func<Microsoft.AspNetCore.Http.HttpContext!, System.Threading.RateLimiting.RateLimitLease!, System.Threading.Tasks.Task!>!
+Microsoft.AspNetCore.RateLimiting.RateLimiterOptions.OnRejected.set -> void
+Microsoft.AspNetCore.RateLimiting.RateLimiterOptions.RateLimiterOptions() -> void
+Microsoft.AspNetCore.RateLimiting.RateLimitingApplicationBuilderExtensions
+static Microsoft.AspNetCore.RateLimiting.RateLimitingApplicationBuilderExtensions.UseRateLimiter(this Microsoft.AspNetCore.Builder.IApplicationBuilder! app) -> Microsoft.AspNetCore.Builder.IApplicationBuilder!
+static Microsoft.AspNetCore.RateLimiting.RateLimitingApplicationBuilderExtensions.UseRateLimiter(this Microsoft.AspNetCore.Builder.IApplicationBuilder! app, Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! options) -> Microsoft.AspNetCore.Builder.IApplicationBuilder!

+ 48 - 0
src/Middleware/RateLimiting/src/RateLimiterOptions.cs

@@ -0,0 +1,48 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Threading.RateLimiting;
+using Microsoft.AspNetCore.Http;
+
+namespace Microsoft.AspNetCore.RateLimiting;
+
+/// <summary>
+/// Specifies options for the rate limiting middleware.
+/// </summary>
+public sealed class RateLimiterOptions
+{
+    // TODO - Provide a default?
+    private PartitionedRateLimiter<HttpContext> _limiter = new NoLimiter<HttpContext>();
+    private Func<HttpContext, RateLimitLease, Task> _onRejected = (context, lease) =>
+    {
+        return Task.CompletedTask;
+    };
+
+    /// <summary>
+    /// Gets or sets the <see cref="PartitionedRateLimiter{TResource}"/>
+    /// </summary>
+    public PartitionedRateLimiter<HttpContext> Limiter
+    {
+        get => _limiter;
+        set => _limiter = value ?? throw new ArgumentNullException(nameof(value));
+    }
+
+    /// <summary>
+    /// Gets or sets a <see cref="Func{HttpContext, RateLimitLease, Task}"/> that handles requests rejected by this middleware.
+    /// </summary>
+    public Func<HttpContext, RateLimitLease, Task> OnRejected
+    {
+        get => _onRejected;
+        set => _onRejected = value ?? throw new ArgumentNullException(nameof(value));
+    }
+
+    /// <summary>
+    /// Gets or sets the default status code to set on the response when a request is rejected.
+    /// Defaults to <see cref="StatusCodes.Status503ServiceUnavailable"/>.
+    /// </summary>
+    /// <remarks>
+    /// This status code will be set before <see cref="OnRejected"/> is called, so any status code set by
+    /// <see cref="OnRejected"/> will "win" over this default.
+    /// </remarks>
+    public int DefaultRejectionStatusCode { get; set; } = StatusCodes.Status503ServiceUnavailable;
+}

+ 39 - 0
src/Middleware/RateLimiting/src/RateLimitingApplicationBuilderExtensions.cs

@@ -0,0 +1,39 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using Microsoft.AspNetCore.Builder;
+using Microsoft.Extensions.Options;
+
+namespace Microsoft.AspNetCore.RateLimiting;
+
+/// <summary>
+/// Extension methods for the RateLimiting middleware.
+/// </summary>
+public static class RateLimitingApplicationBuilderExtensions
+{
+    /// <summary>
+    /// Enables rate limiting for the application.
+    /// </summary>
+    /// <param name="app"></param>
+    /// <returns></returns>
+    public static IApplicationBuilder UseRateLimiter(this IApplicationBuilder app)
+    {
+        ArgumentNullException.ThrowIfNull(app);
+
+        return app.UseMiddleware<RateLimitingMiddleware>();
+    }
+
+    /// <summary>
+    /// Enables rate limiting for the application.
+    /// </summary>
+    /// <param name="app"></param>
+    /// <param name="options"></param>
+    /// <returns></returns>
+    public static IApplicationBuilder UseRateLimiter(this IApplicationBuilder app, RateLimiterOptions options)
+    {
+        ArgumentNullException.ThrowIfNull(app, nameof(app));
+        ArgumentNullException.ThrowIfNull(options, nameof(options));
+
+        return app.UseMiddleware<RateLimitingMiddleware>(Options.Create(options));
+    }
+}

+ 78 - 0
src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs

@@ -0,0 +1,78 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Threading.RateLimiting;
+using Microsoft.AspNetCore.Http;
+using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Options;
+
+namespace Microsoft.AspNetCore.RateLimiting;
+
+/// <summary>
+/// Limits the rate of requests allowed in the application, based on limits set by a user-provided <see cref="PartitionedRateLimiter{TResource}"/>.
+/// </summary>
+internal sealed partial class RateLimitingMiddleware
+{
+    private readonly RequestDelegate _next;
+    private readonly Func<HttpContext, RateLimitLease, Task> _onRejected;
+    private readonly ILogger _logger;
+    private readonly PartitionedRateLimiter<HttpContext> _limiter;
+    private readonly int _rejectionStatusCode;
+
+    /// <summary>
+    /// Creates a new <see cref="RateLimitingMiddleware"/>.
+    /// </summary>
+    /// <param name="next">The <see cref="RequestDelegate"/> representing the next middleware in the pipeline.</param>
+    /// <param name="logger">The <see cref="ILogger"/> used for logging.</param>
+    /// <param name="options">The options for the middleware.</param>
+    public RateLimitingMiddleware(RequestDelegate next, ILogger<RateLimitingMiddleware> logger, IOptions<RateLimiterOptions> options)
+    {
+        _next = next ?? throw new ArgumentNullException(nameof(next));
+
+        _logger = logger ?? throw new ArgumentNullException(nameof(logger));
+
+        _limiter = options.Value.Limiter;
+        _onRejected = options.Value.OnRejected;
+        _rejectionStatusCode = options.Value.DefaultRejectionStatusCode;
+    }
+
+    // TODO - EventSource?
+    /// <summary>
+    /// Invokes the logic of the middleware.
+    /// </summary>
+    /// <param name="context">The <see cref="HttpContext"/>.</param>
+    /// <returns>A <see cref="Task"/> that completes when the request leaves.</returns>
+    public async Task Invoke(HttpContext context)
+    {
+        using var lease = await TryAcquireAsync(context);
+        if (lease.IsAcquired)
+        {
+            await _next(context);
+        }
+        else
+        {
+            RateLimiterLog.RequestRejectedLimitsExceeded(_logger);
+            // OnRejected "wins" over DefaultRejectionStatusCode - we set DefaultRejectionStatusCode first,
+            // then call OnRejected in case it wants to do any further modification of the status code.
+            context.Response.StatusCode = _rejectionStatusCode;
+            await _onRejected(context, lease);
+        }
+    }
+
+    private ValueTask<RateLimitLease> TryAcquireAsync(HttpContext context)
+    {
+        var lease = _limiter.Acquire(context);
+        if (lease.IsAcquired)
+        {
+            return ValueTask.FromResult(lease);
+        }
+
+        return _limiter.WaitAsync(context, cancellationToken: context.RequestAborted);
+    }
+
+    private static partial class RateLimiterLog
+    {
+        [LoggerMessage(1, LogLevel.Debug, "Rate limits exceeded, rejecting this request.", EventName = "RequestRejectedLimitsExceeded")]
+        internal static partial void RequestRejectedLimitsExceeded(ILogger logger);
+    }
+}

+ 11 - 0
src/Middleware/RateLimiting/test/Microsoft.AspNetCore.RateLimiting.Tests.csproj

@@ -0,0 +1,11 @@
+<Project Sdk="Microsoft.NET.Sdk">
+
+  <PropertyGroup>
+    <TargetFramework>$(DefaultNetCoreTargetFramework)</TargetFramework>
+  </PropertyGroup>
+
+  <ItemGroup>
+    <Reference Include="Microsoft.AspNetCore.Http" />
+    <Reference Include="Microsoft.AspNetCore.RateLimiting" />
+  </ItemGroup>
+</Project>

+ 53 - 0
src/Middleware/RateLimiting/test/RateLimitingApplicationBuilderExtensionsTests.cs

@@ -0,0 +1,53 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using Microsoft.AspNetCore.Builder;
+using Microsoft.AspNetCore.Http;
+using Microsoft.AspNetCore.Testing;
+using Microsoft.Extensions.DependencyInjection;
+
+namespace Microsoft.AspNetCore.RateLimiting;
+
+public class RateLimitingApplicationBuilderExtensionsTests : LoggedTest
+{
+
+    [Fact]
+    public void UseRateLimiter_ThrowsOnNullAppBuilder()
+    {
+        Assert.Throws<ArgumentNullException>(() => RateLimitingApplicationBuilderExtensions.UseRateLimiter(null));
+    }
+
+    [Fact]
+    public void UseRateLimiter_ThrowsOnNullOptions()
+    {
+        var appBuilder = new ApplicationBuilder(new ServiceCollection().BuildServiceProvider());
+        Assert.Throws<ArgumentNullException>(() => appBuilder.UseRateLimiter(null));
+    }
+
+    [Fact]
+    public void UseRateLimiter_RespectsOptions()
+    {
+        // These are the options that should get used
+        var options = new RateLimiterOptions();
+        options.DefaultRejectionStatusCode = 429;
+        options.Limiter = new TestPartitionedRateLimiter<HttpContext>(new TestRateLimiter(false));
+
+        // These should not get used
+        var services = new ServiceCollection();
+        services.Configure<RateLimiterOptions>(options =>
+        {
+            options.Limiter = new TestPartitionedRateLimiter<HttpContext>(new TestRateLimiter(false));
+            options.DefaultRejectionStatusCode = 404;
+        });
+        services.AddLogging();
+        var serviceProvider = services.BuildServiceProvider();
+        var appBuilder = new ApplicationBuilder(serviceProvider);
+
+        // Act
+        appBuilder.UseRateLimiter(options);
+        var app = appBuilder.Build();
+        var context = new DefaultHttpContext();
+        app.Invoke(context);
+        Assert.Equal(429, context.Response.StatusCode);
+    }
+}

+ 122 - 0
src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs

@@ -0,0 +1,122 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using Microsoft.AspNetCore.Http;
+using Microsoft.AspNetCore.Testing;
+using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Logging.Abstractions;
+using Microsoft.Extensions.Options;
+
+namespace Microsoft.AspNetCore.RateLimiting;
+
+public class RateLimitingMiddlewareTests : LoggedTest
+{
+    [Fact]
+    public void Ctor_ThrowsExceptionsWhenNullArgs()
+    {
+        var options = CreateOptionsAccessor();
+        options.Value.Limiter = new TestPartitionedRateLimiter<HttpContext>();
+
+        Assert.Throws<ArgumentNullException>(() => new RateLimitingMiddleware(
+            null,
+            new NullLoggerFactory().CreateLogger<RateLimitingMiddleware>(),
+            options));
+
+        Assert.Throws<ArgumentNullException>(() => new RateLimitingMiddleware(c =>
+        {
+            return Task.CompletedTask;
+        },
+        null,
+        options));
+    }
+
+    [Fact]
+    public async Task RequestsCallNextIfAccepted()
+    {
+        var flag = false;
+        var options = CreateOptionsAccessor();
+        options.Value.Limiter = new TestPartitionedRateLimiter<HttpContext>(new TestRateLimiter(true));
+        var middleware = new RateLimitingMiddleware(c =>
+        {
+            flag = true;
+            return Task.CompletedTask;
+        },
+        new NullLoggerFactory().CreateLogger<RateLimitingMiddleware>(),
+        options);
+
+        await middleware.Invoke(new DefaultHttpContext());
+        Assert.True(flag);
+    }
+
+    [Fact]
+    public async Task RequestRejected_CallsOnRejectedAndGives503()
+    {
+        var onRejectedInvoked = false;
+        var options = CreateOptionsAccessor();
+        options.Value.Limiter = new TestPartitionedRateLimiter<HttpContext>(new TestRateLimiter(false));
+        options.Value.OnRejected = (httpContext, lease) =>
+        {
+            onRejectedInvoked = true;
+            return Task.CompletedTask;
+        };
+
+        var middleware = new RateLimitingMiddleware(c =>
+        {
+            return Task.CompletedTask;
+        },
+        new NullLoggerFactory().CreateLogger<RateLimitingMiddleware>(),
+        options);
+
+        var context = new DefaultHttpContext();
+        await middleware.Invoke(context).DefaultTimeout();
+        Assert.True(onRejectedInvoked);
+        Assert.Equal(StatusCodes.Status503ServiceUnavailable, context.Response.StatusCode);
+    }
+
+    [Fact]
+    public async Task RequestRejected_WinsOverDefaultStatusCode()
+    {
+        var onRejectedInvoked = false;
+        var options = CreateOptionsAccessor();
+        options.Value.Limiter = new TestPartitionedRateLimiter<HttpContext>(new TestRateLimiter(false));
+        options.Value.OnRejected = (httpContext, lease) =>
+        {
+            onRejectedInvoked = true;
+            httpContext.Response.StatusCode = 429;
+            return Task.CompletedTask;
+        };
+
+        var middleware = new RateLimitingMiddleware(c =>
+        {
+            return Task.CompletedTask;
+        },
+        new NullLoggerFactory().CreateLogger<RateLimitingMiddleware>(),
+        options);
+
+        var context = new DefaultHttpContext();
+        await middleware.Invoke(context).DefaultTimeout();
+        Assert.True(onRejectedInvoked);
+        Assert.Equal(StatusCodes.Status429TooManyRequests, context.Response.StatusCode);
+    }
+
+    [Fact]
+    public async Task RequestAborted_ThrowsTaskCanceledException()
+    {
+        var options = CreateOptionsAccessor();
+        options.Value.Limiter = new TestPartitionedRateLimiter<HttpContext>(new TestRateLimiter(false));
+
+        var middleware = new RateLimitingMiddleware(c =>
+        {
+            return Task.CompletedTask;
+        },
+        new NullLoggerFactory().CreateLogger<RateLimitingMiddleware>(),
+        options);
+
+        var context = new DefaultHttpContext();
+        context.RequestAborted = new CancellationToken(true);
+        await Assert.ThrowsAsync<TaskCanceledException>(() => middleware.Invoke(context)).DefaultTimeout();
+    }
+
+    private IOptions<RateLimiterOptions> CreateOptionsAccessor() => Options.Create(new RateLimiterOptions());
+
+}

+ 23 - 0
src/Middleware/RateLimiting/test/RateLimitingOptionsTests.cs

@@ -0,0 +1,23 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using Microsoft.AspNetCore.Http;
+
+namespace Microsoft.AspNetCore.RateLimiting;
+
+public class RateLimitingOptionsTests
+{
+    [Fact]
+    public void ThrowsOnNullLimiter()
+    {
+        var options = new RateLimiterOptions();
+        Assert.Throws<ArgumentNullException>(() => options.Limiter = null);
+    }
+
+    [Fact]
+    public void ThrowsOnNullOnRejected()
+    {
+        var options = new RateLimiterOptions();
+        Assert.Throws<ArgumentNullException>(() => options.OnRejected = null);
+    }
+}

+ 85 - 0
src/Middleware/RateLimiting/test/TestPartitionedRateLimiter.cs

@@ -0,0 +1,85 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.RateLimiting;
+using System.Threading.Tasks;
+
+namespace Microsoft.AspNetCore.RateLimiting;
+
+internal class TestPartitionedRateLimiter<TResource> : PartitionedRateLimiter<TResource>
+{
+    private List<RateLimiter> limiters = new List<RateLimiter>();
+
+    public TestPartitionedRateLimiter() { }
+
+    public TestPartitionedRateLimiter(RateLimiter limiter)
+    {
+        limiters.Add(limiter);
+    }
+
+    public void AddLimiter(RateLimiter limiter)
+    {
+        limiters.Add(limiter);
+    }
+
+    public override int GetAvailablePermits(TResource resourceID)
+    {
+        throw new NotImplementedException();
+    }
+
+    protected override RateLimitLease AcquireCore(TResource resourceID, int permitCount)
+    {
+        if (permitCount != 1)
+        {
+            throw new ArgumentException("Tests only support 1 permit at a time");
+        }    
+        var leases = new List<RateLimitLease>();
+        foreach (var limiter in limiters)
+        {
+            var lease = limiter.Acquire();
+            if (lease.IsAcquired)
+            {
+                leases.Add(lease);
+            }
+            else
+            {
+                foreach (var unusedLease in leases)
+                {
+                    unusedLease.Dispose();
+                }
+                return new TestRateLimitLease(false, null);
+            }
+        }
+        return new TestRateLimitLease(true, leases);
+    }
+
+    protected override async ValueTask<RateLimitLease> WaitAsyncCore(TResource resourceID, int permitCount, CancellationToken cancellationToken)
+    {
+        if (permitCount != 1)
+        {
+            throw new ArgumentException("Tests only support 1 permit at a time");
+        }
+        var leases = new List<RateLimitLease>();
+        foreach (var limiter in limiters)
+        {
+            leases.Add(await limiter.WaitAsync().ConfigureAwait(false));
+        }
+        foreach (var lease in leases)
+        {
+            if (!lease.IsAcquired)
+            {
+                foreach (var unusedLease in leases)
+                {
+                    unusedLease.Dispose();
+                }
+                return new TestRateLimitLease(false, null);
+            }    
+        }
+        return new TestRateLimitLease(true, leases);
+
+    }
+}

+ 37 - 0
src/Middleware/RateLimiting/test/TestRateLimitLease.cs

@@ -0,0 +1,37 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Threading.RateLimiting;
+
+namespace Microsoft.AspNetCore.RateLimiting;
+
+internal class TestRateLimitLease : RateLimitLease
+{
+    internal List<RateLimitLease> _leases;
+
+    public TestRateLimitLease(bool isAcquired, List<RateLimitLease> leases)
+    {
+        IsAcquired = isAcquired;
+        _leases = leases;
+    }
+
+    public override bool IsAcquired { get; }
+
+    public override IEnumerable<string> MetadataNames => throw new NotImplementedException();
+
+    public override bool TryGetMetadata(string metadataName, out object metadata)
+    {
+        throw new NotImplementedException();
+    }
+
+    protected override void Dispose(bool disposing)
+    {
+        if (_leases != null)
+        {
+            foreach (var lease in _leases)
+            {
+                lease.Dispose();
+            }
+        }
+    }
+}

+ 33 - 0
src/Middleware/RateLimiting/test/TestRateLimiter.cs

@@ -0,0 +1,33 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Threading.RateLimiting;
+
+namespace Microsoft.AspNetCore.RateLimiting;
+internal class TestRateLimiter : RateLimiter
+{
+    private readonly bool _alwaysAccept;
+
+    public TestRateLimiter(bool alwaysAccept)
+    {
+        _alwaysAccept = alwaysAccept;
+    }
+
+    public override TimeSpan? IdleDuration => throw new NotImplementedException();
+
+    public override int GetAvailablePermits()
+    {
+        throw new NotImplementedException();
+    }
+
+    protected override RateLimitLease AcquireCore(int permitCount)
+    {
+        return new TestRateLimitLease(_alwaysAccept, null);
+    }
+
+    protected override ValueTask<RateLimitLease> WaitAsyncCore(int permitCount, CancellationToken cancellationToken)
+    {
+        cancellationToken.ThrowIfCancellationRequested();
+        return new ValueTask<RateLimitLease>(new TestRateLimitLease(_alwaysAccept, null));
+    }
+}