Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

keep MsQuicConnection alive when streams are pending #52800

Merged
merged 9 commits into from
Jun 10, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ internal static class MsQuicStatusCodes
internal static uint InternalError => OperatingSystem.IsWindows() ? Windows.InternalError : Posix.InternalError;
internal static uint InvalidState => OperatingSystem.IsWindows() ? Windows.InvalidState : Posix.InvalidState;
internal static uint HandshakeFailure => OperatingSystem.IsWindows() ? Windows.HandshakeFailure : Posix.HandshakeFailure;
internal static uint UserCanceled => OperatingSystem.IsWindows() ? Windows.UserCanceled : Posix.UserCanceled;

// TODO return better error messages here.
public static string GetError(uint status) => OperatingSystem.IsWindows() ? Windows.GetError(status) : Posix.GetError(status);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ internal sealed class MsQuicConnection : QuicConnectionProvider

private readonly State _state = new State();
private GCHandle _stateHandle;
private bool _disposed;
private int _disposed;

private IPEndPoint? _localEndPoint;
private readonly EndPoint _remoteEndPoint;
Expand All @@ -44,7 +44,6 @@ private sealed class State
{
public SafeMsQuicConnectionHandle Handle = null!; // set inside of MsQuicConnection ctor.

// These exists to prevent GC of the MsQuicConnection in the middle of an async op (Connect or Shutdown).
public MsQuicConnection? Connection;

// TODO: only allocate these when there is an outstanding connect/shutdown.
Expand All @@ -53,6 +52,7 @@ private sealed class State

public bool Connected;
public long AbortErrorCode = -1;
public int StreamCount;
wfurt marked this conversation as resolved.
Show resolved Hide resolved

// Queue for accepted streams.
// Backlog limit is managed by MsQuic so it can be unbounded here.
Expand Down Expand Up @@ -87,6 +87,8 @@ public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, Saf
_stateHandle.Free();
throw;
}

_state.Connection = this;
}

// constructor for outbound connections
Expand Down Expand Up @@ -121,6 +123,8 @@ public MsQuicConnection(QuicClientConnectionOptions options)
_stateHandle.Free();
throw;
}

_state.Connection = this;
}

