Просмотр исходного кода

Optimize CancellationToken handling (#51660)

Pent Ploompuu 2 лет назад
Родитель
Сommit
89c2ef34b3

+ 1 - 1
src/Hosting/Hosting/src/GenericHost/GenericWebHostBuilder.cs

@@ -324,7 +324,7 @@ internal sealed class GenericWebHostBuilder : WebHostBuilderBase, ISupportsStart
             {
                 services.Configure<GenericWebHostServiceOptions>(options =>
                 {
-                    options.ConfigureApplication = app => configure(app);
+                    options.ConfigureApplication = configure;
                 });
             }
         });

+ 3 - 15
src/Hosting/Hosting/src/Internal/ApplicationLifetime.cs

@@ -58,7 +58,7 @@ internal sealed class ApplicationLifetime : IApplicationLifetime, Extensions.Hos
         {
             try
             {
-                ExecuteHandlers(_stoppingSource);
+                _stoppingSource.Cancel();
             }
             catch (Exception ex)
             {
@@ -76,7 +76,7 @@ internal sealed class ApplicationLifetime : IApplicationLifetime, Extensions.Hos
     {
         try
         {
-            ExecuteHandlers(_startedSource);
+            _startedSource.Cancel();
         }
         catch (Exception ex)
         {
@@ -93,7 +93,7 @@ internal sealed class ApplicationLifetime : IApplicationLifetime, Extensions.Hos
     {
         try
         {
-            ExecuteHandlers(_stoppedSource);
+            _stoppedSource.Cancel();
         }
         catch (Exception ex)
         {
@@ -102,16 +102,4 @@ internal sealed class ApplicationLifetime : IApplicationLifetime, Extensions.Hos
                                      ex);
         }
     }
-
-    private static void ExecuteHandlers(CancellationTokenSource cancel)
-    {
-        // Noop if this is already cancelled
-        if (cancel.IsCancellationRequested)
-        {
-            return;
-        }
-
-        // Run the cancellation token callbacks
-        cancel.Cancel(throwOnFirstException: false);
-    }
 }

+ 9 - 23
src/Hosting/Hosting/src/Internal/HostedServiceExecutor.cs

@@ -2,32 +2,27 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using Microsoft.Extensions.Hosting;
-using Microsoft.Extensions.Logging;
 
 namespace Microsoft.AspNetCore.Hosting;
 
 internal sealed class HostedServiceExecutor
 {
     private readonly IEnumerable<IHostedService> _services;
-    private readonly ILogger<HostedServiceExecutor> _logger;
 
-    public HostedServiceExecutor(ILogger<HostedServiceExecutor> logger, IEnumerable<IHostedService> services)
+    public HostedServiceExecutor(IEnumerable<IHostedService> services)
     {
-        _logger = logger;
         _services = services;
     }
 
-    public Task StartAsync(CancellationToken token)
+    public async Task StartAsync(CancellationToken token)
     {
-        return ExecuteAsync(service => service.StartAsync(token));
-    }
-
-    public Task StopAsync(CancellationToken token)
-    {
-        return ExecuteAsync(service => service.StopAsync(token), throwOnFirstFailure: false);
+        foreach (var service in _services)
+        {
+            await service.StartAsync(token);
+        }
     }
 
-    private async Task ExecuteAsync(Func<IHostedService, Task> callback, bool throwOnFirstFailure = true)
+    public async Task StopAsync(CancellationToken token)
     {
         List<Exception>? exceptions = null;
 
@@ -35,20 +30,11 @@ internal sealed class HostedServiceExecutor
         {
             try
             {
-                await callback(service);
+                await service.StopAsync(token);
             }
             catch (Exception ex)
             {
-                if (throwOnFirstFailure)
-                {
-                    throw;
-                }
-
-                if (exceptions == null)
-                {
-                    exceptions = new List<Exception>();
-                }
-
+                exceptions ??= [];
                 exceptions.Add(ex);
             }
         }

+ 6 - 13
src/Hosting/Hosting/src/Internal/WebHost.cs

@@ -286,16 +286,9 @@ internal sealed partial class WebHost : IWebHost, IAsyncDisposable
 
         Log.Shutdown(_logger);
 
-        using var timeoutCTS = new CancellationTokenSource(Options.ShutdownTimeout);
-        var timeoutToken = timeoutCTS.Token;
-        if (!cancellationToken.CanBeCanceled)
-        {
-            cancellationToken = timeoutToken;
-        }
-        else
-        {
-            cancellationToken = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutToken).Token;
-        }
+        using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
+        cts.CancelAfter(Options.ShutdownTimeout);
+        cancellationToken = cts.Token;
 
         // Fire IApplicationLifetime.Stopping
         _applicationLifetime?.StopApplication();
@@ -340,17 +333,17 @@ internal sealed partial class WebHost : IWebHost, IAsyncDisposable
         await DisposeServiceProviderAsync(_hostingServiceProvider).ConfigureAwait(false);
     }
 
-    private static async ValueTask DisposeServiceProviderAsync(IServiceProvider? serviceProvider)
+    private static ValueTask DisposeServiceProviderAsync(IServiceProvider? serviceProvider)
     {
         switch (serviceProvider)
         {
             case IAsyncDisposable asyncDisposable:
-                await asyncDisposable.DisposeAsync().ConfigureAwait(false);
-                break;
+                return asyncDisposable.DisposeAsync();
             case IDisposable disposable:
                 disposable.Dispose();
                 break;
         }
+        return default;
     }
 
     private static partial class Log

+ 4 - 12
src/Hosting/Hosting/src/Startup/ConventionBasedStartup.cs

@@ -23,13 +23,9 @@ internal sealed class ConventionBasedStartup : IStartup
         {
             _methods.ConfigureDelegate(app);
         }
-        catch (Exception ex)
+        catch (TargetInvocationException ex)
         {
-            if (ex is TargetInvocationException)
-            {
-                ExceptionDispatchInfo.Capture(ex.InnerException!).Throw();
-            }
-
+            ExceptionDispatchInfo.Capture(ex.InnerException!).Throw();
             throw;
         }
     }
@@ -40,13 +36,9 @@ internal sealed class ConventionBasedStartup : IStartup
         {
             return _methods.ConfigureServicesDelegate(services);
         }
-        catch (Exception ex)
+        catch (TargetInvocationException ex)
         {
-            if (ex is TargetInvocationException)
-            {
-                ExceptionDispatchInfo.Capture(ex.InnerException!).Throw();
-            }
-
+            ExceptionDispatchInfo.Capture(ex.InnerException!).Throw();
             throw;
         }
     }

