diff --git a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackServer.cs b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackServer.cs index 1be1364220a24..cf83b893ff739 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackServer.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackServer.cs @@ -20,26 +20,32 @@ public sealed class Http3LoopbackServer : GenericLoopbackServer public override Uri Address => new Uri($"https://{_listener.ListenEndPoint}/"); - public Http3LoopbackServer(QuicImplementationProvider quicImplementationProvider = null, GenericLoopbackOptions options = null) + public Http3LoopbackServer(QuicImplementationProvider quicImplementationProvider = null, Http3Options options = null) { - options ??= new GenericLoopbackOptions(); + options ??= new Http3Options(); _cert = Configuration.Certificates.GetServerCertificate(); - var sslOpts = new SslServerAuthenticationOptions + var listenerOptions = new QuicListenerOptions() { - EnabledSslProtocols = options.SslProtocols, - ApplicationProtocols = new List + ListenEndPoint = new IPEndPoint(options.Address, 0), + ServerAuthenticationOptions = new SslServerAuthenticationOptions { - new SslApplicationProtocol("h3-31"), - new SslApplicationProtocol("h3-30"), - new SslApplicationProtocol("h3-29") + EnabledSslProtocols = options.SslProtocols, + ApplicationProtocols = new List + { + new SslApplicationProtocol("h3-31"), + new SslApplicationProtocol("h3-30"), + new SslApplicationProtocol("h3-29") + }, + ServerCertificate = _cert, + ClientCertificateRequired = false }, - ServerCertificate = _cert, - ClientCertificateRequired = false + MaxUnidirectionalStreams = options.MaxUnidirectionalStreams, + MaxBidirectionalStreams = options.MaxBidirectionalStreams, }; - _listener = new QuicListener(quicImplementationProvider ?? QuicImplementationProviders.Default, new IPEndPoint(options.Address, 0), sslOpts); + _listener = new QuicListener(quicImplementationProvider ?? QuicImplementationProviders.Default, listenerOptions); } public override void Dispose() @@ -82,7 +88,7 @@ public Http3LoopbackServerFactory(QuicImplementationProvider quicImplementationP public override GenericLoopbackServer CreateServer(GenericLoopbackOptions options = null) { - return new Http3LoopbackServer(_quicImplementationProvider, options); + return new Http3LoopbackServer(_quicImplementationProvider, CreateOptions(options)); } public override async Task CreateServerAsync(Func funcAsync, int millisecondsTimeout = 60000, GenericLoopbackOptions options = null) @@ -97,5 +103,29 @@ public override Task CreateConnectionAsync(Socket soc // This method is always unacceptable to call for HTTP/3. throw new NotImplementedException("HTTP/3 does not operate over a Socket."); } + + private static Http3Options CreateOptions(GenericLoopbackOptions options) + { + Http3Options http3Options = new Http3Options(); + if (options != null) + { + http3Options.Address = options.Address; + http3Options.UseSsl = options.UseSsl; + http3Options.SslProtocols = options.SslProtocols; + http3Options.ListenBacklog = options.ListenBacklog; + } + return http3Options; + } + } + public class Http3Options : GenericLoopbackOptions + { + public int MaxUnidirectionalStreams {get; set; } + + public int MaxBidirectionalStreams {get; set; } + public Http3Options() + { + MaxUnidirectionalStreams = 100; + MaxBidirectionalStreams = 100; + } } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs index ca6dd5df9fc83..51f65a875fe54 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs @@ -49,11 +49,6 @@ internal sealed class Http3Connection : HttpConnectionBase, IDisposable private int _haveServerQpackDecodeStream; private int _haveServerQpackEncodeStream; - // Manages MAX_STREAM count from server. - private long _maximumRequestStreams; - private long _requestStreamsRemaining; - private readonly Queue> _waitingRequests = new Queue>(); - // A connection-level error will abort any future operations. private Exception? _abortException; @@ -87,8 +82,6 @@ public Http3Connection(HttpConnectionPool pool, HttpAuthority? origin, HttpAutho string altUsedValue = altUsedDefaultPort ? authority.IdnHost : authority.IdnHost + ":" + authority.Port.ToString(Globalization.CultureInfo.InvariantCulture); _altUsedEncodedHeader = QPack.QPackEncoder.EncodeLiteralHeaderFieldWithoutNameReferenceToArray(KnownHeaders.AltUsed.Name, altUsedValue); - _maximumRequestStreams = _requestStreamsRemaining = connection.GetRemoteAvailableBidirectionalStreamCount(); - // Errors are observed via Abort(). _ = SendSettingsAsync(); @@ -166,45 +159,34 @@ public override async Task SendAsync(HttpRequestMessage req { Debug.Assert(async); - // Wait for an available stream (based on QUIC MAX_STREAMS) if there isn't one available yet. - - TaskCompletionSourceWithCancellation? waitForAvailableStreamTcs = null; - - lock (SyncObj) - { - long remaining = _requestStreamsRemaining; - - if (remaining > 0) - { - _requestStreamsRemaining = remaining - 1; - } - else - { - waitForAvailableStreamTcs = new TaskCompletionSourceWithCancellation(); - _waitingRequests.Enqueue(waitForAvailableStreamTcs); - } - } - - if (waitForAvailableStreamTcs != null) - { - await waitForAvailableStreamTcs.WaitWithCancellationAsync(cancellationToken).ConfigureAwait(false); - } - // Allocate an active request - QuicStream? quicStream = null; Http3RequestStream? requestStream = null; + ValueTask waitTask = default; try { - lock (SyncObj) + while (true) { - if (_connection != null) + lock (SyncObj) { - quicStream = _connection.OpenBidirectionalStream(); - requestStream = new Http3RequestStream(request, this, quicStream); - _activeRequests.Add(quicStream, requestStream); + if (_connection == null) + { + break; + } + + if (_connection.GetRemoteAvailableBidirectionalStreamCount() > 0) + { + quicStream = _connection.OpenBidirectionalStream(); + requestStream = new Http3RequestStream(request, this, quicStream); + _activeRequests.Add(quicStream, requestStream); + break; + } + waitTask = _connection.WaitForAvailableBidirectionalStreamsAsync(cancellationToken); } + + // Wait for an available stream (based on QUIC MAX_STREAMS) if there isn't one available yet. + await waitTask.ConfigureAwait(false); } if (quicStream == null) @@ -212,8 +194,6 @@ public override async Task SendAsync(HttpRequestMessage req throw new HttpRequestException(SR.net_http_request_aborted, null, RequestRetryType.RetryOnConnectionFailure); } - // 0-byte write to force QUIC to allocate a stream ID. - await quicStream.WriteAsync(Array.Empty(), cancellationToken).ConfigureAwait(false); requestStream!.StreamId = quicStream.StreamId; bool goAway; @@ -246,76 +226,6 @@ public override async Task SendAsync(HttpRequestMessage req } } - /// - /// Waits for MAX_STREAMS to be raised by the server. - /// - private Task WaitForAvailableRequestStreamAsync(CancellationToken cancellationToken) - { - TaskCompletionSourceWithCancellation tcs; - - lock (SyncObj) - { - long remaining = _requestStreamsRemaining; - - if (remaining > 0) - { - _requestStreamsRemaining = remaining - 1; - return Task.CompletedTask; - } - - tcs = new TaskCompletionSourceWithCancellation(); - _waitingRequests.Enqueue(tcs); - } - - // Note: cancellation on connection shutdown is handled in CancelWaiters. - return tcs.WaitWithCancellationAsync(cancellationToken).AsTask(); - } - - /// - /// Cancels any waiting SendAsync calls. - /// - /// Requires to be held. - private void CancelWaiters() - { - Debug.Assert(Monitor.IsEntered(SyncObj)); - - while (_waitingRequests.TryDequeue(out TaskCompletionSourceWithCancellation? tcs)) - { - tcs.TrySetException(new HttpRequestException(SR.net_http_request_aborted, null, RequestRetryType.RetryOnConnectionFailure)); - } - } - - // TODO: how do we get this event? -> HandleEventStreamsAvailable reports currently available Uni/Bi streams - private void OnMaximumStreamCountIncrease(long newMaximumStreamCount) - { - lock (SyncObj) - { - if (newMaximumStreamCount <= _maximumRequestStreams) - { - return; - } - - IncreaseRemainingStreamCount(newMaximumStreamCount - _maximumRequestStreams); - _maximumRequestStreams = newMaximumStreamCount; - } - } - - private void IncreaseRemainingStreamCount(long delta) - { - Debug.Assert(Monitor.IsEntered(SyncObj)); - Debug.Assert(delta > 0); - - _requestStreamsRemaining += delta; - - while (_requestStreamsRemaining != 0 && _waitingRequests.TryDequeue(out TaskCompletionSourceWithCancellation? tcs)) - { - if (tcs.TrySetResult(true)) - { - --_requestStreamsRemaining; - } - } - } - /// /// Aborts the connection with an error. /// @@ -358,7 +268,6 @@ internal Exception Abort(Exception abortException) _connectionClosedTask = _connection.CloseAsync((long)connectionResetErrorCode).AsTask(); } - CancelWaiters(); CheckForShutdown(); } @@ -396,7 +305,6 @@ private void OnServerGoAway(long lastProcessedStreamId) } } - CancelWaiters(); CheckForShutdown(); } @@ -414,8 +322,6 @@ public void RemoveStream(QuicStream stream) bool removed = _activeRequests.Remove(stream); Debug.Assert(removed == true); - IncreaseRemainingStreamCount(1); - if (ShuttingDown) { CheckForShutdown(); diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs index 0732797672f82..0a193684e4ce7 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs @@ -79,10 +79,12 @@ public async Task ClientSettingsReceived_Success(int headerSizeLimit) } [Theory] + [InlineData(10)] [InlineData(100)] + [InlineData(1000)] public async Task SendMoreThanStreamLimitRequests_Succeeds(int streamLimit) { - using Http3LoopbackServer server = CreateHttp3LoopbackServer(); + using Http3LoopbackServer server = CreateHttp3LoopbackServer(new Http3Options(){ MaxBidirectionalStreams = streamLimit }); Task serverTask = Task.Run(async () => { @@ -100,7 +102,7 @@ public async Task SendMoreThanStreamLimitRequests_Succeeds(int streamLimit) for (int i = 0; i < streamLimit + 1; ++i) { - using HttpRequestMessage request = new() + HttpRequestMessage request = new() { Method = HttpMethod.Get, RequestUri = server.Address, @@ -114,6 +116,162 @@ public async Task SendMoreThanStreamLimitRequests_Succeeds(int streamLimit) await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000); } + [Theory] + [InlineData(10)] + [InlineData(100)] + [InlineData(1000)] + public async Task SendStreamLimitRequestsConcurrently_Succeeds(int streamLimit) + { + using Http3LoopbackServer server = CreateHttp3LoopbackServer(new Http3Options(){ MaxBidirectionalStreams = streamLimit }); + + Task serverTask = Task.Run(async () => + { + using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); + for (int i = 0; i < streamLimit; ++i) + { + using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); + await stream.HandleRequestAsync(); + } + }); + + Task clientTask = Task.Run(async () => + { + using HttpClient client = CreateHttpClient(); + + var tasks = new Task[streamLimit]; + Parallel.For(0, streamLimit, i => + { + HttpRequestMessage request = new() + { + Method = HttpMethod.Get, + RequestUri = server.Address, + Version = HttpVersion30, + VersionPolicy = HttpVersionPolicy.RequestVersionExact + }; + + tasks[i] = client.SendAsync(request); + }); + + var responses = await Task.WhenAll(tasks); + foreach (var response in responses) + { + response.Dispose(); + } + }); + + await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000); + } + + [Theory] + [InlineData(10)] + [InlineData(100)] + [InlineData(1000)] + public async Task SendMoreThanStreamLimitRequestsConcurrently_LastWaits(int streamLimit) + { + // This combination leads to a hang manifesting in CI only. Disabling it until there's more time to investigate. + // [ActiveIssue("https://github.com/dotnet/runtime/issues/53688")] + if (streamLimit == 10 && this.UseQuicImplementationProvider == QuicImplementationProviders.Mock) + { + return; + } + + using Http3LoopbackServer server = CreateHttp3LoopbackServer(new Http3Options(){ MaxBidirectionalStreams = streamLimit }); + var lastRequestContentStarted = new TaskCompletionSource(); + + Task serverTask = Task.Run(async () => + { + // Read the first streamLimit requests, keep the streams open to make the last one wait. + using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); + var streams = new Http3LoopbackStream[streamLimit]; + for (int i = 0; i < streamLimit; ++i) + { + Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); + var body = await stream.ReadRequestDataAsync(); + streams[i] = stream; + } + + // Make the last request running independently. + var lastRequest = Task.Run(async () => { + using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); + await stream.HandleRequestAsync(); + }); + + // All the initial streamLimit streams are still opened so the last request cannot started yet. + Assert.False(lastRequestContentStarted.Task.IsCompleted); + + // Reply to the first streamLimit requests. + for (int i = 0; i < streamLimit; ++i) + { + await streams[i].SendResponseAsync(); + streams[i].Dispose(); + // After the first request is fully processed, the last request should unblock and get processed. + if (i == 0) + { + await lastRequestContentStarted.Task; + } + } + await lastRequest; + }); + + Task clientTask = Task.Run(async () => + { + using HttpClient client = CreateHttpClient(); + + // Fire out the first streamLimit requests in parallel, no waiting for the responses yet. + var countdown = new CountdownEvent(streamLimit); + var tasks = new Task[streamLimit]; + Parallel.For(0, streamLimit, i => + { + HttpRequestMessage request = new() + { + Method = HttpMethod.Post, + RequestUri = server.Address, + Version = HttpVersion30, + VersionPolicy = HttpVersionPolicy.RequestVersionExact, + Content = new StreamContent(new DelegateStream( + canReadFunc: () => true, + readFunc: (buffer, offset, count) => + { + countdown.Signal(); + return 0; + })) + }; + + tasks[i] = client.SendAsync(request); + }); + + // Wait for the first streamLimit request to get started. + countdown.Wait(); + + // Fire out the last request, that should wait until the server fully handles at least one request. + HttpRequestMessage last = new() + { + Method = HttpMethod.Post, + RequestUri = server.Address, + Version = HttpVersion30, + VersionPolicy = HttpVersionPolicy.RequestVersionExact, + Content = new StreamContent(new DelegateStream( + canReadFunc: () => true, + readFunc: (buffer, offset, count) => + { + lastRequestContentStarted.SetResult(); + return 0; + })) + }; + var lastTask = client.SendAsync(last); + + // Wait for all requests to finish. Whether the last request was pending is checked on the server side. + var responses = await Task.WhenAll(tasks); + foreach (var response in responses) + { + response.Dispose(); + } + await lastTask; + }); + + await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000); + } + [Fact] [ActiveIssue("https://github.com/dotnet/runtime/issues/53090")] public async Task ReservedFrameType_Throws() diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTestBase.SocketsHttpHandler.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTestBase.SocketsHttpHandler.cs index dc9d2b57efead..e68bb1768c4bf 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTestBase.SocketsHttpHandler.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTestBase.SocketsHttpHandler.cs @@ -39,9 +39,9 @@ protected static HttpClientHandler CreateHttpClientHandler(Version useVersion = return handler; } - protected Http3LoopbackServer CreateHttp3LoopbackServer() + protected Http3LoopbackServer CreateHttp3LoopbackServer(Http3Options options = default) { - return new Http3LoopbackServer(UseQuicImplementationProvider); + return new Http3LoopbackServer(UseQuicImplementationProvider, options); } protected HttpClientHandler CreateHttpClientHandler() => CreateHttpClientHandler(UseVersion, UseQuicImplementationProvider); @@ -84,7 +84,7 @@ protected static LoopbackServerFactory GetFactoryForVersion(Version useVersion, internal class VersionHttpClientHandler : HttpClientHandler { private readonly Version _useVersion; - + public VersionHttpClientHandler(Version useVersion) { _useVersion = useVersion; @@ -107,7 +107,7 @@ protected override Task SendAsync(HttpRequestMessage reques { request.VersionPolicy = HttpVersionPolicy.RequestVersionExact; } - + return base.SendAsync(request, cancellationToken); } diff --git a/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs b/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs index 4f4abcf7e5236..be6df9a7c479b 100644 --- a/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs +++ b/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs @@ -27,10 +27,12 @@ public QuicConnection(System.Net.Quic.QuicClientConnectionOptions options) { } public System.Threading.Tasks.ValueTask CloseAsync(long errorCode, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public System.Threading.Tasks.ValueTask ConnectAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public void Dispose() { } - public long GetRemoteAvailableBidirectionalStreamCount() { throw null; } - public long GetRemoteAvailableUnidirectionalStreamCount() { throw null; } + public int GetRemoteAvailableBidirectionalStreamCount() { throw null; } + public int GetRemoteAvailableUnidirectionalStreamCount() { throw null; } public System.Net.Quic.QuicStream OpenBidirectionalStream() { throw null; } public System.Net.Quic.QuicStream OpenUnidirectionalStream() { throw null; } + public System.Threading.Tasks.ValueTask WaitForAvailableBidirectionalStreamsAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public System.Threading.Tasks.ValueTask WaitForAvailableUnidirectionalStreamsAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } } public partial class QuicConnectionAbortedException : System.Net.Quic.QuicException { @@ -73,8 +75,8 @@ public partial class QuicOptions { public QuicOptions() { } public System.TimeSpan IdleTimeout { get { throw null; } set { } } - public long MaxBidirectionalStreams { get { throw null; } set { } } - public long MaxUnidirectionalStreams { get { throw null; } set { } } + public int MaxBidirectionalStreams { get { throw null; } set { } } + public int MaxUnidirectionalStreams { get { throw null; } set { } } } public sealed partial class QuicStream : System.IO.Stream { @@ -101,8 +103,8 @@ public override void Flush() { } public override long Seek(long offset, System.IO.SeekOrigin origin) { throw null; } public override void SetLength(long value) { } public void Shutdown() { } - public System.Threading.Tasks.ValueTask ShutdownWriteCompleted(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public System.Threading.Tasks.ValueTask ShutdownCompleted(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public System.Threading.Tasks.ValueTask ShutdownWriteCompleted(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public override void Write(byte[] buffer, int offset, int count) { } public override void Write(System.ReadOnlySpan buffer) { } public System.Threading.Tasks.ValueTask WriteAsync(System.Buffers.ReadOnlySequence buffers, bool endStream, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockConnection.cs index 9a24bc1436c2f..3a876d25bcb64 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockConnection.cs @@ -4,6 +4,7 @@ using System.Diagnostics; using System.Net; using System.Net.Security; +using System.Runtime.ExceptionServices; using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; @@ -20,11 +21,16 @@ internal sealed class MockConnection : QuicConnectionProvider private object _syncObject = new object(); private long _nextOutboundBidirectionalStream; private long _nextOutboundUnidirectionalStream; + private readonly int _maxUnidirectionalStreams; + private readonly int _maxBidirectionalStreams; private ConnectionState? _state; + internal PeerStreamLimit? LocalStreamLimit => _isClient ? _state?._clientStreamLimit : _state?._serverStreamLimit; + internal PeerStreamLimit? RemoteStreamLimit => _isClient ? _state?._serverStreamLimit : _state?._clientStreamLimit; + // Constructor for outbound connections - internal MockConnection(EndPoint? remoteEndPoint, SslClientAuthenticationOptions? sslClientAuthenticationOptions, IPEndPoint? localEndPoint = null) + internal MockConnection(EndPoint? remoteEndPoint, SslClientAuthenticationOptions? sslClientAuthenticationOptions, IPEndPoint? localEndPoint = null, int maxUnidirectionalStreams = 100, int maxBidirectionalStreams = 100) { if (remoteEndPoint is null) { @@ -43,6 +49,8 @@ internal MockConnection(EndPoint? remoteEndPoint, SslClientAuthenticationOptions _sslClientAuthenticationOptions = sslClientAuthenticationOptions; _nextOutboundBidirectionalStream = 0; _nextOutboundUnidirectionalStream = 2; + _maxUnidirectionalStreams = maxUnidirectionalStreams; + _maxBidirectionalStreams = maxBidirectionalStreams; // _state is not initialized until ConnectAsync } @@ -129,7 +137,10 @@ internal override ValueTask ConnectAsync(CancellationToken cancellationToken = d } // TODO: deal with protocol negotiation - _state = new ConnectionState(_sslClientAuthenticationOptions!.ApplicationProtocols![0]); + _state = new ConnectionState(_sslClientAuthenticationOptions!.ApplicationProtocols![0]) + { + _clientStreamLimit = new PeerStreamLimit(_maxUnidirectionalStreams, _maxBidirectionalStreams) + }; if (!listener.TryConnect(_state)) { throw new QuicException("Connection refused"); @@ -138,8 +149,41 @@ internal override ValueTask ConnectAsync(CancellationToken cancellationToken = d return ValueTask.CompletedTask; } + internal override ValueTask WaitForAvailableUnidirectionalStreamsAsync(CancellationToken cancellationToken = default) + { + PeerStreamLimit? streamLimit = RemoteStreamLimit; + if (streamLimit is null) + { + throw new InvalidOperationException("Not connected"); + } + + return streamLimit.Unidirectional.WaitForAvailableStreams(cancellationToken); + } + + internal override ValueTask WaitForAvailableBidirectionalStreamsAsync(CancellationToken cancellationToken = default) + { + PeerStreamLimit? streamLimit = RemoteStreamLimit; + if (streamLimit is null) + { + throw new InvalidOperationException("Not connected"); + } + + return streamLimit.Bidirectional.WaitForAvailableStreams(cancellationToken); + } + internal override QuicStreamProvider OpenUnidirectionalStream() { + PeerStreamLimit? streamLimit = RemoteStreamLimit; + if (streamLimit is null) + { + throw new InvalidOperationException("Not connected"); + } + + if (!streamLimit.Unidirectional.TryIncrement()) + { + throw new QuicException("No available unidirectional stream"); + } + long streamId; lock (_syncObject) { @@ -152,6 +196,17 @@ internal override QuicStreamProvider OpenUnidirectionalStream() internal override QuicStreamProvider OpenBidirectionalStream() { + PeerStreamLimit? streamLimit = RemoteStreamLimit; + if (streamLimit is null) + { + throw new InvalidOperationException("Not connected"); + } + + if (!streamLimit.Bidirectional.TryIncrement()) + { + throw new QuicException("No available bidirectional stream"); + } + long streamId; lock (_syncObject) { @@ -174,12 +229,30 @@ internal MockStream OpenStream(long streamId, bool bidirectional) Channel streamChannel = _isClient ? state._clientInitiatedStreamChannel : state._serverInitiatedStreamChannel; streamChannel.Writer.TryWrite(streamState); - return new MockStream(streamState, true); + return new MockStream(this, streamState, true); } - internal override long GetRemoteAvailableUnidirectionalStreamCount() => long.MaxValue; + internal override int GetRemoteAvailableUnidirectionalStreamCount() + { + PeerStreamLimit? streamLimit = RemoteStreamLimit; + if (streamLimit is null) + { + throw new InvalidOperationException("Not connected"); + } + + return streamLimit.Unidirectional.AvailableCount; + } + + internal override int GetRemoteAvailableBidirectionalStreamCount() + { + PeerStreamLimit? streamLimit = RemoteStreamLimit; + if (streamLimit is null) + { + throw new InvalidOperationException("Not connected"); + } - internal override long GetRemoteAvailableBidirectionalStreamCount() => long.MaxValue; + return streamLimit.Bidirectional.AvailableCount; + } internal override async ValueTask AcceptStreamAsync(CancellationToken cancellationToken = default) { @@ -196,7 +269,7 @@ internal override async ValueTask AcceptStreamAsync(Cancella try { MockStream.StreamState streamState = await streamChannel.Reader.ReadAsync(cancellationToken).ConfigureAwait(false); - return new MockStream(streamState, false); + return new MockStream(this, streamState, false); } catch (ChannelClosedException) { @@ -251,6 +324,14 @@ private void Dispose(bool disposing) Channel streamChannel = _isClient ? state._clientInitiatedStreamChannel : state._serverInitiatedStreamChannel; streamChannel.Writer.Complete(); } + + + PeerStreamLimit? streamLimit = LocalStreamLimit; + if (streamLimit is not null) + { + streamLimit.Unidirectional.CloseWaiters(); + streamLimit.Bidirectional.CloseWaiters(); + } } // TODO: free unmanaged resources (unmanaged objects) and override a finalizer below. @@ -271,11 +352,77 @@ public override void Dispose() GC.SuppressFinalize(this); } + internal sealed class StreamLimit + { + public readonly int MaxCount; + + private int _actualCount; + // Since this is mock, we don't need to be conservative with the allocations. + // We keep the TCSes allocated all the time for the simplicity of the code. + private TaskCompletionSource _availableTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly object _syncRoot = new object(); + + public StreamLimit(int maxCount) + { + MaxCount = maxCount; + } + + public int AvailableCount => MaxCount - _actualCount; + + public void Decrement() + { + lock (_syncRoot) + { + --_actualCount; + if (!_availableTcs.Task.IsCompleted) + { + _availableTcs.SetResult(); + _availableTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + } + } + } + + public bool TryIncrement() + { + lock (_syncRoot) + { + if (_actualCount < MaxCount) + { + ++_actualCount; + return true; + } + return false; + } + } + + public ValueTask WaitForAvailableStreams(CancellationToken cancellationToken) + => new ValueTask(_availableTcs.Task.WaitAsync(cancellationToken)); + + public void CloseWaiters() + => _availableTcs.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException())); + } + + internal class PeerStreamLimit + { + public readonly StreamLimit Unidirectional; + public readonly StreamLimit Bidirectional; + + public PeerStreamLimit(int maxUnidirectional, int maxBidirectional) + { + Unidirectional = new StreamLimit(maxUnidirectional); + Bidirectional = new StreamLimit(maxBidirectional); + } + } + internal sealed class ConnectionState { public readonly SslApplicationProtocol _applicationProtocol; public Channel _clientInitiatedStreamChannel; public Channel _serverInitiatedStreamChannel; + + public PeerStreamLimit? _clientStreamLimit; + public PeerStreamLimit? _serverStreamLimit; + public long _clientErrorCode; public long _serverErrorCode; public bool _closed; diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockImplementationProvider.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockImplementationProvider.cs index 03b53613f1514..a46b1691bed6e 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockImplementationProvider.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockImplementationProvider.cs @@ -16,7 +16,11 @@ internal override QuicListenerProvider CreateListener(QuicListenerOptions option internal override QuicConnectionProvider CreateConnection(QuicClientConnectionOptions options) { - return new MockConnection(options.RemoteEndPoint, options.ClientAuthenticationOptions, options.LocalEndPoint); + return new MockConnection(options.RemoteEndPoint, + options.ClientAuthenticationOptions, + options.LocalEndPoint, + options.MaxUnidirectionalStreams, + options.MaxBidirectionalStreams); } } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockListener.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockListener.cs index 826746ad69700..48ebf8a06035d 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockListener.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockListener.cs @@ -69,6 +69,7 @@ internal override async ValueTask AcceptConnectionAsync( // Returns false if backlog queue is full. internal bool TryConnect(MockConnection.ConnectionState state) { + state._serverStreamLimit = new MockConnection.PeerStreamLimit(_options.MaxUnidirectionalStreams, _options.MaxBidirectionalStreams); return _listenQueue.Writer.TryWrite(state); } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs index 14ecead9a7f88..ec02d88afae91 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs @@ -14,11 +14,13 @@ internal sealed class MockStream : QuicStreamProvider { private bool _disposed; private readonly bool _isInitiator; + private readonly MockConnection _connection; private readonly StreamState _streamState; - internal MockStream(StreamState streamState, bool isInitiator) + internal MockStream(MockConnection connection, StreamState streamState, bool isInitiator) { + _connection = connection; _streamState = streamState; _isInitiator = isInitiator; } @@ -170,7 +172,6 @@ internal override void AbortWrite(long errorCode) WriteStreamBuffer?.EndWrite(); } - internal override ValueTask ShutdownWriteCompleted(CancellationToken cancellationToken = default) { CheckDisposed(); @@ -192,6 +193,15 @@ internal override void Shutdown() // This seems to mean shutdown send, in particular, not both. WriteStreamBuffer?.EndWrite(); + + if (_streamState._inboundStreamBuffer is null) // unidirectional stream + { + _connection.LocalStreamLimit!.Unidirectional.Decrement(); + } + else + { + _connection.LocalStreamLimit!.Bidirectional.Decrement(); + } } private void CheckDisposed() diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs index 66c8add4be60e..76b3693fd1fa5 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs @@ -51,6 +51,12 @@ internal sealed class State public readonly TaskCompletionSource ConnectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); public readonly TaskCompletionSource ShutdownTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + // Note that there's no such thing as resetable TCS, so we cannot reuse the same instance after we've set the result. + // We also cannot use solutions like ManualResetValueTaskSourceCore, since we can have multiple waiters on the same TCS. + // As a result, we allocate a new TCS when needed, which is when someone explicitely asks for them in WaitForAvailableStreamsAsync. + public TaskCompletionSource? NewUnidirectionalStreamsAvailable; + public TaskCompletionSource? NewBidirectionalStreamsAvailable; + public bool Connected; public long AbortErrorCode = -1; @@ -192,6 +198,26 @@ private static uint HandleEventShutdownComplete(State state, ref ConnectionEvent // Stop accepting new streams. state.AcceptQueue.Writer.Complete(); + + // Stop notifying about available streams. + TaskCompletionSource? unidirectionalTcs = null; + TaskCompletionSource? bidirectionalTcs = null; + lock (state) + { + unidirectionalTcs = state.NewBidirectionalStreamsAvailable; + bidirectionalTcs = state.NewBidirectionalStreamsAvailable; + state.NewUnidirectionalStreamsAvailable = null; + state.NewBidirectionalStreamsAvailable = null; + } + + if (unidirectionalTcs is not null) + { + unidirectionalTcs.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException())); + } + if (bidirectionalTcs is not null) + { + bidirectionalTcs.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException())); + } return MsQuicStatusCodes.Success; } @@ -206,6 +232,32 @@ private static uint HandleEventNewStream(State state, ref ConnectionEvent connec private static uint HandleEventStreamsAvailable(State state, ref ConnectionEvent connectionEvent) { + TaskCompletionSource? unidirectionalTcs = null; + TaskCompletionSource? bidirectionalTcs = null; + lock (state) + { + if (connectionEvent.Data.StreamsAvailable.UniDirectionalCount > 0) + { + unidirectionalTcs = state.NewUnidirectionalStreamsAvailable; + state.NewUnidirectionalStreamsAvailable = null; + } + + if (connectionEvent.Data.StreamsAvailable.BiDirectionalCount > 0) + { + bidirectionalTcs = state.NewBidirectionalStreamsAvailable; + state.NewBidirectionalStreamsAvailable = null; + } + } + + if (unidirectionalTcs is not null) + { + unidirectionalTcs.SetResult(); + } + if (bidirectionalTcs is not null) + { + bidirectionalTcs.SetResult(); + } + return MsQuicStatusCodes.Success; } @@ -329,24 +381,82 @@ internal override async ValueTask AcceptStreamAsync(Cancella return stream; } + internal override ValueTask WaitForAvailableUnidirectionalStreamsAsync(CancellationToken cancellationToken = default) + { + TaskCompletionSource? tcs = _state.NewUnidirectionalStreamsAvailable; + if (tcs is null) + { + lock (_state) + { + if (_state.NewUnidirectionalStreamsAvailable is null) + { + if (_state.ShutdownTcs.Task.IsCompleted) + { + throw new QuicOperationAbortedException(); + } + + if (GetRemoteAvailableUnidirectionalStreamCount() > 0) + { + return ValueTask.CompletedTask; + } + + _state.NewUnidirectionalStreamsAvailable = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + } + tcs = _state.NewUnidirectionalStreamsAvailable; + } + } + + return new ValueTask(tcs.Task.WaitAsync(cancellationToken)); + } + + internal override ValueTask WaitForAvailableBidirectionalStreamsAsync(CancellationToken cancellationToken = default) + { + TaskCompletionSource? tcs = _state.NewBidirectionalStreamsAvailable; + if (tcs is null) + { + lock (_state) + { + if (_state.NewBidirectionalStreamsAvailable is null) + { + if (_state.ShutdownTcs.Task.IsCompleted) + { + throw new QuicOperationAbortedException(); + } + + if (GetRemoteAvailableBidirectionalStreamCount() > 0) + { + return ValueTask.CompletedTask; + } + + _state.NewBidirectionalStreamsAvailable = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + } + tcs = _state.NewBidirectionalStreamsAvailable; + } + } + + return new ValueTask(tcs.Task.WaitAsync(cancellationToken)); + } + internal override QuicStreamProvider OpenUnidirectionalStream() { ThrowIfDisposed(); + return new MsQuicStream(_state, QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL); } internal override QuicStreamProvider OpenBidirectionalStream() { ThrowIfDisposed(); + return new MsQuicStream(_state, QUIC_STREAM_OPEN_FLAGS.NONE); } - internal override long GetRemoteAvailableUnidirectionalStreamCount() + internal override int GetRemoteAvailableUnidirectionalStreamCount() { return MsQuicParameterHelpers.GetUShortParam(MsQuicApi.Api, _state.Handle, QUIC_PARAM_LEVEL.CONNECTION, (uint)QUIC_PARAM_CONN.LOCAL_UNIDI_STREAM_COUNT); } - internal override long GetRemoteAvailableBidirectionalStreamCount() + internal override int GetRemoteAvailableBidirectionalStreamCount() { return MsQuicParameterHelpers.GetUShortParam(MsQuicApi.Api, _state.Handle, QUIC_PARAM_LEVEL.CONNECTION, (uint)QUIC_PARAM_CONN.LOCAL_BIDI_STREAM_COUNT); } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs index 6102235f7e1fe..c270d1242295b 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs @@ -64,7 +64,6 @@ private sealed class State // Set once writes have been shutdown. public readonly TaskCompletionSource ShutdownWriteCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - public ShutdownState ShutdownState; // Set once stream have been shutdown. @@ -124,7 +123,7 @@ internal MsQuicStream(MsQuicConnection.State connectionState, QUIC_STREAM_OPEN_F QuicExceptionHelpers.ThrowIfFailed(status, "Failed to open stream to peer."); - status = MsQuicApi.Api.StreamStartDelegate(_state.Handle, QUIC_STREAM_START_FLAGS.ASYNC); + status = MsQuicApi.Api.StreamStartDelegate(_state.Handle, QUIC_STREAM_START_FLAGS.FAIL_BLOCKED); QuicExceptionHelpers.ThrowIfFailed(status, "Could not start stream."); } catch @@ -490,6 +489,7 @@ internal override async ValueTask ShutdownCompleted(CancellationToken cancellati internal override void Shutdown() { ThrowIfDisposed(); + // it is ok to send shutdown several times, MsQuic will ignore it StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0); } @@ -590,7 +590,7 @@ private static uint HandleEvent(State state, ref StreamEvent evt) // Stream has started. // Will only be done for outbound streams (inbound streams have already started) case QUIC_STREAM_EVENT_TYPE.START_COMPLETE: - return HandleStartComplete(state); + return HandleEventStartComplete(state); // Received data on the stream case QUIC_STREAM_EVENT_TYPE.RECEIVE: return HandleEventRecv(state, ref evt); @@ -676,7 +676,7 @@ private static uint HandleEventPeerRecvAborted(State state, ref StreamEvent evt) return MsQuicStatusCodes.Success; } - private static uint HandleStartComplete(State state) + private static uint HandleEventStartComplete(State state) { bool shouldComplete = false; lock (state) diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicConnectionProvider.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicConnectionProvider.cs index 9425833413589..e5153c3fae4b2 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicConnectionProvider.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicConnectionProvider.cs @@ -16,13 +16,17 @@ internal abstract class QuicConnectionProvider : IDisposable internal abstract ValueTask ConnectAsync(CancellationToken cancellationToken = default); + internal abstract ValueTask WaitForAvailableUnidirectionalStreamsAsync(CancellationToken cancellationToken = default); + + internal abstract ValueTask WaitForAvailableBidirectionalStreamsAsync(CancellationToken cancellationToken = default); + internal abstract QuicStreamProvider OpenUnidirectionalStream(); internal abstract QuicStreamProvider OpenBidirectionalStream(); - internal abstract long GetRemoteAvailableUnidirectionalStreamCount(); + internal abstract int GetRemoteAvailableUnidirectionalStreamCount(); - internal abstract long GetRemoteAvailableBidirectionalStreamCount(); + internal abstract int GetRemoteAvailableBidirectionalStreamCount(); internal abstract ValueTask AcceptStreamAsync(CancellationToken cancellationToken = default); diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs index fd21c4116f7b6..e91913f9d7bac 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs @@ -67,6 +67,18 @@ internal QuicConnection(QuicConnectionProvider provider) /// public ValueTask ConnectAsync(CancellationToken cancellationToken = default) => _provider.ConnectAsync(cancellationToken); + /// + /// Waits for available unidirectional stream capacity to be announced by the peer. If any capacity is available, returns immediately. + /// + /// + public ValueTask WaitForAvailableUnidirectionalStreamsAsync(CancellationToken cancellationToken = default) => _provider.WaitForAvailableUnidirectionalStreamsAsync(cancellationToken); + + /// + /// Waits for available bidirectional stream capacity to be announced by the peer. If any capacity is available, returns immediately. + /// + /// + public ValueTask WaitForAvailableBidirectionalStreamsAsync(CancellationToken cancellationToken = default) => _provider.WaitForAvailableBidirectionalStreamsAsync(cancellationToken); + /// /// Create an outbound unidirectional stream. /// @@ -95,11 +107,11 @@ internal QuicConnection(QuicConnectionProvider provider) /// /// Gets the maximum number of bidirectional streams that can be made to the peer. /// - public long GetRemoteAvailableUnidirectionalStreamCount() => _provider.GetRemoteAvailableUnidirectionalStreamCount(); + public int GetRemoteAvailableUnidirectionalStreamCount() => _provider.GetRemoteAvailableUnidirectionalStreamCount(); /// /// Gets the maximum number of unidirectional streams that can be made to the peer. /// - public long GetRemoteAvailableBidirectionalStreamCount() => _provider.GetRemoteAvailableBidirectionalStreamCount(); + public int GetRemoteAvailableBidirectionalStreamCount() => _provider.GetRemoteAvailableBidirectionalStreamCount(); } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicOptions.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicOptions.cs index 86dd644aaac1c..3d02ee3fd0199 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicOptions.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicOptions.cs @@ -19,14 +19,14 @@ public class QuicOptions /// Default is 100. /// // TODO consider constraining these limits to 0 to whatever the max of the QUIC library we are using. - public long MaxBidirectionalStreams { get; set; } = 100; + public int MaxBidirectionalStreams { get; set; } = 100; /// /// Limit on the number of unidirectional streams the remote peer connection can create on an open connection. /// Default is 100. /// // TODO consider constraining these limits to 0 to whatever the max of the QUIC library we are using. - public long MaxUnidirectionalStreams { get; set; } = 100; + public int MaxUnidirectionalStreams { get; set; } = 100; /// /// Idle timeout for connections, after which the connection will be closed. diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs index 0e2a227b0ef1d..443f759d0fbe8 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs @@ -95,6 +95,56 @@ public async Task ConnectWithCertificateChain() await clientTask; } + [Fact] + [ActiveIssue("https://github.com/dotnet/runtime/issues/52048")] + public async Task WaitForAvailableUnidirectionStreamsAsyncWorks() + { + using QuicListener listener = CreateQuicListener(maxUnidirectionalStreams: 1); + using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); + + ValueTask clientTask = clientConnection.ConnectAsync(); + using QuicConnection serverConnection = await listener.AcceptConnectionAsync(); + await clientTask; + + // No stream openned yet, should return immediately. + Assert.True(clientConnection.WaitForAvailableUnidirectionalStreamsAsync().IsCompletedSuccessfully); + + // Open one stream, should wait till it closes. + QuicStream stream = clientConnection.OpenUnidirectionalStream(); + ValueTask waitTask = clientConnection.WaitForAvailableUnidirectionalStreamsAsync(); + Assert.False(waitTask.IsCompleted); + Assert.Throws(() => clientConnection.OpenUnidirectionalStream()); + + // Close the stream, the waitTask should finish as a result. + stream.Dispose(); + await waitTask.AsTask().WaitAsync(TimeSpan.FromSeconds(10)); + } + + [Fact] + [ActiveIssue("https://github.com/dotnet/runtime/issues/52048")] + public async Task WaitForAvailableBidirectionStreamsAsyncWorks() + { + using QuicListener listener = CreateQuicListener(maxBidirectionalStreams: 1); + using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); + + ValueTask clientTask = clientConnection.ConnectAsync(); + using QuicConnection serverConnection = await listener.AcceptConnectionAsync(); + await clientTask; + + // No stream openned yet, should return immediately. + Assert.True(clientConnection.WaitForAvailableBidirectionalStreamsAsync().IsCompletedSuccessfully); + + // Open one stream, should wait till it closes. + QuicStream stream = clientConnection.OpenBidirectionalStream(); + ValueTask waitTask = clientConnection.WaitForAvailableBidirectionalStreamsAsync(); + Assert.False(waitTask.IsCompleted); + Assert.Throws(() => clientConnection.OpenBidirectionalStream()); + + // Close the stream, the waitTask should finish as a result. + stream.Dispose(); + await waitTask.AsTask().WaitAsync(TimeSpan.FromSeconds(10)); + } + [Fact] [OuterLoop("May take several seconds")] public async Task SetListenerTimeoutWorksWithSmallTimeout() @@ -234,7 +284,7 @@ public async Task CallDifferentWriteMethodsWorks() int res = await serverStream.ReadAsync(memory); Assert.Equal(12, res); ReadOnlyMemory> romrom = new ReadOnlyMemory>(new ReadOnlyMemory[] { helloWorld, helloWorld }); - + await clientStream.WriteAsync(romrom); res = await serverStream.ReadAsync(memory); @@ -254,7 +304,7 @@ await RunClientServer( { var acceptTask = serverConnection.AcceptStreamAsync(); await serverConnection.CloseAsync(errorCode: 0); - // make sure + // make sure await Assert.ThrowsAsync(() => acceptTask.AsTask()); }); } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamConnectedStreamConformanceTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamConnectedStreamConformanceTests.cs index ad7b74c12e887..6c3670bbc30bb 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamConnectedStreamConformanceTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamConnectedStreamConformanceTests.cs @@ -92,7 +92,7 @@ public SslServerAuthenticationOptions GetSslServerAuthenticationOptions() ServerCertificate = System.Net.Test.Common.Configuration.Certificates.GetServerCertificate() }; } - + protected abstract QuicImplementationProvider Provider { get; } protected override async Task CreateConnectedStreamsAsync() diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs index 027d0adb258ad..ee7501868beba 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs @@ -53,16 +53,30 @@ internal QuicConnection CreateQuicConnection(IPEndPoint endpoint) return new QuicConnection(ImplementationProvider, endpoint, GetSslClientAuthenticationOptions()); } - internal QuicListener CreateQuicListener() + internal QuicListener CreateQuicListener(int maxUnidirectionalStreams = 100, int maxBidirectionalStreams = 100) { - return CreateQuicListener(new IPEndPoint(IPAddress.Loopback, 0)); + var options = new QuicListenerOptions() + { + ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0), + ServerAuthenticationOptions = GetSslServerAuthenticationOptions(), + MaxUnidirectionalStreams = maxUnidirectionalStreams, + MaxBidirectionalStreams = maxBidirectionalStreams + }; + return CreateQuicListener(options); } internal QuicListener CreateQuicListener(IPEndPoint endpoint) { - return new QuicListener(ImplementationProvider, endpoint, GetSslServerAuthenticationOptions()); + var options = new QuicListenerOptions() + { + ListenEndPoint = endpoint, + ServerAuthenticationOptions = GetSslServerAuthenticationOptions() + }; + return CreateQuicListener(options); } + private QuicListener CreateQuicListener(QuicListenerOptions options) => new QuicListener(ImplementationProvider, options); + internal async Task RunClientServer(Func clientFunction, Func serverFunction, int iterations = 1, int millisecondsTimeout = 10_000) { using QuicListener listener = CreateQuicListener();