Browse Source

Kestrel Content-Length handling changes (#43103)

* Forbid Content-Length on 1xx, 204, or any CONNECT 2xx responses.

* Reject non-zero Content-Length for 205 responses.
Aditya Mandaleeka 3 years ago
parent
commit
72ee5732f0

+ 3 - 0
src/Servers/Kestrel/Core/src/CoreStrings.resx

@@ -710,4 +710,7 @@ For more information on configuring HTTPS see https://go.microsoft.com/fwlink/?l
   <data name="DynamicPortOnMultipleTransportsNotSupported" xml:space="preserve">
     <value>Dynamic port binding is not supported when binding multiple transports. HTTP/3 not enabled. A port must be specified to support TCP based HTTP/1.1 and HTTP/2, and QUIC based HTTP/3 with the same endpoint.</value>
   </data>
+  <data name="NonzeroContentLengthNotAllowedOn205" xml:space="preserve">
+    <value>Responses with status code 205 cannot have a non-zero Content-Length value.</value>
+  </data>
 </root>

+ 51 - 7
src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs

@@ -1147,7 +1147,21 @@ internal abstract partial class HttpProtocol : IHttpResponseControl
 
         if (!_canWriteResponseBody && hasTransferEncoding)
         {
-            RejectNonBodyTransferEncodingResponse(appCompleted);
+            RejectInvalidHeaderForNonBodyResponse(appCompleted, HeaderNames.TransferEncoding);
+        }
+        else if (responseHeaders.ContentLength.HasValue)
+        {
+            if (!CanIncludeResponseContentLengthHeader())
+            {
+                RejectInvalidHeaderForNonBodyResponse(appCompleted, HeaderNames.ContentLength);
+            }
+            else if (StatusCode == StatusCodes.Status205ResetContent && responseHeaders.ContentLength.Value != 0)
+            {
+                // It is valid for a 205 response to have a Content-Length but it must be 0
+                // since 205 implies that no additional content will be provided.
+                // https://httpwg.org/specs/rfc7231.html#rfc.section.6.3.6
+                RejectNonzeroContentLengthOn205Response(appCompleted);
+            }
         }
         else if (StatusCode == StatusCodes.Status101SwitchingProtocols)
         {
@@ -1157,7 +1171,6 @@ internal abstract partial class HttpProtocol : IHttpResponseControl
         {
             if ((appCompleted || !_canWriteResponseBody) && !_hasAdvanced) // Avoid setting contentLength of 0 if we wrote data before calling CreateResponseHeaders
             {
-                // Don't set the Content-Length header automatically for HEAD requests, 204 responses, or 304 responses.
                 if (CanAutoSetContentLengthZeroResponseHeader())
                 {
                     // Since the app has completed writing or cannot write to the response, we can safely set the Content-Length to 0.
@@ -1219,6 +1232,28 @@ internal abstract partial class HttpProtocol : IHttpResponseControl
         return responseHeaders;
     }
 
+    private bool CanIncludeResponseContentLengthHeader()
+    {
+        // Section 4.3.6 of RFC7231
+        if (Is1xxCode(StatusCode) || StatusCode == StatusCodes.Status204NoContent)
+        {
+            // A server MUST NOT send a Content-Length header field in any response
+            // with a status code of 1xx (Informational) or 204 (No Content).
+            return false;
+        }
+        else if (Method == HttpMethod.Connect && Is2xxCode(StatusCode))
+        {
+            // A server MUST NOT send a Content-Length header field in any 2xx
+            // (Successful) response to a CONNECT request.
+            return false;
+        }
+
+        return true;
+
+        static bool Is1xxCode(int code) => code >= StatusCodes.Status100Continue && code < StatusCodes.Status200OK;
+        static bool Is2xxCode(int code) => code >= StatusCodes.Status200OK && code < StatusCodes.Status300MultipleChoices;
+    }
+
     private bool CanWriteResponseBody()
     {
         // List of status codes taken from Microsoft.Net.Http.Server.Response
@@ -1230,9 +1265,12 @@ internal abstract partial class HttpProtocol : IHttpResponseControl
 
     private bool CanAutoSetContentLengthZeroResponseHeader()
     {
-        return Method != HttpMethod.Head &&
-               StatusCode != StatusCodes.Status204NoContent &&
-               StatusCode != StatusCodes.Status304NotModified;
+        return CanIncludeResponseContentLengthHeader() &&
+            // Responses to HEAD may omit Content-Length (Section 4.3.6 of RFC7231).
+            Method != HttpMethod.Head &&
+            // 304s should only include specific fields, of which Content-Length is
+            // not one (Section 4.1 of RFC7232).
+            StatusCode != StatusCodes.Status304NotModified;
     }
 
     private static void ThrowResponseAlreadyStartedException(string value)
@@ -1240,9 +1278,15 @@ internal abstract partial class HttpProtocol : IHttpResponseControl
         throw new InvalidOperationException(CoreStrings.FormatParameterReadOnlyAfterResponseStarted(value));
     }
 
-    private void RejectNonBodyTransferEncodingResponse(bool appCompleted)
+    private void RejectInvalidHeaderForNonBodyResponse(bool appCompleted, string headerName)
+        => RejectInvalidResponse(appCompleted, CoreStrings.FormatHeaderNotAllowedOnResponse(headerName, StatusCode));
+
+    private void RejectNonzeroContentLengthOn205Response(bool appCompleted)
+        => RejectInvalidResponse(appCompleted, CoreStrings.NonzeroContentLengthNotAllowedOn205);
+
+    private void RejectInvalidResponse(bool appCompleted, string message)
     {
-        var ex = new InvalidOperationException(CoreStrings.FormatHeaderNotAllowedOnResponse("Transfer-Encoding", StatusCode));
+        var ex = new InvalidOperationException(message);
         if (!appCompleted)
         {
             // Back out of header creation surface exception in user code

+ 2 - 2
src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2ConnectionTests.cs

@@ -2671,7 +2671,7 @@ public class Http2ConnectionTests : Http2TestBase
         await SendHeadersAsync(1, Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM, headers);
 
         await ExpectAsync(Http2FrameType.HEADERS,
-            withLength: 36,
+            withLength: 32,
             withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
             withStreamId: 1);
 
@@ -4574,7 +4574,7 @@ public class Http2ConnectionTests : Http2TestBase
         await SendEmptyContinuationFrameAsync(1, Http2ContinuationFrameFlags.END_HEADERS);
 
         await ExpectAsync(Http2FrameType.HEADERS,
-            withLength: 36,
+            withLength: 32,
             withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
             withStreamId: 1);
 

+ 2 - 3
src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs

@@ -282,7 +282,7 @@ public class Http2StreamTests : Http2TestBase
         await SendHeadersAsync(1, Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM, headers);
 
         var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
-            withLength: 52,
+            withLength: 48,
             withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
             withStreamId: 1);
 
@@ -290,11 +290,10 @@ public class Http2StreamTests : Http2TestBase
 
         _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
 
-        Assert.Equal(4, _decodedHeaders.Count);
+        Assert.Equal(3, _decodedHeaders.Count);
         Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
         Assert.Equal("200", _decodedHeaders[InternalHeaderNames.Status]);
         Assert.Equal("CONNECT", _decodedHeaders["Method"]);
-        Assert.Equal("0", _decodedHeaders["content-length"]);
     }
 
     [Fact]

+ 2 - 3
src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2WebSocketTests.cs

@@ -69,7 +69,7 @@ public class Http2WebSocketTests : Http2TestBase
         await SendDataAsync(1, Array.Empty<byte>(), endStream: true);
 
         var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
-            withLength: 36,
+            withLength: 32,
             withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
             withStreamId: 1);
 
@@ -77,10 +77,9 @@ public class Http2WebSocketTests : Http2TestBase
 
         _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
 
-        Assert.Equal(3, _decodedHeaders.Count);
+        Assert.Equal(2, _decodedHeaders.Count);
         Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
         Assert.Equal("200", _decodedHeaders[InternalHeaderNames.Status]);
-        Assert.Equal("0", _decodedHeaders["content-length"]);
     }
 
     [Fact]

+ 1 - 2
src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3StreamTests.cs

@@ -156,11 +156,10 @@ public class Http3StreamTests : Http3TestBase
 
         var responseHeaders = await requestStream.ExpectHeadersAsync();
 
-        Assert.Equal(4, responseHeaders.Count);
+        Assert.Equal(3, responseHeaders.Count);
         Assert.Contains("date", responseHeaders.Keys, StringComparer.OrdinalIgnoreCase);
         Assert.Equal("200", responseHeaders[InternalHeaderNames.Status]);
         Assert.Equal("CONNECT", responseHeaders["Method"]);
-        Assert.Equal("0", responseHeaders["content-length"]);
     }
 
     [Fact]

+ 4 - 12
src/Servers/Kestrel/test/InMemory.FunctionalTests/RequestTargetProcessingTests.cs

@@ -90,11 +90,10 @@ public class RequestTargetProcessingTests : LoggedTest
     }
 
     [Theory]