+ 1 - 1
src/Hosting/Hosting/src/WebHostBuilderExtensions.cs

@@ -46,7 +46,7 @@ public static class WebHostBuilderExtensions
         {
             services.AddSingleton<IStartup>(sp =>
             {
-                return new DelegateStartup(sp.GetRequiredService<IServiceProviderFactory<IServiceCollection>>(), (app => configureApp(app)));
+                return new DelegateStartup(sp.GetRequiredService<IServiceProviderFactory<IServiceCollection>>(), configureApp);
             });
         });
     }

+ 2 - 8
src/Hosting/Hosting/src/WebHostExtensions.cs

@@ -159,14 +159,8 @@ public static class WebHostExtensions
         },
         applicationLifetime);
 
-        var waitForStop = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
-        applicationLifetime.ApplicationStopping.Register(obj =>
-        {
-            var tcs = (TaskCompletionSource)obj!;
-            tcs.TrySetResult();
-        }, waitForStop);
-
-        await waitForStop.Task;
+        await Task.Delay(Timeout.Infinite, applicationLifetime.ApplicationStopping)
+            .ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing | ConfigureAwaitOptions.ContinueOnCapturedContext);
 
         // WebHost will use its default ShutdownTimeout if none is specified.
 #pragma warning disable CA2016 // Forward the 'CancellationToken' parameter to methods. StopAsync should not be canceled by the token to RunAsync.

+ 2 - 3
src/Http/Http/src/Timeouts/RequestTimeoutsMiddleware.cs

@@ -107,10 +107,9 @@ internal sealed class RequestTimeoutsMiddleware
                 await _next(context);
             }
             catch (OperationCanceledException operationCanceledException)
