Ver código fonte

Delay socket receive/send until first read/flush (#34458)

* Delay socket receive/send until first read/flush
- Today when the socket connection is accepted or connection, we implicitly start reading and writing from the socket. This can prevent certain scenarios where users want to get access to the raw socket before any operations happen (like in hand off scenarios). This change defers the reads and writes until read or flush is called on the transport's IDuplexPipe.
- Added test to make sure we can read from the socket via IConnectionocket feature.
- Added DuplicateAndClose test
David Fowler 4 anos atrás
pai
commit
e2acbb98b7

+ 0 - 1
src/Servers/Kestrel/Transport.Sockets/src/Client/SocketConnectionFactory.cs

@@ -81,7 +81,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets
                 _outputOptions,
                 _options.WaitForDataBeforeAllocatingBuffer);
 
-            socketConnection.Start();
             return socketConnection;
         }
 

+ 25 - 0
src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.DuplexPipe.cs

@@ -0,0 +1,25 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System.IO.Pipelines;
+
+namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal
+{
+    internal sealed partial class SocketConnection
+    {
+        // We could implement this on SocketConnection to remove an extra allocation but this is a
+        // bit cleaner
+        private class SocketDuplexPipe : IDuplexPipe
+        {
+            public SocketDuplexPipe(SocketConnection connection)
+            {
+                Input = new SocketPipeReader(connection);
+                Output = new SocketPipeWriter(connection);
+            }
+
+            public PipeReader Input { get; }
+
+            public PipeWriter Output { get; }
+        }
+    }
+}

+ 77 - 0
src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.PipeReader.cs

@@ -0,0 +1,77 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System.IO.Pipelines;
+
+namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal
+{
+    internal sealed partial class SocketConnection
+    {
+        private class SocketPipeReader : PipeReader
+        {
+            private readonly SocketConnection _socketConnection;
+            private readonly PipeReader _reader;
+
+            public SocketPipeReader(SocketConnection socketConnection)
+            {
+                _socketConnection = socketConnection;
+                _reader = socketConnection.InnerTransport.Input;
+            }
+
+            public override void AdvanceTo(SequencePosition consumed)
+            {
+                _reader.AdvanceTo(consumed);
+            }
+
+            public override void AdvanceTo(SequencePosition consumed, SequencePosition examined)
+            {
+                _reader.AdvanceTo(consumed, examined);
+            }
+
+            public override void CancelPendingRead()
+            {
+                _reader.CancelPendingRead();
+            }
+
+            public override void Complete(Exception? exception = null)
+            {
+                _reader.Complete(exception);
+            }
+
+            public override ValueTask CompleteAsync(Exception? exception = null)
+            {
+                return _reader.CompleteAsync(exception);
+            }
+
+            public override Task CopyToAsync(PipeWriter destination, CancellationToken cancellationToken = default)
+            {
+                _socketConnection.EnsureStarted();
+                return _reader.CopyToAsync(destination, cancellationToken);
+            }
+
+            public override Task CopyToAsync(Stream destination, CancellationToken cancellationToken = default)
+            {
+                _socketConnection.EnsureStarted();
+                return _reader.CopyToAsync(destination, cancellationToken);
+            }
+
+            protected override ValueTask<ReadResult> ReadAtLeastAsyncCore(int minimumSize, CancellationToken cancellationToken)
+            {
+                _socketConnection.EnsureStarted();
+                return _reader.ReadAtLeastAsync(minimumSize, cancellationToken);
+            }
+
+            public override ValueTask<ReadResult> ReadAsync(CancellationToken cancellationToken = default)
+            {
+                _socketConnection.EnsureStarted();
+                return _reader.ReadAsync(cancellationToken);
+            }
+
+            public override bool TryRead(out ReadResult result)
+            {
+                _socketConnection.EnsureStarted();
+                return _reader.TryRead(out result);
+            }
+        }
+    }
+}

