Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HTTP/3] Fix NullReferenceException on cancellation #54334

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ public void Dispose()
if (!_disposed)
{
_disposed = true;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need _disposed, or does _stream being null now indicate disposal?

_stream.Dispose();
DisposeSyncHelper();
var stream = Interlocked.Exchange(ref _stream, null!);
stream.Dispose();
DisposeSyncHelper(stream);
Comment on lines +84 to +86
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
var stream = Interlocked.Exchange(ref _stream, null!);
stream.Dispose();
DisposeSyncHelper(stream);
QuicStream stream = Interlocked.Exchange(ref _stream, null!);
if (stream is not null)
{
stream.Dispose();
DisposeSyncHelper(stream);
}

}
}

Expand All @@ -91,16 +92,15 @@ public async ValueTask DisposeAsync()
if (!_disposed)
{
_disposed = true;
await _stream.DisposeAsync().ConfigureAwait(false);
DisposeSyncHelper();
var stream = Interlocked.Exchange(ref _stream, null!);
await stream.DisposeAsync().ConfigureAwait(false);
DisposeSyncHelper(stream);
Comment on lines +95 to +97
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
var stream = Interlocked.Exchange(ref _stream, null!);
await stream.DisposeAsync().ConfigureAwait(false);
DisposeSyncHelper(stream);
QuicStream stream = Interlocked.Exchange(ref _stream, null!);
if (stream is not null)
{
await stream.DisposeAsync().ConfigureAwait(false);
DisposeSyncHelper(stream);
}

}
}

private void DisposeSyncHelper()
private void DisposeSyncHelper(QuicStream stream)
{
_connection.RemoveStream(_stream);
_connection = null!;
_stream = null!;
Interlocked.Exchange(ref _connection, null!).RemoveStream(stream);
Copy link
Member

@stephentoub stephentoub Jun 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we've got sole ownership of the stream now due to the Interlocked in the dispose methods above, do we still need this interlocked? Or is there another caller that could get here outside of Dispose{Async}?

Also, if it is still needed, then presumably there's a race condition around nulling this out, in which case the .RemoveStream should be ?.RemoveStream.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've mostly done that to be sure the change is visible to null checks in Http3RequestStream.HandleReadResponseContentException https://github.com/dotnet/runtime/pull/54334/files#diff-b79affd636899ac98b816c1398f9dca19bc4850a939a3815c09c5a9931094a4cR1113-R1120 In the issue I'm trying to solve, disposal and this exception handling are happening concurrently. Here #52800 (comment) @scalablecory says the change might not be visible otherwise. On the other hand, @wfurt offline told me that we usually don't worry about this for example for setting and checking _disposed flag. I am not sure what should be the best practice here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest simply not nulling these fields out (i.e. _stream and _connection).

I'm not quite sure why you null them out today. Is it because you want to avoid doing stuff like Abort on the connection or AbortRead on the QuicStream if we are disposed? Because as the code stands, this is all racy and you aren't avoiding these calls consistently.

You need to either:
(a) Implement full locking here to prevent the races
(b) Avoid the races entirely by not modifying these. (There's still a race here, but it's in msquic code and they are presumably handling this today.)


_sendBuffer.Dispose();
_recvBuffer.Dispose();
Expand Down Expand Up @@ -1110,18 +1110,26 @@ private void HandleReadResponseContentException(Exception ex, CancellationToken
throw new IOException(SR.net_http_client_execution_error, new HttpRequestException(SR.net_http_client_execution_error, abortException));
case Http3ConnectionException _:
// A connection-level protocol error has occurred on our stream.
_connection.Abort(ex);
_connection?.Abort(ex);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When can this happen? It seems weird that errors on a single HTTP stream would cause the whole connection to be killed.

throw new IOException(SR.net_http_client_execution_error, new HttpRequestException(SR.net_http_client_execution_error, ex));
case OperationCanceledException oce when oce.CancellationToken == cancellationToken:
_stream.AbortWrite((long)Http3ErrorCode.RequestCancelled);
_stream?.AbortRead((long)Http3ErrorCode.RequestCancelled);
ExceptionDispatchInfo.Throw(ex); // Rethrow.
return; // Never reached.
default:
_stream.AbortWrite((long)Http3ErrorCode.InternalError);
_stream?.AbortRead((long)Http3ErrorCode.InternalError);
throw new IOException(SR.net_http_client_execution_error, new HttpRequestException(SR.net_http_client_execution_error, ex));
}
}