+            when (linkedCts.IsCancellationRequested && !originalToken.IsCancellationRequested)
             {
-                if (context.Response.HasStarted ||
-                    !linkedCts.Token.IsCancellationRequested ||
-                    originalToken.IsCancellationRequested)
+                if (context.Response.HasStarted)
                 {
                     // We can't produce a response, or it wasn't our timeout that caused this.
                     throw;

+ 13 - 25
src/Middleware/Session/src/DistributedSession.cs

@@ -218,9 +218,9 @@ public class DistributedSession : ISession
         // This will throw if called directly and a failure occurs. The user is expected to handle the failures.
         if (!_loaded)
         {
-            using (var timeout = new CancellationTokenSource(_ioTimeout))
+            using (var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken))
             {
-                var cts = CancellationTokenSource.CreateLinkedTokenSource(timeout.Token, cancellationToken);
+                cts.CancelAfter(_ioTimeout);
                 try
                 {
                     cts.Token.ThrowIfCancellationRequested();
@@ -234,14 +234,10 @@ public class DistributedSession : ISession
                         _logger.AccessingExpiredSession(_sessionKey);
                     }
                 }
-                catch (OperationCanceledException oex)
+                catch (OperationCanceledException oex) when (!cancellationToken.IsCancellationRequested && cts.IsCancellationRequested)
                 {
-                    if (timeout.Token.IsCancellationRequested)
-                    {
-                        _logger.SessionLoadingTimeout();
-                        throw new OperationCanceledException("Timed out loading the session.", oex, timeout.Token);
-                    }
-                    throw;
+                    _logger.SessionLoadingTimeout();
+                    throw new OperationCanceledException("Timed out loading the session.", oex, cts.Token);
                 }
             }
             _isAvailable = true;
@@ -252,9 +248,9 @@ public class DistributedSession : ISession
     /// <inheritdoc />
     public async Task CommitAsync(CancellationToken cancellationToken = default)
     {
-        using (var timeout = new CancellationTokenSource(_ioTimeout))
+        using (var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken))
         {
-            var cts = CancellationTokenSource.CreateLinkedTokenSource(timeout.Token, cancellationToken);
+            cts.CancelAfter(_ioTimeout);
             if (_isModified)
             {
                 if (_logger.IsEnabled(LogLevel.Information))
@@ -293,14 +289,10 @@ public class DistributedSession : ISession
                     _isModified = false;
                     _logger.SessionStored(_sessionKey, Id, _store.Count);
                 }
-                catch (OperationCanceledException oex)
+                catch (OperationCanceledException oex) when (!cancellationToken.IsCancellationRequested && cts.IsCancellationRequested)
                 {
-                    if (timeout.Token.IsCancellationRequested)
-                    {
-                        _logger.SessionCommitTimeout();
-                        throw new OperationCanceledException("Timed out committing the session.", oex, timeout.Token);
-                    }
-                    throw;
+                    _logger.SessionCommitTimeout();
+                    throw new OperationCanceledException("Timed out committing the session.", oex, cts.Token);
                 }
             }
             else
@@ -309,14 +301,10 @@ public class DistributedSession : ISession
                 {
                     await _cache.RefreshAsync(_sessionKey, cts.Token);
                 }
-                catch (OperationCanceledException oex)
+                catch (OperationCanceledException oex) when (!cancellationToken.IsCancellationRequested && cts.IsCancellationRequested)
                 {
-                    if (timeout.Token.IsCancellationRequested)
-                    {
-                        _logger.SessionRefreshTimeout();
-                        throw new OperationCanceledException("Timed out refreshing the session.", oex, timeout.Token);
-                    }
-                    throw;
+                    _logger.SessionRefreshTimeout();
+                    throw new OperationCanceledException("Timed out refreshing the session.", oex, cts.Token);
                 }
             }
         }

+ 4 - 4
src/Middleware/Spa/SpaProxy/src/SpaProxyLaunchManager.cs