+ 68 - 0
src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.PipeWriter.cs

@@ -0,0 +1,68 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System.IO.Pipelines;
+
+namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal
+{
+    internal sealed partial class SocketConnection
+    {
+        private class SocketPipeWriter : PipeWriter
+        {
+            private readonly SocketConnection _socketConnection;
+            private readonly PipeWriter _writer;
+
+            public SocketPipeWriter(SocketConnection socketConnection)
+            {
+                _socketConnection = socketConnection;
+                _writer = socketConnection.InnerTransport.Output;
+            }
+
+            public override bool CanGetUnflushedBytes => _writer.CanGetUnflushedBytes;
+
+            public override long UnflushedBytes => _writer.UnflushedBytes;
+
+            public override void Advance(int bytes)
+            {
+                _writer.Advance(bytes);
+            }
+
+            public override void CancelPendingFlush()
+            {
+                _writer.CancelPendingFlush();
+            }
+
+            public override void Complete(Exception? exception = null)
+            {
+                _writer.Complete(exception);
+            }
+
+            public override ValueTask CompleteAsync(Exception? exception = null)
+            {
+                return _writer.CompleteAsync(exception);
+            }
+
+            public override ValueTask<FlushResult> WriteAsync(ReadOnlyMemory<byte> source, CancellationToken cancellationToken = default)
+            {
+                _socketConnection.EnsureStarted();
+                return _writer.WriteAsync(source, cancellationToken);
+            }
+
+            public override ValueTask<FlushResult> FlushAsync(CancellationToken cancellationToken = default)
+            {
+                _socketConnection.EnsureStarted();
+                return _writer.FlushAsync(cancellationToken);
+            }
+
+            public override Memory<byte> GetMemory(int sizeHint = 0)
+            {
+                return _writer.GetMemory(sizeHint);
+            }
+
+            public override Span<byte> GetSpan(int sizeHint = 0)
+            {
+                return _writer.GetSpan(sizeHint);
+            }
+        }
+    }
+}

+ 17 - 12
src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.cs

@@ -33,6 +33,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal
         private readonly TaskCompletionSource _waitForConnectionClosedTcs = new TaskCompletionSource();
         private bool _connectionClosed;
         private readonly bool _waitForData;
+        private int _connectionStarted;
 
         internal SocketConnection(Socket socket,
                                   MemoryPool<byte> memoryPool,
@@ -67,31 +68,32 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal
 
             var pair = DuplexPipe.CreateConnectionPair(inputOptions, outputOptions);
 
-            // Set the transport and connection id
-            Transport = _originalTransport = pair.Transport;
+            _originalTransport = pair.Transport;
             Application = pair.Application;
 
+            Transport = new SocketDuplexPipe(this);
+
             InitiaizeFeatures();
         }
 
+        public IDuplexPipe InnerTransport => _originalTransport;
+
         public PipeWriter Input => Application.Output;
 
         public PipeReader Output => Application.Input;
 
         public override MemoryPool<byte> MemoryPool { get; }
 
-        public void Start()
+        private void EnsureStarted()
         {
-            try
+            if (_connectionStarted == 1 || Interlocked.CompareExchange(ref _connectionStarted, 1, 0) == 1)
             {
-                // Spawn send and receive logic
-                _receivingTask = DoReceive();
-                _sendingTask = DoSend();
-            }
-            catch (Exception ex)
-            {
-                _trace.LogError(0, ex, $"Unexpected exception in {nameof(SocketConnection)}.{nameof(Start)}.");
+                return;
             }
+
+            // Offload these to avoid potentially blocking the first read/write/flush
+            _receivingTask = Task.Run(DoReceive);
+            _sendingTask = Task.Run(DoSend);
         }
 
         public override void Abort(ConnectionAbortedException abortReason)