private void CancelResponseContentRead()
{
if (Volatile.Read(ref _responseDataPayloadRemaining) != -1) // -1 indicates EOS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an inherent race between setting this and checking it in ReadNextDataFrameAsync.

That means you need to handle the case where the AbortRead happens after we check this in ReadNextDataFrameAsync, but before we try to actually read the frame.

As such, I would suggest not even trying to set _responseDataPayloadRemaining here. Just do AbortRead and let the next attempt to read from the underlying stream fail, and handle it as appropriate.

{
_stream.AbortRead((long)Http3ErrorCode.RequestCancelled);
}
}

private async ValueTask<bool> ReadNextDataFrameAsync(HttpResponseMessage response, CancellationToken cancellationToken)
{
if (_responseDataPayloadRemaining == -1)
Expand Down Expand Up @@ -1157,7 +1165,7 @@ private async ValueTask<bool> ReadNextDataFrameAsync(HttpResponseMessage respons
// End of stream.
CopyTrailersToResponseMessage(response);

_responseDataPayloadRemaining = -1; // Set to -1 to indicate EOS.
Volatile.Write(ref _responseDataPayloadRemaining, -1); // Set to -1 to indicate EOS.
return false;
}
}
Expand Down Expand Up @@ -1193,6 +1201,7 @@ protected override void Dispose(bool disposing)
{
if (disposing)
{
_stream.CancelResponseContentRead();
// This will remove the stream from the connection properly.
_stream.Dispose();
}
Expand All @@ -1215,6 +1224,7 @@ public override async ValueTask DisposeAsync()
{
if (_stream != null)
{
_stream.CancelResponseContentRead();
await _stream.DisposeAsync().ConfigureAwait(false);
_stream = null!;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,72 @@ public async Task Public_Interop_Upgrade_Success(string uri)
}
}

[ConditionalFact(nameof(IsMsQuicSupported))]
public async Task ResponseCancellationViaBothDisposeAndCancellationToken_Success()
{
if (UseQuicImplementationProvider != QuicImplementationProviders.MsQuic)
{
return;
}

using Http3LoopbackServer server = CreateHttp3LoopbackServer();

var pauseServerTcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
var pauseClientTcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);

Task serverTask = Task.Run(async () =>
{
using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
HttpRequestData request = await stream.ReadRequestDataAsync().ConfigureAwait(false);

int contentLength = 2*1024;
var headers = new List<HttpHeaderData>();
headers.Append(new HttpHeaderData("Content-Length", contentLength.ToString(CultureInfo.InvariantCulture)));

await stream.SendResponseHeadersAsync(HttpStatusCode.OK, headers).ConfigureAwait(false);
await stream.SendDataFrameAsync(new byte[1024]).ConfigureAwait(false);

await pauseServerTcs.Task.WaitAsync(TimeSpan.FromSeconds(10));

var ex = await Assert.ThrowsAsync<QuicStreamAbortedException>(() => stream.SendDataFrameAsync(new byte[1024]));

await stream.ShutdownSendAsync().ConfigureAwait(false);
pauseClientTcs.SetResult(true);
});

Task clientTask = Task.Run(async () =>
{
using HttpClient client = CreateHttpClient();

using HttpRequestMessage request = new()
{
Method = HttpMethod.Get,
RequestUri = server.Address,
Version = HttpVersion30,
VersionPolicy = HttpVersionPolicy.RequestVersionExact
};
HttpResponseMessage response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead).WaitAsync(TimeSpan.FromSeconds(10));

Stream stream = await response.Content.ReadAsStreamAsync();

int bytesRead = await stream.ReadAsync(new byte[1024]);
Assert.Equal(1024, bytesRead);

var cts = new CancellationTokenSource();

cts.Token.Register(() => response.Dispose());
cts.CancelAfter(200);

await Assert.ThrowsAsync<OperationCanceledException>(() => stream.ReadAsync(new byte[1024], cancellationToken: cts.Token).AsTask());

pauseServerTcs.SetResult(true);
await pauseClientTcs.Task.WaitAsync(TimeSpan.FromSeconds(3));
});

await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000);
}