@@ -74,8 +74,8 @@ internal sealed class SpaProxyLaunchManager : IDisposable
     {
         var httpClient = CreateHttpClient();
 
-        using var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(10));
-        using var cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(timeout.Token, cancellationToken);
+        using var cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
+        cancellationTokenSource.CancelAfter(TimeSpan.FromSeconds(10));
         try
         {
             var response = await httpClient.GetAsync(_options.ServerUrl, cancellationTokenSource.Token);
@@ -105,8 +105,8 @@ internal sealed class SpaProxyLaunchManager : IDisposable
 
     private async Task<bool> ProbeSpaDevelopmentServerUrl(HttpClient httpClient, CancellationToken cancellationToken)
     {
-        using var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(10));
-        using var cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(timeout.Token, cancellationToken);
+        using var cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
+        cancellationTokenSource.CancelAfter(TimeSpan.FromSeconds(10));
         try
         {
             var response = await httpClient.GetAsync(_options.ServerUrl, cancellationTokenSource.Token);

+ 2 - 2
src/Servers/Kestrel/Core/src/Internal/AddressBinder.cs

@@ -16,7 +16,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal;
 internal sealed class AddressBinder
 {
     // note this doesn't copy the ListenOptions[], only call this with an array that isn't mutated elsewhere
-    public static async Task BindAsync(ListenOptions[] listenOptions, AddressBindContext context, Func<ListenOptions, ListenOptions> useHttps, CancellationToken cancellationToken)
+    public static Task BindAsync(ListenOptions[] listenOptions, AddressBindContext context, Func<ListenOptions, ListenOptions> useHttps, CancellationToken cancellationToken)
     {
         var strategy = CreateStrategy(
             listenOptions,
@@ -29,7 +29,7 @@ internal sealed class AddressBinder
         context.ServerOptions.OptionsInUse.Clear();
         context.Addresses.Clear();
 
-        await strategy.BindAsync(context, cancellationToken).ConfigureAwait(false);
+        return strategy.BindAsync(context, cancellationToken);
     }
 
     private static IStrategy CreateStrategy(ListenOptions[] listenOptions, string[] addresses, bool preferAddresses, Func<ListenOptions, ListenOptions> useHttps)

+ 2 - 5
src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.cs

@@ -699,12 +699,9 @@ internal partial class Http1Connection : HttpProtocol, IRequestProcessor, IHttpO
         {
             isConsumed = ParseRequest(ref reader);
         }
-        catch (InvalidOperationException)
+        catch (InvalidOperationException) when (_requestProcessingStatus == RequestProcessingStatus.ParsingHeaders)
         {
-            if (_requestProcessingStatus == RequestProcessingStatus.ParsingHeaders)
-            {
-                KestrelBadHttpRequestException.Throw(RequestRejectionReason.MalformedRequestInvalidHeaders);
-            }
+            KestrelBadHttpRequestException.Throw(RequestRejectionReason.MalformedRequestInvalidHeaders);
             throw;
         }
 #pragma warning disable CS0618 // Type or member is obsolete

+ 3 - 3
src/Servers/Kestrel/Core/src/Internal/KestrelServerImpl.cs

@@ -354,8 +354,8 @@ internal sealed class KestrelServerImpl : IServer
                 }
 
                 // 5 is the default value for WebHost's "shutdownTimeoutSeconds", so use that.
-                using var timeoutCts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
-                using var combinedCts = CancellationTokenSource.CreateLinkedTokenSource(_stopCts.Token, timeoutCts.Token);
+                using var cts = CancellationTokenSource.CreateLinkedTokenSource(_stopCts.Token);
+                cts.CancelAfter(TimeSpan.FromSeconds(5));
 
                 // TODO: It would be nice to start binding to new endpoints immediately and reconfigured endpoints as soon
                 // as the unbinding finished for the given endpoint rather than wait for all transports to unbind first.
@@ -364,7 +364,7 @@ internal sealed class KestrelServerImpl : IServer
                 {
                     configsToStop.Add(lo.EndpointConfig!);
                 }
-                await _transportManager.StopEndpointsAsync(configsToStop, combinedCts.Token).ConfigureAwait(false);
+                await _transportManager.StopEndpointsAsync(configsToStop, cts.Token).ConfigureAwait(false);
 
                 foreach (var listenOption in endpointsToStop)
                 {

+ 3 - 8
src/Servers/Kestrel/Core/src/Middleware/HttpsConnectionMiddleware.cs

@@ -259,15 +259,10 @@ internal sealed class HttpsConnectionMiddleware
                 store.Open(OpenFlags.ReadOnly);
                 return store;
             }
-            catch (Exception exception)
+            catch (Exception exception) when (exception is CryptographicException || exception is SecurityException)
             {
-                if (exception is CryptographicException || exception is SecurityException)
-                {
-                    _logger.FailedToOpenStore(storeLocation, exception);
-                    return null;
-                }
-
-                throw;
+                _logger.FailedToOpenStore(storeLocation, exception);
+                return null;
             }
         }
 

+ 2 - 0
src/Servers/Kestrel/Transport.Quic/src/Internal/QuicConnectionContext.cs

@@ -182,11 +182,13 @@ internal partial class QuicConnectionContext : TransportMultiplexedConnection
                 _abortReason?.Throw();
             }
         }
+#if DEBUG
         catch (Exception ex)
         {
             Debug.Fail($"Unexpected exception in {nameof(QuicConnectionContext)}.{nameof(AcceptAsync)}: {ex}");
             throw;
         }
+#endif
 
         // Return null for graceful closure or cancellation.
         return null;