diff --git a/src/libraries/Common/tests/System/IO/DelegateDelegatingStream.cs b/src/libraries/Common/tests/System/IO/DelegateDelegatingStream.cs index f5316f41e3066..53de599d9b1fa 100644 --- a/src/libraries/Common/tests/System/IO/DelegateDelegatingStream.cs +++ b/src/libraries/Common/tests/System/IO/DelegateDelegatingStream.cs @@ -9,6 +9,8 @@ namespace System.IO /// Provides a stream whose implementation is supplied by delegates or by an inner stream. internal sealed class DelegateDelegatingStream : DelegatingStream { + public delegate int ReadSpanDelegate(Span buffer); + public static DelegateDelegatingStream NopDispose(Stream innerStream) => new DelegateDelegatingStream(innerStream) { @@ -27,6 +29,7 @@ public DelegateDelegatingStream(Stream innerStream) : base(innerStream) { } public Func GetPositionFunc { get; set; } public Action SetPositionFunc { get; set; } public Func ReadFunc { get; set; } + public ReadSpanDelegate ReadSpanFunc { get; set; } public Func> ReadAsyncArrayFunc { get; set; } public Func, CancellationToken, ValueTask> ReadAsyncMemoryFunc { get; set; } public Func SeekFunc { get; set; } @@ -48,6 +51,7 @@ public DelegateDelegatingStream(Stream innerStream) : base(innerStream) { } public override long Position => GetPositionFunc != null ? GetPositionFunc() : base.Position; public override int Read(byte[] buffer, int offset, int count) => ReadFunc != null ? ReadFunc(buffer, offset, count) : base.Read(buffer, offset, count); + public override int Read(Span buffer) => ReadSpanFunc != null ? ReadSpanFunc(buffer) : base.Read(buffer); public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => ReadAsyncArrayFunc != null ? ReadAsyncArrayFunc(buffer, offset, count, cancellationToken) : base.ReadAsync(buffer, offset, count, cancellationToken); public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) => ReadAsyncMemoryFunc != null ? ReadAsyncMemoryFunc(buffer, cancellationToken) : base.ReadAsync(buffer, cancellationToken); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingReadStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingReadStream.cs index 6e7264401705d..bf6d08a092366 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingReadStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingReadStream.cs @@ -94,7 +94,7 @@ public override int Read(Span buffer) } // We're only here if we need more data to make forward progress. - _connection.Fill(); + Fill(); // Now that we have more, see if we can get any response data, and if // we can we're done. @@ -210,7 +210,7 @@ private async ValueTask ReadAsyncCore(Memory buffer, CancellationToke } // We're only here if we need more data to make forward progress. - await _connection.FillAsync(async: true).ConfigureAwait(false); + await FillAsync().ConfigureAwait(false); // Now that we have more, see if we can get any response data, and if // we can we're done. @@ -273,7 +273,7 @@ private async Task CopyToAsyncCore(Stream destination, CancellationToken cancell return; } - await _connection.FillAsync(async: true).ConfigureAwait(false); + await FillAsync().ConfigureAwait(false); } } catch (Exception exc) when (CancellationHelper.ShouldWrapInOperationCanceledException(exc, cancellationToken)) @@ -323,7 +323,7 @@ private int ReadChunksFromConnectionBuffer(Span buffer, CancellationTokenR Debug.Assert(_chunkBytesRemaining == 0, $"Expected {nameof(_chunkBytesRemaining)} == 0, got {_chunkBytesRemaining}"); // Read the chunk header line. - if (!_connection.TryReadNextChunkedLine(readingHeader: false, out currentLine)) + if (!_connection.TryReadNextChunkedLine(out currentLine)) { // Could not get a whole line, so we can't parse the chunk header. return default; @@ -379,7 +379,7 @@ private int ReadChunksFromConnectionBuffer(Span buffer, CancellationTokenR case ParsingState.ExpectChunkTerminator: Debug.Assert(_chunkBytesRemaining == 0, $"Expected {nameof(_chunkBytesRemaining)} == 0, got {_chunkBytesRemaining}"); - if (!_connection.TryReadNextChunkedLine(readingHeader: false, out currentLine)) + if (!_connection.TryReadNextChunkedLine(out currentLine)) { return default; } @@ -395,38 +395,23 @@ private int ReadChunksFromConnectionBuffer(Span buffer, CancellationTokenR case ParsingState.ConsumeTrailers: Debug.Assert(_chunkBytesRemaining == 0, $"Expected {nameof(_chunkBytesRemaining)} == 0, got {_chunkBytesRemaining}"); - while (true) + // Consume the receive buffer. If the stream is disposed, pass a null response to avoid + // processing headers for a connection returned to the pool. + if (_connection.ParseHeaders(IsDisposed ? null : _response, isFromTrailer: true)) { - if (!_connection.TryReadNextChunkedLine(readingHeader: true, out currentLine)) - { - break; - } - - if (currentLine.IsEmpty) - { - // Dispose of the registration and then check whether cancellation has been - // requested. This is necessary to make deterministic a race condition between - // cancellation being requested and unregistering from the token. Otherwise, - // it's possible cancellation could be requested just before we unregister and - // we then return a connection to the pool that has been or will be disposed - // (e.g. if a timer is used and has already queued its callback but the - // callback hasn't yet run). - cancellationRegistration.Dispose(); - CancellationHelper.ThrowIfCancellationRequested(cancellationRegistration.Token); - - _state = ParsingState.Done; - _connection.CompleteResponse(); - _connection = null; - - break; - } - // Parse the trailer. - else if (!IsDisposed) - { - // Make sure that we don't inadvertently consume trailing headers - // while draining a connection that's being returned back to the pool. - HttpConnection.ParseHeaderNameValue(_connection, currentLine, _response, isFromTrailer: true); - } + // Dispose of the registration and then check whether cancellation has been + // requested. This is necessary to make deterministic a race condition between + // cancellation being requested and unregistering from the token. Otherwise, + // it's possible cancellation could be requested just before we unregister and + // we then return a connection to the pool that has been or will be disposed + // (e.g. if a timer is used and has already queued its callback but the + // callback hasn't yet run). + cancellationRegistration.Dispose(); + CancellationHelper.ThrowIfCancellationRequested(cancellationRegistration.Token); + + _state = ParsingState.Done; + _connection.CompleteResponse(); + _connection = null; } return default; @@ -528,7 +513,7 @@ public override async ValueTask DrainAsync(int maxDrainBytes) } } - await _connection.FillAsync(async: true).ConfigureAwait(false); + await FillAsync().ConfigureAwait(false); } } finally @@ -537,6 +522,24 @@ public override async ValueTask DrainAsync(int maxDrainBytes) cts?.Dispose(); } } + + private void Fill() + { + Debug.Assert(_connection is not null); + ValueTask fillTask = _state == ParsingState.ConsumeTrailers + ? _connection.FillForHeadersAsync(async: false) + : _connection.FillAsync(async: false); + Debug.Assert(fillTask.IsCompleted); + fillTask.GetAwaiter().GetResult(); + } + + private ValueTask FillAsync() + { + Debug.Assert(_connection is not null); + return _state == ParsingState.ConsumeTrailers + ? _connection.FillForHeadersAsync(async: true) + : _connection.FillAsync(async: true); + } } } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs index d0bef3bc27a23..b3993091a2f6f 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs @@ -615,7 +615,11 @@ public async Task SendAsyncCore(HttpRequestMessage request, // Parse the response status line. var response = new HttpResponseMessage() { RequestMessage = request, Content = new HttpConnectionResponseContent() }; - ParseStatusLine((await ReadNextResponseHeaderLineAsync(async).ConfigureAwait(false)).Span, response); + + while (!ParseStatusLine(response)) + { + await FillForHeadersAsync(async).ConfigureAwait(false); + } if (HttpTelemetry.Log.IsEnabled()) HttpTelemetry.Log.ResponseHeadersStart(); @@ -648,22 +652,22 @@ public async Task SendAsyncCore(HttpRequestMessage request, if (NetEventSource.Log.IsEnabled()) Trace($"Current {response.StatusCode} response is an interim response or not expected, need to read for a final response."); // Discard headers that come with the interim 1xx responses. - // RFC7231: 1xx responses are terminated by the first empty line after the status-line. - while (!IsLineEmpty(await ReadNextResponseHeaderLineAsync(async).ConfigureAwait(false))); + while (!ParseHeaders(response: null, isFromTrailer: false)) + { + await FillForHeadersAsync(async).ConfigureAwait(false); + } // Parse the status line for next response. - ParseStatusLine((await ReadNextResponseHeaderLineAsync(async).ConfigureAwait(false)).Span, response); + while (!ParseStatusLine(response)) + { + await FillForHeadersAsync(async).ConfigureAwait(false); + } } // Parse the response headers. Logic after this point depends on being able to examine headers in the response object. - while (true) + while (!ParseHeaders(response, isFromTrailer: false)) { - ReadOnlyMemory line = await ReadNextResponseHeaderLineAsync(async, foldedHeadersAllowed: true).ConfigureAwait(false); - if (IsLineEmpty(line)) - { - break; - } - ParseHeaderNameValue(this, line.Span, response, isFromTrailer: false); + await FillForHeadersAsync(async).ConfigureAwait(false); } if (HttpTelemetry.Log.IsEnabled()) HttpTelemetry.Log.ResponseHeadersStop(); @@ -908,8 +912,6 @@ private CancellationTokenRegistration RegisterCancellation(CancellationToken can }, _weakThisRef); } - private static bool IsLineEmpty(ReadOnlyMemory line) => line.Length == 0; - private async ValueTask SendRequestContentAsync(HttpRequestMessage request, HttpContentWriteStream stream, bool async, CancellationToken cancellationToken) { // Now that we're sending content, prohibit retries of this request by setting this flag. @@ -972,7 +974,35 @@ private async Task SendRequestContentWithExpect100ContinueAsync( } } - private static void ParseStatusLine(ReadOnlySpan line, HttpResponseMessage response) + private bool ParseStatusLine(HttpResponseMessage response) + { + Span buffer = new Span(_readBuffer, _readOffset, _readLength - _readOffset); + + int lineFeedIndex = buffer.IndexOf((byte)'\n'); + if (lineFeedIndex >= 0) + { + _readOffset += lineFeedIndex + 1; + _allowedReadLineBytes -= lineFeedIndex + 1; + + int carriageReturnIndex = lineFeedIndex - 1; + int length = (uint)carriageReturnIndex < (uint)buffer.Length && buffer[carriageReturnIndex] == '\r' + ? carriageReturnIndex + : lineFeedIndex; + + ParseStatusLineCore(buffer.Slice(0, length), response); + return true; + } + else + { + if (_allowedReadLineBytes <= buffer.Length) + { + ThrowExceededAllowedReadLineBytes(); + } + return false; + } + } + + private static void ParseStatusLineCore(Span line, HttpResponseMessage response) { // We sent the request version as either 1.0 or 1.1. // We expect a response version of the form 1.X, where X is a single digit as per RFC. @@ -1045,97 +1075,209 @@ private static void ParseStatusLine(ReadOnlySpan line, HttpResponseMessage } } - private static void ParseHeaderNameValue(HttpConnection connection, ReadOnlySpan line, HttpResponseMessage response, bool isFromTrailer) + private bool ParseHeaders(HttpResponseMessage? response, bool isFromTrailer) { - Debug.Assert(line.Length > 0); + Span buffer = new Span(_readBuffer, _readOffset, _readLength - _readOffset); + + (bool finished, int bytesConsumed) = ParseHeadersCore(buffer, response, isFromTrailer); - int pos = 0; - while (line[pos] != (byte)':' && line[pos] != (byte)' ') + int bytesScanned = finished ? bytesConsumed : buffer.Length; + if (_allowedReadLineBytes < bytesScanned) { - pos++; - if (pos == line.Length) - { - // Invalid header line that doesn't contain ':'. - throw new HttpRequestException(SR.Format(SR.net_http_invalid_response_header_line, Encoding.ASCII.GetString(line))); - } + ThrowExceededAllowedReadLineBytes(); } - if (pos == 0) + _readOffset += bytesConsumed; + _allowedReadLineBytes -= bytesConsumed; + Debug.Assert(_allowedReadLineBytes >= 0); + + return finished; + } + + private (bool finished, int bytesConsumed) ParseHeadersCore(Span buffer, HttpResponseMessage? response, bool isFromTrailer) + { + int originalBufferLength = buffer.Length; + + while (true) { - // Invalid empty header name. - throw new HttpRequestException(SR.Format(SR.net_http_invalid_response_header_name, "")); + int colIdx = buffer.IndexOfAny((byte)':', (byte)'\n'); + if (colIdx < 0) + { + return (finished: false, bytesConsumed: originalBufferLength - buffer.Length); + } + + if (buffer[colIdx] == '\n') + { + if ((colIdx == 1 && buffer[0] == '\r') || colIdx == 0) + { + return (finished: true, bytesConsumed: originalBufferLength - buffer.Length + colIdx + 1); + } + + ThrowForInvalidHeaderLine(buffer, colIdx); + } + + int valueStartIdx = colIdx + 1; + if ((uint)valueStartIdx >= (uint)buffer.Length) + { + return (finished: false, bytesConsumed: originalBufferLength - buffer.Length); + } + + // Iterate over the value and handle any line folds (new lines followed by SP/HTAB). + // valueIterator refers to the remainder of the buffer that we can still scan for new lines. + Span valueIterator = buffer.Slice(valueStartIdx); + + while (true) + { + int lfIdx = valueIterator.IndexOf((byte)'\n'); + if ((uint)lfIdx >= (uint)valueIterator.Length) + { + return (finished: false, bytesConsumed: originalBufferLength - buffer.Length); + } + + int crIdx = lfIdx - 1; + int crOrLfIdx = (uint)crIdx < (uint)valueIterator.Length && valueIterator[crIdx] == '\r' + ? crIdx + : lfIdx; + + int spIdx = lfIdx + 1; + if ((uint)spIdx >= (uint)valueIterator.Length) + { + return (finished: false, bytesConsumed: originalBufferLength - buffer.Length); + } + + if (valueIterator[spIdx] is not (byte)'\t' and not (byte)' ') + { + // Found the end of the header value. + + if (response is not null) + { + ReadOnlySpan headerName = buffer.Slice(0, valueStartIdx - 1); + ReadOnlySpan headerValue = buffer.Slice(valueStartIdx, buffer.Length - valueIterator.Length + crOrLfIdx - valueStartIdx); + AddResponseHeader(headerName, headerValue, response, isFromTrailer); + } + + buffer = buffer.Slice(buffer.Length - valueIterator.Length + spIdx); + break; + } + + // Found an obs-fold (CRLFHT/CRLFSP). + // Replace the CRLF with SPSP and keep looking for the final newline. + valueIterator[crOrLfIdx] = (byte)' '; + valueIterator[lfIdx] = (byte)' '; + + valueIterator = valueIterator.Slice(spIdx + 1); + } } - if (!HeaderDescriptor.TryGet(line.Slice(0, pos), out HeaderDescriptor descriptor)) + static void ThrowForInvalidHeaderLine(ReadOnlySpan buffer, int newLineIndex) => + throw new HttpRequestException(SR.Format(SR.net_http_invalid_response_header_line, Encoding.ASCII.GetString(buffer.Slice(0, newLineIndex)))); + } + + private void AddResponseHeader(ReadOnlySpan name, ReadOnlySpan value, HttpResponseMessage response, bool isFromTrailer) + { + // Skip trailing whitespace and check for empty length. + while (true) { - // Invalid header name. - throw new HttpRequestException(SR.Format(SR.net_http_invalid_response_header_name, Encoding.ASCII.GetString(line.Slice(0, pos)))); + int spIdx = name.Length - 1; + + if ((uint)spIdx < (uint)name.Length) + { + if (name[spIdx] != ' ') + { + // hot path + break; + } + + name = name.Slice(0, spIdx); + } + else + { + ThrowForEmptyHeaderName(); + } } - if (isFromTrailer && (descriptor.HeaderType & HttpHeaderType.NonTrailing) == HttpHeaderType.NonTrailing) + // Skip leading OWS for value. + // hot path: loop body runs only once. + while (value.Length != 0 && value[0] is (byte)' ' or (byte)'\t') { - // Disallowed trailer fields. - // A recipient MUST ignore fields that are forbidden to be sent in a trailer. - if (NetEventSource.Log.IsEnabled()) connection.Trace($"Stripping forbidden {descriptor.Name} from trailer headers."); - return; + value = value.Slice(1); } - // Eat any trailing whitespace - while (line[pos] == (byte)' ') + // Skip trailing whitespace for value. + while (true) { - pos++; - if (pos == line.Length) + int spIdx = value.Length - 1; + + if ((uint)spIdx >= (uint)value.Length || value[spIdx] != ' ') { - // Invalid header line that doesn't contain ':'. - throw new HttpRequestException(SR.Format(SR.net_http_invalid_response_header_line, Encoding.ASCII.GetString(line))); + // hot path + break; } + + value = value.Slice(0, spIdx); } - if (line[pos++] != ':') + if (!HeaderDescriptor.TryGet(name, out HeaderDescriptor descriptor)) { - // Invalid header line that doesn't contain ':'. - throw new HttpRequestException(SR.Format(SR.net_http_invalid_response_header_line, Encoding.ASCII.GetString(line))); + ThrowForInvalidHeaderName(name); } - // Skip whitespace after colon - while (pos < line.Length && (line[pos] == (byte)' ' || line[pos] == (byte)'\t')) + Encoding? valueEncoding = _pool.Settings._responseHeaderEncodingSelector?.Invoke(descriptor.Name, _currentRequest!); + + HttpHeaderType headerType = descriptor.HeaderType; + + // Request headers returned on the response must be treated as custom headers. + if ((headerType & HttpHeaderType.Request) != 0) { - pos++; + descriptor = descriptor.AsCustomHeader(); } - Debug.Assert(response.RequestMessage != null); - Encoding? valueEncoding = connection._pool.Settings._responseHeaderEncodingSelector?.Invoke(descriptor.Name, response.RequestMessage); + string headerValue; + HttpHeaders headers; - // Note we ignore the return value from TryAddWithoutValidation. If the header can't be added, we silently drop it. - ReadOnlySpan value = line.Slice(pos); if (isFromTrailer) { - string headerValue = descriptor.GetHeaderValue(value, valueEncoding); - response.TrailingHeaders.TryAddWithoutValidation((descriptor.HeaderType & HttpHeaderType.Request) == HttpHeaderType.Request ? descriptor.AsCustomHeader() : descriptor, headerValue); + if ((headerType & HttpHeaderType.NonTrailing) != 0) + { + // Disallowed trailer fields. + // A recipient MUST ignore fields that are forbidden to be sent in a trailer. + return; + } + + headerValue = descriptor.GetHeaderValue(value, valueEncoding); + headers = response.TrailingHeaders; } - else if ((descriptor.HeaderType & HttpHeaderType.Content) == HttpHeaderType.Content) + else if ((headerType & HttpHeaderType.Content) != 0) { - string headerValue = descriptor.GetHeaderValue(value, valueEncoding); - response.Content!.Headers.TryAddWithoutValidation(descriptor, headerValue); + headerValue = descriptor.GetHeaderValue(value, valueEncoding); + headers = response.Content!.Headers; } else { - string headerValue = connection.GetResponseHeaderValueWithCaching(descriptor, value, valueEncoding); + headerValue = GetResponseHeaderValueWithCaching(descriptor, value, valueEncoding); + headers = response.Headers; if (descriptor.Equals(KnownHeaders.KeepAlive)) { // We are intentionally going against RFC to honor the Keep-Alive header even if // we haven't received a Keep-Alive connection token to maximize compat with servers. - connection.ProcessKeepAliveHeader(headerValue); + ProcessKeepAliveHeader(headerValue); } - - // Request headers returned on the response must be treated as custom headers. - response.Headers.TryAddWithoutValidation( - (descriptor.HeaderType & HttpHeaderType.Request) == HttpHeaderType.Request ? descriptor.AsCustomHeader() : descriptor, - headerValue); } + + bool added = headers.TryAddWithoutValidation(descriptor, headerValue); + Debug.Assert(added); + + static void ThrowForEmptyHeaderName() => + throw new HttpRequestException(SR.Format(SR.net_http_invalid_response_header_name, "")); + + static void ThrowForInvalidHeaderName(ReadOnlySpan name) => + throw new HttpRequestException(SR.Format(SR.net_http_invalid_response_header_name, Encoding.ASCII.GetString(name))); } + private void ThrowExceededAllowedReadLineBytes() => + throw new HttpRequestException(SR.Format(SR.net_http_response_headers_exceeded_length, _pool.Settings.MaxResponseHeadersByteLength)); + private void ProcessKeepAliveHeader(string keepAlive) { var parsedValues = new UnvalidatedObjectCollection(); @@ -1516,15 +1658,14 @@ private ValueTask WriteToStreamAsync(ReadOnlyMemory source, bool async) } } - private bool TryReadNextChunkedLine(bool readingHeader, out ReadOnlySpan line) + private bool TryReadNextChunkedLine(out ReadOnlySpan line) { - int maxByteLength = readingHeader ? _allowedReadLineBytes : MaxChunkBytesAllowed; var buffer = new ReadOnlySpan(_readBuffer, _readOffset, _readLength - _readOffset); int lineFeedIndex = buffer.IndexOf((byte)'\n'); if (lineFeedIndex < 0) { - if (buffer.Length < maxByteLength) + if (buffer.Length < MaxChunkBytesAllowed) { line = default; return false; @@ -1533,16 +1674,10 @@ private bool TryReadNextChunkedLine(bool readingHeader, out ReadOnlySpan l else { int bytesConsumed = lineFeedIndex + 1; - int maxBytesRemaining = maxByteLength - bytesConsumed; - if (maxBytesRemaining >= 0) + if (bytesConsumed <= MaxChunkBytesAllowed) { _readOffset += bytesConsumed; - if (readingHeader) - { - _allowedReadLineBytes = maxBytesRemaining; - } - int carriageReturnIndex = lineFeedIndex - 1; int length = (uint)carriageReturnIndex < (uint)buffer.Length && buffer[carriageReturnIndex] == '\r' @@ -1554,117 +1689,7 @@ private bool TryReadNextChunkedLine(bool readingHeader, out ReadOnlySpan l } } - string message = readingHeader - ? SR.Format(SR.net_http_response_headers_exceeded_length, _pool.Settings.MaxResponseHeadersByteLength) - : SR.net_http_chunk_too_large; - - throw new HttpRequestException(message); - } - - private async ValueTask> ReadNextResponseHeaderLineAsync(bool async, bool foldedHeadersAllowed = false) - { - int previouslyScannedBytes = 0; - while (true) - { - int scanOffset = _readOffset + previouslyScannedBytes; - int lfIndex = Array.IndexOf(_readBuffer, (byte)'\n', scanOffset, _readLength - scanOffset); - if (lfIndex >= 0) - { - int startIndex = _readOffset; - int length = lfIndex - startIndex; - if (lfIndex > 0 && _readBuffer[lfIndex - 1] == '\r') - { - length--; - } - - // If this isn't the ending header, we need to account for the possibility - // of folded headers, which per RFC2616 are headers split across multiple - // lines, where the continuation line begins with a space or horizontal tab. - // The feature was deprecated in RFC 7230 3.2.4, but some servers still use it. - if (foldedHeadersAllowed && length > 0) - { - // If the newline is the last character we've buffered, we need at least - // one more character in order to see whether it's space/tab, in which - // case it's a folded header. - if (lfIndex + 1 == _readLength) - { - // The LF is at the end of the buffer, so we need to read more - // to determine whether there's a continuation. We'll read - // and then loop back around again, but to avoid needing to - // rescan the whole header, reposition to one character before - // the newline so that we'll find it quickly. - int backPos = _readBuffer[lfIndex - 1] == '\r' ? lfIndex - 2 : lfIndex - 1; - Debug.Assert(backPos >= 0); - previouslyScannedBytes = backPos - _readOffset; - _allowedReadLineBytes -= backPos - scanOffset; - ThrowIfExceededAllowedReadLineBytes(); - await FillAsync(async).ConfigureAwait(false); - continue; - } - - // We have at least one more character we can look at. - Debug.Assert(lfIndex + 1 < _readLength); - char nextChar = (char)_readBuffer[lfIndex + 1]; - if (nextChar == ' ' || nextChar == '\t') - { - // The next header is a continuation. - - // Folded headers are only allowed within header field values, not within header field names, - // so if we haven't seen a colon, this is invalid. - if (Array.IndexOf(_readBuffer, (byte)':', _readOffset, lfIndex - _readOffset) == -1) - { - throw new HttpRequestException(SR.net_http_invalid_response_header_folder); - } - - // When we return the line, we need the interim newlines filtered out. According - // to RFC 7230 3.2.4, a valid approach to dealing with them is to "replace each - // received obs-fold with one or more SP octets prior to interpreting the field - // value or forwarding the message downstream", so that's what we do. - _readBuffer[lfIndex] = (byte)' '; - if (_readBuffer[lfIndex - 1] == '\r') - { - _readBuffer[lfIndex - 1] = (byte)' '; - } - - // Update how much we've read, and simply go back to search for the next newline. - previouslyScannedBytes = (lfIndex + 1 - _readOffset); - _allowedReadLineBytes -= (lfIndex + 1 - scanOffset); - ThrowIfExceededAllowedReadLineBytes(); - continue; - } - - // Not at the end of a header with a continuation. - } - - // Advance read position past the LF - _allowedReadLineBytes -= lfIndex + 1 - scanOffset; - ThrowIfExceededAllowedReadLineBytes(); - _readOffset = lfIndex + 1; - - return new ReadOnlyMemory(_readBuffer, startIndex, length); - } - - // Couldn't find LF. Read more. Note this may cause _readOffset to change. - previouslyScannedBytes = _readLength - _readOffset; - _allowedReadLineBytes -= _readLength - scanOffset; - ThrowIfExceededAllowedReadLineBytes(); - await FillAsync(async).ConfigureAwait(false); - } - } - - private void ThrowIfExceededAllowedReadLineBytes() - { - if (_allowedReadLineBytes < 0) - { - throw new HttpRequestException(SR.Format(SR.net_http_response_headers_exceeded_length, _pool.Settings.MaxResponseHeadersByteLength)); - } - } - - private void Fill() - { - ValueTask fillTask = FillAsync(async: false); - Debug.Assert(fillTask.IsCompleted); - fillTask.GetAwaiter().GetResult(); + throw new HttpRequestException(SR.net_http_chunk_too_large); } // Does not throw on EOF. Also assumes there is no buffered data. @@ -1729,6 +1754,89 @@ await _stream.ReadAsync(new Memory(_readBuffer, _readLength, _readBuffer.L _readLength += bytesRead; } + private ValueTask FillForHeadersAsync(bool async) + { + // If the read offset is 0, it means we haven't consumed any data since the last FillAsync. + // If so, read until we either find the next new line or we hit the MaxResponseHeadersLength limit. + return _readOffset == 0 + ? ReadUntilEndOfHeaderAsync(async) + : FillAsync(async); + + // This method guarantees that the next call to ParseHeaders will consume at least one header. + // This is the slow path, but guarantees O(n) worst-case parsing complexity. + async ValueTask ReadUntilEndOfHeaderAsync(bool async) + { + int searchOffset = _readLength; + if (searchOffset > 0) + { + // The last character we've buffered could be a new line, + // we just haven't checked the byte following it to see if it's a space or tab. + searchOffset--; + } + + while (true) + { + await FillAsync(async).ConfigureAwait(false); + Debug.Assert(_readOffset == 0); + + // There's no need to search the whole buffer, only look through the new bytes we just read. + if (TryFindEndOfLine(new ReadOnlySpan(_readBuffer, searchOffset, _readLength - searchOffset), out int offset)) + { + break; + } + + searchOffset += offset; + + if (searchOffset != _readLength) + { + Debug.Assert(searchOffset == _readLength - 1 && _readBuffer[searchOffset] == '\n'); + if (_readLength <= 2) + { + // There are no headers - we start off with a new line. + // This is reachable from ChunkedEncodingReadStream if the buffers allign just right and there are no trailing headers. + break; + } + } + + if (_readLength >= _allowedReadLineBytes) + { + ThrowExceededAllowedReadLineBytes(); + } + } + + static bool TryFindEndOfLine(ReadOnlySpan buffer, out int searchOffset) + { + int originalBufferLength = buffer.Length; + + while (true) + { + int newLineOffset = buffer.IndexOf((byte)'\n'); + if (newLineOffset < 0) + { + searchOffset = originalBufferLength; + return false; + } + + int tabOrSpaceIndex = newLineOffset + 1; + if (tabOrSpaceIndex == buffer.Length) + { + // The new line is the last character, read again to make sure it doesn't continue with space or tab. + searchOffset = originalBufferLength - 1; + return false; + } + + if (buffer[tabOrSpaceIndex] is not (byte)'\t' and not (byte)' ') + { + searchOffset = 0; + return true; + } + + buffer = buffer.Slice(tabOrSpaceIndex + 1); + } + } + } + } + private void ReadFromBuffer(Span buffer) { Debug.Assert(buffer.Length <= _readLength - _readOffset); diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs index 6fc99b5764602..4e51d356fc655 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs @@ -6,10 +6,12 @@ using System.IO; using System.IO.Pipes; using System.Linq; +using System.Net.Http.Headers; using System.Net.Quic; using System.Net.Security; using System.Net.Sockets; using System.Net.Test.Common; +using System.Numerics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Security.Authentication; @@ -1231,6 +1233,165 @@ public void Expect100ContinueTimeout_SetAfterUse_Throws() public sealed class SocketsHttpHandler_HttpClientHandler_MaxResponseHeadersLength_Http11 : HttpClientHandler_MaxResponseHeadersLength_Test { public SocketsHttpHandler_HttpClientHandler_MaxResponseHeadersLength_Http11(ITestOutputHelper output) : base(output) { } + + [Theory] + [InlineData(null, 63 * 1024)] + [InlineData(null, 65 * 1024)] + [InlineData(1, 100)] + [InlineData(1, 1024)] + public async Task LargeStatusLine_ThrowsException(int? maxResponseHeadersLength, int statusLineLengthEstimate) + { + await LoopbackServer.CreateClientAndServerAsync(async uri => + { + using HttpClientHandler handler = CreateHttpClientHandler(); + + if (maxResponseHeadersLength.HasValue) + { + handler.MaxResponseHeadersLength = maxResponseHeadersLength.Value; + } + + using HttpClient client = CreateHttpClient(handler); + + if (statusLineLengthEstimate < handler.MaxResponseHeadersLength * 1024L) + { + await client.GetAsync(uri); + } + else + { + Exception e = await Assert.ThrowsAsync(() => client.GetAsync(uri)); + if (!IsWinHttpHandler) + { + Assert.Contains((handler.MaxResponseHeadersLength * 1024).ToString(), e.ToString()); + } + } + }, + async server => + { + try + { + await server.AcceptConnectionSendCustomResponseAndCloseAsync($"HTTP/1.1 200 OK{new string('a', statusLineLengthEstimate)}\r\n\r\n"); + } + catch { } + }); + } + + public static IEnumerable TripleBoolValues() => + from trailing in BoolValues + from async in BoolValues + from lineFolds in BoolValues + select new object[] { trailing, async, lineFolds }; + + [Theory] + [MemberData(nameof(TripleBoolValues))] + public async Task LargeHeaders_TrickledOverTime_ProcessedEfficiently(bool trailingHeaders, bool async, bool lineFolds) + { + Memory responsePrefix = Encoding.ASCII.GetBytes(trailingHeaders + ? "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n0\r\nLong-Header: " + : "HTTP/1.1 200 OK\r\nContent-Length: 0\r\nLong-Header: "); + + bool streamDisposed = false; + bool responseComplete = false; + int readCount = 0; + int fastFillLength = 64 * 1024 * 1024; // 64 MB + + Func, int> readFunc = memory => + { + if (streamDisposed) + { + throw new ObjectDisposedException("Foo"); + } + + if (responseComplete) + { + return 0; + } + + Span buffer = memory.Span; + + if (!responsePrefix.IsEmpty) + { + int toCopy = Math.Min(responsePrefix.Length, buffer.Length); + responsePrefix.Span.Slice(0, toCopy).CopyTo(buffer); + responsePrefix = responsePrefix.Slice(toCopy); + return toCopy; + } + + if (fastFillLength > 0) + { + int toFill = Math.Min(fastFillLength, buffer.Length); + buffer.Slice(0, toFill).Fill((byte)'a'); + fastFillLength -= toFill; + if (lineFolds) + { + for (int i = 0; i < toFill / 10; i++) + { + buffer[i * 10 + 8] = (byte)'\n'; + buffer[i * 10 + 9] = (byte)' '; + } + } + return toFill; + } + + if (++readCount < 500_000) + { + // Slowly trickle data over 500 thousand read calls. + // If the implementation scans the whole buffer after every read, it will have to sift through 32 TB of data. + // As that is not achievable on current hardware within the PassingTestTimeout window, the test would fail. + if (lineFolds && readCount % 10 == 0) + { + buffer[0] = (byte)'\n'; + buffer[1] = (byte)' '; + return 2; + } + else + { + buffer[0] = (byte)'a'; + return 1; + } + } + + responseComplete = true; + + Debug.Assert(buffer.Length >= 4); + return Encoding.ASCII.GetBytes("\r\n\r\n", buffer); + }; + + var responseStream = new DelegateDelegatingStream(Stream.Null) + { + ReadAsyncMemoryFunc = (memory, _) => new ValueTask(readFunc(memory)), + ReadFunc = (array, offset, length) => readFunc(array.AsMemory(offset, length)), + ReadSpanFunc = buffer => + { + byte[] arrayBuffer = new byte[buffer.Length]; + int read = readFunc(arrayBuffer); + arrayBuffer.AsSpan(0, read).CopyTo(buffer); + return read; + }, + DisposeFunc = _ => streamDisposed = true + }; + + using var client = new HttpClient(new SocketsHttpHandler + { + ConnectCallback = (_, _) => new ValueTask(responseStream), + MaxResponseHeadersLength = 1024 * 1024 // 1 GB + }) + { + Timeout = TestHelper.PassingTestTimeout + }; + + var request = new HttpRequestMessage(HttpMethod.Get, "http://foo"); + + using HttpResponseMessage response = async + ? await client.SendAsync(request) + : client.Send(request); + + response.EnsureSuccessStatusCode(); + + HttpHeaders headers = trailingHeaders + ? response.TrailingHeaders + : response.Headers; + Assert.True(headers.NonValidated.Contains("Long-Header")); + } } [ConditionalClass(typeof(SocketsHttpHandler), nameof(SocketsHttpHandler.IsSupported))] @@ -3518,6 +3679,24 @@ public SocketsHttpHandlerTest_HttpClientHandlerTest_Http2(ITestOutputHelper outp public sealed class SocketsHttpHandlerTest_HttpClientHandlerTest_Headers_Http11 : HttpClientHandlerTest_Headers { public SocketsHttpHandlerTest_HttpClientHandlerTest_Headers_Http11(ITestOutputHelper output) : base(output) { } + + [Fact] + public async Task ResponseHeaders_ExtraWhitespace_Trimmed() + { + await LoopbackServer.CreateClientAndServerAsync(async uri => + { + using HttpClient client = CreateHttpClient(); + + using HttpResponseMessage response = await client.GetAsync(uri); + + Assert.True(response.Headers.NonValidated.TryGetValues("foo", out HeaderStringValues value)); + Assert.Equal("bar", Assert.Single(value)); + }, + async server => + { + await server.HandleRequestAsync(headers: new[] { new HttpHeaderData("foo ", " \t bar \r\n ") }); + }); + } } [ConditionalClass(typeof(PlatformDetection), nameof(PlatformDetection.SupportsAlpn))]