internal override IPEndPoint? LocalEndPoint => _localEndPoint;
Expand All @@ -142,7 +146,6 @@ private static uint HandleEventConnected(State state, ref ConnectionEvent connec
Debug.Assert(state.Connection != null);
state.Connection._localEndPoint = MsQuicAddressHelpers.INetToIPEndPoint(ref inetAddress);
state.Connection.SetNegotiatedAlpn(connectionEvent.Data.Connected.NegotiatedAlpn, connectionEvent.Data.Connected.NegotiatedAlpnLength);
state.Connection = null;

state.Connected = true;
state.ConnectTcs.SetResult(MsQuicStatusCodes.Success);
Expand Down Expand Up @@ -185,12 +188,56 @@ private static uint HandleEventShutdownComplete(State state, ref ConnectionEvent
return MsQuicStatusCodes.Success;
}

public void RemoveStream(MsQuicStream stream)
{
lock (_state)
{
_state.StreamCount--;
}

if (_state.Connection == null && _state.StreamCount == 0)
wfurt marked this conversation as resolved.
Show resolved Hide resolved
{
_state?.Handle?.Dispose();
if (_stateHandle.IsAllocated) _stateHandle.Free();
}
}

private bool TryAddStream(SafeMsQuicStreamHandle streamHandle, QUIC_STREAM_OPEN_FLAGS flags)
{
lock (_state)
{
var stream = new MsQuicStream(this, streamHandle, flags);
// once stream is created, it will call RemoveStream on disposal.
_state.StreamCount++;
if (_state.Connection != null && _state.AcceptQueue.Writer.TryWrite(stream))
{
return true;
}
else
{
stream.Dispose();
return false;
}
}
}

private static uint HandleEventNewStream(State state, ref ConnectionEvent connectionEvent)
{
MsQuicConnection? connection = state.Connection;

if (connection == null)
{
return MsQuicStatusCodes.UserCanceled;
}

var streamHandle = new SafeMsQuicStreamHandle(connectionEvent.Data.PeerStreamStarted.Stream);
var stream = new MsQuicStream(streamHandle, connectionEvent.Data.PeerStreamStarted.Flags);
if (!connection.TryAddStream(streamHandle, connectionEvent.Data.PeerStreamStarted.Flags))
{
// This will call StreamCloseDelegate and free the stream.
// We will return Success to the MsQuic to prevent double free.
streamHandle.Dispose();
}

state.AcceptQueue.Writer.TryWrite(stream);
return MsQuicStatusCodes.Success;
}

Expand Down Expand Up @@ -326,13 +373,25 @@ internal override async ValueTask<QuicStreamProvider> AcceptStreamAsync(Cancella
internal override QuicStreamProvider OpenUnidirectionalStream()
{
ThrowIfDisposed();
return new MsQuicStream(_state.Handle, QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL);
lock (_state)
{
var stream = new MsQuicStream(this, _state.Handle, QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL);
// once stream is created, it will call RemoveStream on disposal.
_state.StreamCount++;
wfurt marked this conversation as resolved.
Show resolved Hide resolved
return stream;
}
}

internal override QuicStreamProvider OpenBidirectionalStream()
{
ThrowIfDisposed();
return new MsQuicStream(_state.Handle, QUIC_STREAM_OPEN_FLAGS.NONE);
lock (_state)
{
var stream = new MsQuicStream(this, _state.Handle, QUIC_STREAM_OPEN_FLAGS.NONE);
// once stream is created, it will call RemoveStream on disposal.
_state.StreamCount++;
return stream;
}
}

internal override long GetRemoteAvailableUnidirectionalStreamCount()
Expand Down Expand Up @@ -394,8 +453,6 @@ private ValueTask ShutdownAsync(
QUIC_CONNECTION_SHUTDOWN_FLAGS Flags,
long ErrorCode)
{
// Store the connection into the GCHandle'd state to prevent GC if user calls ShutdownAsync and gets rid of all references to the MsQuicConnection.
Debug.Assert(_state.Connection == null);
_state.Connection = this;

try
Expand Down Expand Up @@ -476,16 +533,38 @@ public override void Dispose()
Dispose(false);
}

private async Task FlushAcceptQueue()
{
try {
// Writer may or may not be completed.
_state.AcceptQueue.Writer.Complete();
} catch { };
wfurt marked this conversation as resolved.
Show resolved Hide resolved

await foreach (MsQuicStream item in _state.AcceptQueue.Reader.ReadAllAsync())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
await foreach (MsQuicStream item in _state.AcceptQueue.Reader.ReadAllAsync())
await foreach (MsQuicStream item in _state.AcceptQueue.Reader.ReadAllAsync().ConfigureAwait(false))

{
item.Dispose();
}
}

private void Dispose(bool disposing)
{
if (_disposed)
int disposed = Interlocked.Exchange(ref _disposed, 1);
if (disposed == 1)
{
return;
}
wfurt marked this conversation as resolved.
Show resolved Hide resolved

_state?.Handle?.Dispose();
if (_stateHandle.IsAllocated) _stateHandle.Free();
_disposed = true;
_state.Connection = null;
FlushAcceptQueue().GetAwaiter().GetResult();

lock (_state)
{
if (_state.StreamCount == 0)
{
_state?.Handle?.Dispose();
if (_stateHandle.IsAllocated) _stateHandle.Free();
}
}
}

// TODO: this appears abortive and will cause prior successfully shutdown and closed streams to drop data.
Expand All @@ -499,7 +578,7 @@ internal override ValueTask CloseAsync(long errorCode, CancellationToken cancell

private void ThrowIfDisposed()
{
if (_disposed)
if (_disposed == 1)
wfurt marked this conversation as resolved.
Show resolved Hide resolved
{
throw new ObjectDisposedException(nameof(MsQuicStream));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ internal sealed class MsQuicStream : QuicStreamProvider

private volatile bool _disposed;

private MsQuicConnection? _connection;

private sealed class State
{
public SafeMsQuicStreamHandle Handle = null!; // set in ctor.
Expand Down Expand Up @@ -71,7 +73,7 @@ private sealed class State
}

// inbound.
internal MsQuicStream(SafeMsQuicStreamHandle streamHandle, QUIC_STREAM_OPEN_FLAGS flags)
internal MsQuicStream(MsQuicConnection connection, SafeMsQuicStreamHandle streamHandle, QUIC_STREAM_OPEN_FLAGS flags)
{
_state.Handle = streamHandle;
_canRead = true;
Expand All @@ -91,10 +93,12 @@ internal MsQuicStream(SafeMsQuicStreamHandle streamHandle, QUIC_STREAM_OPEN_FLAG
_stateHandle.Free();
throw;
}

_connection = connection;
}

// outbound.
internal MsQuicStream(SafeMsQuicConnectionHandle connection, QUIC_STREAM_OPEN_FLAGS flags)
internal MsQuicStream(MsQuicConnection connection, SafeMsQuicConnectionHandle connectionHandle, QUIC_STREAM_OPEN_FLAGS flags)
{
Debug.Assert(connection != null);

Expand All @@ -105,7 +109,7 @@ internal MsQuicStream(SafeMsQuicConnectionHandle connection, QUIC_STREAM_OPEN_FL
try
{
uint status = MsQuicApi.Api.StreamOpenDelegate(
connection,
connectionHandle,
flags,
s_streamDelegate,
GCHandle.ToIntPtr(_stateHandle),
Expand All @@ -122,6 +126,8 @@ internal MsQuicStream(SafeMsQuicConnectionHandle connection, QUIC_STREAM_OPEN_FL
_stateHandle.Free();
throw;
}

_connection = connection;
}

internal override bool CanRead => _canRead;
Expand Down Expand Up @@ -284,7 +290,6 @@ internal override async ValueTask<int> ReadAsync(Memory<byte> destination, Cance
{
shouldComplete = true;
}

state.ReadState = ReadState.Aborted;
}

Expand Down Expand Up @@ -503,6 +508,8 @@ private void Dispose(bool disposing)
Marshal.FreeHGlobal(_state.SendQuicBuffers);
if (_stateHandle.IsAllocated) _stateHandle.Free();
CleanupSendState(_state);
_connection?.RemoveStream(this);
_connection = null;
}

private void EnableReceive()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,6 @@ await RunClientServer(
await (new[] { t1, t2 }).WhenAllOrAnyFailed(millisecondsTimeout: 1000000);
}

[ActiveIssue("https://github.com/dotnet/runtime/issues/52048")]
[Fact]
public async Task ManagedAVE_MinimalFailingTest()
{
Expand All @@ -370,6 +369,32 @@ async Task GetStreamIdWithoutStartWorks()
// TODO: stream that is opened by client but left unaccepted by server may cause AccessViolationException in its Finalizer
}

await GetStreamIdWithoutStartWorks().WaitAsync(TimeSpan.FromSeconds(15));

GC.Collect();
}

[Fact]
public async Task DisposingConnection_OK()
{
async Task GetStreamIdWithoutStartWorks()
{
using QuicListener listener = CreateQuicListener();
using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);

ValueTask clientTask = clientConnection.ConnectAsync();
using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;

using QuicStream clientStream = clientConnection.OpenBidirectionalStream();
Assert.Equal(0, clientStream.StreamId);

// Dispose all connections before the streams;
clientConnection.Dispose();
serverConnection.Dispose();
listener.Dispose();
}

await GetStreamIdWithoutStartWorks();

GC.Collect();
Expand Down