-    [InlineData((int)HttpMethod.Options, "*")]
-    [InlineData((int)HttpMethod.Connect, "host")]
-    public async Task NonPathRequestTargetSetInRawTarget(int intMethod, string requestTarget)
+    [InlineData(HttpMethod.Options, "*")]
+    [InlineData(HttpMethod.Connect, "host")]
+    public async Task NonPathRequestTargetSetInRawTarget(HttpMethod method, string requestTarget)
     {
-        var method = (HttpMethod)intMethod;
         var testContext = new TestServiceContext(LoggerFactory);
 
         await using (var server = new TestServer(async context =>
@@ -104,8 +103,7 @@ public class RequestTargetProcessingTests : LoggedTest
             Assert.Empty(context.Request.PathBase.Value);
             Assert.Empty(context.Request.QueryString.Value);
 
-            context.Response.Headers["Content-Length"] = new[] { "11" };
-            await context.Response.WriteAsync("Hello World");
+            await context.Response.CompleteAsync();
         }, testContext))
         {
             using (var connection = server.CreateConnection())
@@ -119,12 +117,6 @@ public class RequestTargetProcessingTests : LoggedTest
                     $"Host: {host}",
                     "",
                     "");
-                await connection.Receive(
-                    "HTTP/1.1 200 OK",
-                    "Content-Length: 11",
-                    $"Date: {testContext.DateHeaderValue}",
-                    "",
-                    "Hello World");
             }
         }
     }