@@ -106,6 +108,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal
         // Only called after connection middleware is complete which means the ConnectionClosed token has fired.
         public override async ValueTask DisposeAsync()
         {
+            // Just in case we haven't started the connection, start it here so we can clean up properly.
+            EnsureStarted();
+
             _originalTransport.Input.Complete();
             _originalTransport.Output.Complete();
 
@@ -125,7 +130,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal
             }
             catch (Exception ex)
             {
-                _trace.LogError(0, ex, $"Unexpected exception in {nameof(SocketConnection)}.{nameof(Start)}.");
+                _trace.LogError(0, ex, $"Unexpected exception in {nameof(SocketConnection)}.");
             }
             finally
             {

+ 0 - 2
src/Servers/Kestrel/Transport.Sockets/src/SocketConnectionListener.cs

@@ -136,8 +136,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets
                         setting.OutputOptions,
                         waitForData: _options.WaitForDataBeforeAllocatingBuffer);
 
-                    connection.Start();
-
                     _settingsIndex = (_settingsIndex + 1) % _settingsCount;
 
                     return connection;

+ 160 - 0
src/Servers/Kestrel/test/Sockets.FunctionalTests/SocketTranspotTests.cs

@@ -1,15 +1,22 @@
+using System.Buffers;
+using System.Diagnostics;
 using System.Net;
 using System.Net.Http;
 using System.Net.Sockets;
+using System.Text;
 using System.Threading.Tasks;
 using Microsoft.AspNetCore.Builder;
 using Microsoft.AspNetCore.Connections.Features;
 using Microsoft.AspNetCore.Hosting;
+using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http;
 using Microsoft.AspNetCore.Server.Kestrel.FunctionalTests;
 using Microsoft.AspNetCore.Testing;
 using Microsoft.Extensions.Hosting;
 using Xunit;
 
+using KestrelHttpVersion = Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http.HttpVersion;
+using KestrelHttpMethod = Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http.HttpMethod;
+
 namespace Sockets.FunctionalTests
 {
     public class SocketTranspotTests : LoggedTestBase
@@ -50,5 +57,158 @@ namespace Sockets.FunctionalTests
 
             await host.StopAsync();
         }