/// <summary>
/// These are public interop test servers for various QUIC and HTTP/3 implementations,
/// taken from https://github.com/quicwg/base-drafts/wiki/Implementations
Expand Down
1 change: 1 addition & 0 deletions src/libraries/System.Net.Quic/src/System.Net.Quic.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
<Compile Include="System\Net\Quic\Implementations\MsQuic\Internal\*.cs" />
<Compile Include="System\Net\Quic\Implementations\MsQuic\Interop\MsQuicAlpnHelper.cs" />
<Compile Include="System\Net\Quic\Implementations\MsQuic\Interop\MsQuicEnums.cs" />
<Compile Include="System\Net\Quic\Implementations\MsQuic\Interop\MsQuicLogHelper.cs" />
<Compile Include="System\Net\Quic\Implementations\MsQuic\Interop\MsQuicNativeMethods.cs" />
<Compile Include="System\Net\Quic\Implementations\MsQuic\Interop\MsQuicStatusCodes.cs" />
<Compile Include="System\Net\Quic\Implementations\MsQuic\Interop\MsQuicStatusHelper.cs" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Runtime.InteropServices;

namespace System.Net.Quic.Implementations.MsQuic.Internal
{
internal static class MsQuicLogHelper
{
internal static string GetLogId(SafeMsQuicStreamHandle handle)
{
return $"[strm][0x{GetIntPtrHex(handle)}]";
}

internal static string GetLogId(SafeMsQuicConnectionHandle handle)
{
return $"[conn][0x{GetIntPtrHex(handle)}]";
}

private static string GetIntPtrHex(SafeHandle handle)
{
return handle.DangerousGetHandle().ToString("X11");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ internal sealed class State
SingleWriter = true,
});

public string LogId = null!; // set in ctor.

public void RemoveStream(MsQuicStream stream)
{
bool releaseHandles;
Expand Down Expand Up @@ -149,9 +151,11 @@ public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, Saf
throw;
}

_state.LogId = MsQuicLogHelper.GetLogId(_state.Handle);

if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(_state, $"[Connection#{_state.GetHashCode()}] inbound connection created");
NetEventSource.Info(_state, $"{_state.LogId} inbound connection created");
}
}

Expand Down Expand Up @@ -188,9 +192,11 @@ public MsQuicConnection(QuicClientConnectionOptions options)
throw;
}

_state.LogId = MsQuicLogHelper.GetLogId(_state.Handle);

if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(_state, $"[Connection#{_state.GetHashCode()}] outbound connection created");
NetEventSource.Info(_state, $"{_state.LogId} outbound connection created");
}
}

Expand Down Expand Up @@ -393,18 +399,18 @@ private static uint HandleEventPeerCertificateReceived(State state, ref Connecti
{
bool success = connection._remoteCertificateValidationCallback(connection, certificate, chain, sslPolicyErrors);
if (!success && NetEventSource.Log.IsEnabled())
NetEventSource.Error(state, $"[Connection#{state.GetHashCode()}] remote certificate rejected by verification callback");
NetEventSource.Error(state, $"{state.LogId} remote certificate rejected by verification callback");
return success ? MsQuicStatusCodes.Success : MsQuicStatusCodes.HandshakeFailure;
}

if (NetEventSource.Log.IsEnabled())
NetEventSource.Info(state, $"[Connection#{state.GetHashCode()}] certificate validation for '${certificate?.Subject}' finished with ${sslPolicyErrors}");
NetEventSource.Info(state, $"{state.LogId} certificate validation for '${certificate?.Subject}' finished with ${sslPolicyErrors}");

return (sslPolicyErrors == SslPolicyErrors.None) ? MsQuicStatusCodes.Success : MsQuicStatusCodes.HandshakeFailure;
}
catch (Exception ex)
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(state, $"[Connection#{state.GetHashCode()}] certificate validation failed ${ex.Message}");
if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(state, $"{state.LogId} certificate validation failed ${ex.Message}");
}

return MsQuicStatusCodes.InternalError;
Expand Down Expand Up @@ -596,7 +602,7 @@ private static uint NativeCallbackHandler(

if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(state, $"[Connection#{state.GetHashCode()}] received event {connectionEvent.Type}");
NetEventSource.Info(state, $"{state.LogId} received event {connectionEvent.Type}");
}