+ 124 - 0
src/Servers/Kestrel/test/InMemory.FunctionalTests/ResponseTests.cs

@@ -591,6 +591,130 @@ public class ResponseTests : TestApplicationErrorLoggerLoggedTest
         }
     }
 
+    public static IEnumerable<object[]> Get1xxAnd204MethodCombinations()
+    {
+        // Status codes to test
+        var statusCodes = new int[] {
+                StatusCodes.Status100Continue,
+                StatusCodes.Status101SwitchingProtocols,
+                StatusCodes.Status102Processing,
+                StatusCodes.Status204NoContent,
+            };
+
+        // HTTP methods to test
+        var methods = new HttpMethod[] {
+                HttpMethod.Connect,
+                HttpMethod.Delete,
+                HttpMethod.Get,
+                HttpMethod.Head,
+                HttpMethod.Options,
+                HttpMethod.Patch,
+                HttpMethod.Post,
+                HttpMethod.Put,
+                HttpMethod.Trace
+            };
+
+        foreach (var statusCode in statusCodes)
+        {
+            foreach (var method in methods)
+            {
+                yield return new object[] { statusCode, method };
+            }
+        }
+    }
+
+    [Theory]
+    [MemberData(nameof(Get1xxAnd204MethodCombinations))]
+    public async Task AttemptingToWriteContentLengthFailsFor1xxAnd204Responses(int statusCode, HttpMethod method)
+        => await AttemptingToWriteContentLengthFails(statusCode, method).ConfigureAwait(true);
+
+    [Theory]
+    [InlineData(StatusCodes.Status200OK)]
+    [InlineData(StatusCodes.Status201Created)]
+    [InlineData(StatusCodes.Status202Accepted)]
+    [InlineData(StatusCodes.Status203NonAuthoritative)]
+    [InlineData(StatusCodes.Status204NoContent)]
+    [InlineData(StatusCodes.Status205ResetContent)]
+    [InlineData(StatusCodes.Status206PartialContent)]
+    [InlineData(StatusCodes.Status207MultiStatus)]
+    [InlineData(StatusCodes.Status208AlreadyReported)]
+    [InlineData(StatusCodes.Status226IMUsed)]
+    public async Task AttemptingToWriteContentLengthFailsFor2xxResponsesOnConnect(int statusCode)
+        => await AttemptingToWriteContentLengthFails(statusCode, HttpMethod.Connect).ConfigureAwait(true);
+
+    private async Task AttemptingToWriteContentLengthFails(int statusCode, HttpMethod method)
+    {
+        var responseWriteTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+
+        await using (var server = new TestServer(async httpContext =>
+        {
+            httpContext.Response.StatusCode = statusCode;
+            httpContext.Response.Headers.ContentLength = 0;
+
+            try
+            {
+                await httpContext.Response.StartAsync();
+            }
+            catch (Exception ex)
+            {
+                responseWriteTcs.TrySetException(ex);
+                throw;
+            }
+
+            responseWriteTcs.TrySetResult();
+        }, new TestServiceContext(LoggerFactory)))
+        {
+            using (var connection = server.CreateConnection())
+            {
+                await connection.Send(
+                    $"{HttpUtilities.MethodToString(method)} / HTTP/1.1",
+                    "Host:",
+                    "",
+                    "");
+
+                var ex = await Assert.ThrowsAsync<InvalidOperationException>(() => responseWriteTcs.Task).DefaultTimeout();
+                Assert.Equal(CoreStrings.FormatHeaderNotAllowedOnResponse("Content-Length", statusCode), ex.Message);
+            }
+        }
+    }
+
+    [Fact]
+    public async Task AttemptingToWriteNonzeroContentLengthFailsFor205Response()
+    {
+        var responseWriteTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+
+        await using (var server = new TestServer(async httpContext =>
+        {
+            httpContext.Response.StatusCode = 205;
+            httpContext.Response.Headers.ContentLength = 1;
+
+            try
+            {
+                await httpContext.Response.StartAsync();
+            }
+            catch (Exception ex)
+            {
+                responseWriteTcs.TrySetException(ex);
+                throw;
+            }
+
+            responseWriteTcs.TrySetResult();
+        }, new TestServiceContext(LoggerFactory)))
+        {
+            using (var connection = server.CreateConnection())
+            {
+                await connection.Send(
+                    "GET / HTTP/1.1",
+                    "Host:",
+                    "",
+                    "");
+
+                var ex = await Assert.ThrowsAsync<InvalidOperationException>(() => responseWriteTcs.Task).DefaultTimeout();
+                Assert.Equal(CoreStrings.NonzeroContentLengthNotAllowedOn205, ex.Message);
+            }
+        }
+    }
+
     [Theory]
     [InlineData(StatusCodes.Status204NoContent)]
     [InlineData(StatusCodes.Status304NotModified)]