+
+        [Fact]
+        public async Task CanReadAndWriteFromSocketFeatureInConnectionMiddleware()
+        {
+            var builder = TransportSelector.GetHostBuilder()
+                .ConfigureWebHost(webHostBuilder =>
+                {
+                    webHostBuilder
+                        .UseKestrel(options =>
+                        {
+                            options.ListenAnyIP(0, lo =>
+                            {
+                                lo.Use(next =>
+                                {
+                                    return async connection =>
+                                    {
+                                        var socket = connection.Features.Get<IConnectionSocketFeature>().Socket;
+                                        Assert.NotNull(socket);
+
+                                        var buffer = new byte[4096];
+
+                                        var read = await socket.ReceiveAsync(buffer, SocketFlags.None);
+
+                                        static void ParseHttp(ReadOnlySequence<byte> data)
+                                        {
+                                            var parser = new HttpParser<ParserHandler>();
+                                            var handler = new ParserHandler();
+
+                                            var reader = new SequenceReader<byte>(data);
+
+                                            // Assume we can parse the HTTP request in a single buffer
+                                            Assert.True(parser.ParseRequestLine(handler, ref reader));
+                                            Assert.True(parser.ParseHeaders(handler, ref reader));
+
+                                            Assert.Equal(KestrelHttpMethod.Get, handler.HttpMethod);
+                                            Assert.Equal(KestrelHttpVersion.Http11, handler.HttpVersion);
+                                        }
+
+                                        ParseHttp(new ReadOnlySequence<byte>(buffer[0..read]));
+
+                                        await socket.SendAsync(Encoding.UTF8.GetBytes("HTTP/1.1 200 OK\r\nConnection: close\r\n\r\n"), SocketFlags.None);
+                                    };
+                                });
+                            });
+                        })
+                        .Configure(app => { });
+                })
+                .ConfigureServices(AddTestLogging);
+
+            using var host = builder.Build();
+            using var client = new HttpClient();
+
+            await host.StartAsync();
+
+            var response = await client.GetAsync($"http://127.0.0.1:{host.GetPort()}/");
+            response.EnsureSuccessStatusCode();
+
+            await host.StopAsync();
+        }
+
+        [ConditionalFact]
+        [OSSkipCondition(OperatingSystems.Linux)]
+        [OSSkipCondition(OperatingSystems.MacOSX)]
+        public async Task CanDuplicateAndCloseSocketFeatureInConnectionMiddleware()
+        {
+            var builder = TransportSelector.GetHostBuilder()
+                .ConfigureWebHost(webHostBuilder =>
+                {
+                    webHostBuilder
+                        .UseKestrel(options =>
+                        {
+                            options.ListenAnyIP(0, lo =>
+                            {
+                                lo.Use(next =>
+                                {
+                                    return async connection =>
+                                    {
+                                        var originalSocket = connection.Features.Get<IConnectionSocketFeature>().Socket;
+                                        Assert.NotNull(originalSocket);
+
+                                        var si = originalSocket.DuplicateAndClose(Process.GetCurrentProcess().Id);
+
+                                        using var socket = new Socket(si);
+                                        var buffer = new byte[4096];
+
+                                        var read = await socket.ReceiveAsync(buffer, SocketFlags.None);
+
+                                        static void ParseHttp(ReadOnlySequence<byte> data)
+                                        {
+                                            var parser = new HttpParser<ParserHandler>();
+                                            var handler = new ParserHandler();
+
+                                            var reader = new SequenceReader<byte>(data);
+
+                                            // Assume we can parse the HTTP request in a single buffer
+                                            Assert.True(parser.ParseRequestLine(handler, ref reader));
+                                            Assert.True(parser.ParseHeaders(handler, ref reader));
+
+                                            Assert.Equal(KestrelHttpMethod.Get, handler.HttpMethod);
+                                            Assert.Equal(KestrelHttpVersion.Http11, handler.HttpVersion);
+                                        }
+
+                                        ParseHttp(new ReadOnlySequence<byte>(buffer[0..read]));
+
+                                        await socket.SendAsync(Encoding.UTF8.GetBytes("HTTP/1.1 200 OK\r\nConnection: close\r\n\r\n"), SocketFlags.None);
+                                    };
+                                });
+                            });
+                        })
+                        .Configure(app => { });
+                })
+                .ConfigureServices(AddTestLogging);
+
+            using var host = builder.Build();
+            using var client = new HttpClient();
+
+            await host.StartAsync();
+
+            var response = await client.GetAsync($"http://127.0.0.1:{host.GetPort()}/");
+            response.EnsureSuccessStatusCode();
+
+            await host.StopAsync();
+        }
+
+        private class ParserHandler : IHttpRequestLineHandler, IHttpHeadersHandler
+        {
+            public KestrelHttpVersion HttpVersion { get; set; }
+            public KestrelHttpMethod HttpMethod { get; set; }
+            public Dictionary<string, string> Headers = new();
+
+            public void OnHeader(ReadOnlySpan<byte> name, ReadOnlySpan<byte> value)
+            {
+                Headers[Encoding.ASCII.GetString(name)] = Encoding.ASCII.GetString(value);
+            }
+
+            public void OnHeadersComplete(bool endStream)
+            {
+            }
+
+            public void OnStartLine(HttpVersionAndMethod versionAndMethod, TargetOffsetPathLength targetPath, Span<byte> startLine)
+            {
+                HttpMethod = versionAndMethod.Method;
+                HttpVersion = versionAndMethod.Version;
+            }
+
+            public void OnStaticIndexedHeader(int index)
+            {
+            }
+
+            public void OnStaticIndexedHeader(int index, ReadOnlySpan<byte> value)
+            {
+            }
+        }
     }
 }