try
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ private sealed class State

// Set once stream have been shutdown.
public readonly TaskCompletionSource ShutdownCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);

public string LogId = null!; // set in ctor.
}

// inbound.
Expand Down Expand Up @@ -99,13 +101,14 @@ internal MsQuicStream(MsQuicConnection.State connectionState, SafeMsQuicStreamHa
}

_state.ConnectionState = connectionState;


_state.LogId = MsQuicLogHelper.GetLogId(_state.Handle);
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(
_state,
$"[Stream#{_state.GetHashCode()}] inbound {(_canWrite ? "bi" : "uni")}directional stream created " +
$"in Connection#{_state.ConnectionState.GetHashCode()}.");
$"{_state.LogId} inbound {(_canWrite ? "bi" : "uni")}directional stream created " +
$"in {_state.ConnectionState.LogId}.");
}
}

Expand Down Expand Up @@ -147,12 +150,13 @@ internal MsQuicStream(MsQuicConnection.State connectionState, QUIC_STREAM_OPEN_F

_state.ConnectionState = connectionState;

_state.LogId = MsQuicLogHelper.GetLogId(_state.Handle);
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(
_state,
$"[Stream#{_state.GetHashCode()}] outbound {(_canRead ? "bi" : "uni")}directional stream created " +
$"in Connection#{_state.ConnectionState.GetHashCode()}.");
$"{_state.LogId} outbound {(_canRead ? "bi" : "uni")}directional stream created " +
$"in {_state.ConnectionState.LogId}.");
}
}

Expand Down Expand Up @@ -262,6 +266,10 @@ private async ValueTask<CancellationTokenRegistration> HandleWriteStartState(Can
if (_state.SendState == SendState.Aborted)
{
cancellationToken.ThrowIfCancellationRequested();
if (_state.SendErrorCode != -1)
{
throw new QuicStreamAbortedException(_state.SendErrorCode);
}
throw new OperationCanceledException(SR.net_quic_sending_aborted);
}
else if (_state.SendState == SendState.ConnectionClosed)
Expand Down Expand Up @@ -306,7 +314,7 @@ internal override async ValueTask<int> ReadAsync(Memory<byte> destination, Cance

if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(_state, $"[Stream#{_state.GetHashCode()}] reading into Memory of '{destination.Length}' bytes.");
NetEventSource.Info(_state, $"{_state.LogId} reading into Memory of '{destination.Length}' bytes.");
}

lock (_state)
Expand Down Expand Up @@ -583,7 +591,7 @@ private void Dispose(bool disposing)

if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(_state, $"[Stream#{_state.GetHashCode()}] disposed");
NetEventSource.Info(_state, $"{_state.LogId} disposed");
}
}

Expand All @@ -605,7 +613,7 @@ private static uint HandleEvent(State state, ref StreamEvent evt)
{
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(state, $"[Stream#{state.GetHashCode()}] received event {evt.Type}");
NetEventSource.Info(state, $"{state.LogId} received event {evt.Type}");
}

try
Expand Down Expand Up @@ -757,7 +765,7 @@ private static uint HandleEventShutdownComplete(State state, ref StreamEvent evt
lock (state)
{
// This event won't occur within the middle of a receive.
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(state, $"[Stream#{state.GetHashCode()}] completing resettable event source.");
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(state, $"{state.LogId} completing resettable event source.");

if (state.ReadState == ReadState.None)
{
Expand Down Expand Up @@ -829,7 +837,7 @@ private static uint HandleEventPeerSendShutdown(State state)
lock (state)
{
// This event won't occur within the middle of a receive.
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(state, $"[Stream#{state.GetHashCode()}] completing resettable event source.");
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(state, $"{state.LogId} completing resettable event source.");

if (state.ReadState == ReadState.None)
{
Expand Down Expand Up @@ -1119,7 +1127,7 @@ private static uint HandleEventConnectionClose(State state)
long errorCode = state.ConnectionState.AbortErrorCode;
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(state, $"[Stream#{state.GetHashCode()}] handling Connection#{state.ConnectionState.GetHashCode()} close" +
NetEventSource.Info(state, $"{state.LogId} handling {state.ConnectionState.LogId} close" +
(errorCode != -1 ? $" with code {errorCode}" : ""));
}

Expand Down