From ef1c97a957f488559b7167db4d020a0f65edede6 Mon Sep 17 00:00:00 2001 From: Luke Bakken Date: Mon, 20 Nov 2023 15:56:21 -0800 Subject: [PATCH] Ensure that the underlying timer for `Task.Delay` is canceled. Part of the fix to #1425 * Add EnsureCompleted task extensions. * Don't use compiler shenanigans for AsyncRpcContinuation * Add Ignore() extension to ignore the results of a Task --- .../RabbitMQ.Client/client/TaskExtensions.cs | 114 +++++++++++++----- .../RabbitMQ.Client/client/api/ITcpClient.cs | 2 + .../client/api/TcpClientAdapter.cs | 5 +- .../client/impl/AsyncRpcContinuations.cs | 8 +- .../client/impl/ChannelBase.cs | 64 ++++++---- .../RabbitMQ.Client/client/impl/Connection.cs | 5 +- .../client/impl/SocketFrameHandler.cs | 9 +- .../RabbitMQ.Client/client/impl/SslHelper.cs | 4 +- .../TestConnectionRecoveryBase.cs | 3 +- 9 files changed, 142 insertions(+), 72 deletions(-) diff --git a/projects/RabbitMQ.Client/client/TaskExtensions.cs b/projects/RabbitMQ.Client/client/TaskExtensions.cs index cd1daf9e8e..b4e06dd3ea 100644 --- a/projects/RabbitMQ.Client/client/TaskExtensions.cs +++ b/projects/RabbitMQ.Client/client/TaskExtensions.cs @@ -39,7 +39,7 @@ internal static class TaskExtensions { #if !NET6_0_OR_GREATER private static readonly TaskContinuationOptions s_tco = TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously; - private static void continuation(Task t, object s) => t.Exception.Handle(e => true); + private static void IgnoreTaskContinuation(Task t, object s) => t.Exception.Handle(e => true); #endif public static Task TimeoutAfter(this Task task, TimeSpan timeout) @@ -59,60 +59,112 @@ public static Task TimeoutAfter(this Task task, TimeSpan timeout) return DoTimeoutAfter(task, timeout); + // https://github.com/davidfowl/AspNetCoreDiagnosticScenarios/blob/master/AsyncGuidance.md#using-a-timeout static async Task DoTimeoutAfter(Task task, TimeSpan timeout) { - if (task == await Task.WhenAny(task, Task.Delay(timeout)).ConfigureAwait(false)) + using (var cts = new CancellationTokenSource()) { + Task delayTask = Task.Delay(timeout, cts.Token); + Task resultTask = await Task.WhenAny(task, delayTask).ConfigureAwait(false); + if (resultTask == delayTask) + { + task.Ignore(); + throw new TimeoutException(); + } + else + { + cts.Cancel(); + } + await task.ConfigureAwait(false); } - else - { - Task supressErrorTask = task.ContinueWith( - continuationAction: continuation, - state: null, - cancellationToken: CancellationToken.None, - continuationOptions: s_tco, - scheduler: TaskScheduler.Default); - throw new TimeoutException(); - } } #endif } - public static async ValueTask TimeoutAfter(this ValueTask task, TimeSpan timeout) + public static async ValueTask TimeoutAfter(this ValueTask valueTask, TimeSpan timeout) { - if (task.IsCompletedSuccessfully) + if (valueTask.IsCompletedSuccessfully) { return; } #if NET6_0_OR_GREATER - Task actualTask = task.AsTask(); - await actualTask.WaitAsync(timeout) + Task task = valueTask.AsTask(); + await task.WaitAsync(timeout) .ConfigureAwait(false); #else - await DoTimeoutAfter(task, timeout) + await DoTimeoutAfter(valueTask, timeout) .ConfigureAwait(false); - async static ValueTask DoTimeoutAfter(ValueTask task, TimeSpan timeout) + // https://github.com/davidfowl/AspNetCoreDiagnosticScenarios/blob/master/AsyncGuidance.md#using-a-timeout + static async ValueTask DoTimeoutAfter(ValueTask valueTask, TimeSpan timeout) { - Task actualTask = task.AsTask(); - if (actualTask == await Task.WhenAny(actualTask, Task.Delay(timeout)).ConfigureAwait(false)) + Task task = valueTask.AsTask(); + using (var cts = new CancellationTokenSource()) { - await actualTask.ConfigureAwait(false); - } - else - { - Task supressErrorTask = actualTask.ContinueWith( - continuationAction: continuation, - state: null, - cancellationToken: CancellationToken.None, - continuationOptions: s_tco, - scheduler: TaskScheduler.Default); - throw new TimeoutException(); + Task delayTask = Task.Delay(timeout, cts.Token); + Task resultTask = await Task.WhenAny(task, delayTask).ConfigureAwait(false); + if (resultTask == delayTask) + { + task.Ignore(); + throw new TimeoutException(); + } + else + { + cts.Cancel(); + } + + await valueTask.ConfigureAwait(false); } } #endif } + + /* + * https://devblogs.microsoft.com/dotnet/configureawait-faq/ + * I'm using GetAwaiter().GetResult(). Do I need to use ConfigureAwait(false)? + * Answer: No + */ + public static void EnsureCompleted(this Task task) + { + task.GetAwaiter().GetResult(); + } + + public static T EnsureCompleted(this Task task) + { + return task.GetAwaiter().GetResult(); + } + + public static T EnsureCompleted(this ValueTask task) + { + return task.GetAwaiter().GetResult(); + } + + public static void EnsureCompleted(this ValueTask task) + { + task.GetAwaiter().GetResult(); + } + +#if !NET6_0_OR_GREATER + // https://github.com/dotnet/runtime/issues/23878 + // https://github.com/dotnet/runtime/issues/23878#issuecomment-1398958645 + public static void Ignore(this Task task) + { + if (task.IsCompleted) + { + _ = task.Exception; + } + else + { + _ = task.ContinueWith( + continuationAction: IgnoreTaskContinuation, + state: null, + cancellationToken: CancellationToken.None, + continuationOptions: s_tco, + scheduler: TaskScheduler.Default); + } + } +#endif } } diff --git a/projects/RabbitMQ.Client/client/api/ITcpClient.cs b/projects/RabbitMQ.Client/client/api/ITcpClient.cs index 6b55246297..cc26560f18 100644 --- a/projects/RabbitMQ.Client/client/api/ITcpClient.cs +++ b/projects/RabbitMQ.Client/client/api/ITcpClient.cs @@ -18,7 +18,9 @@ public interface ITcpClient : IDisposable Socket Client { get; } + // TODO CancellationToken Task ConnectAsync(string host, int port); + // TODO CancellationToken Task ConnectAsync(IPAddress host, int port); NetworkStream GetStream(); diff --git a/projects/RabbitMQ.Client/client/api/TcpClientAdapter.cs b/projects/RabbitMQ.Client/client/api/TcpClientAdapter.cs index 0d0b3797c8..3de1b792d9 100644 --- a/projects/RabbitMQ.Client/client/api/TcpClientAdapter.cs +++ b/projects/RabbitMQ.Client/client/api/TcpClientAdapter.cs @@ -34,10 +34,11 @@ await ConnectAsync(ep, port) .ConfigureAwait(false); } - public virtual Task ConnectAsync(IPAddress ep, int port) + public virtual async Task ConnectAsync(IPAddress ep, int port) { AssertSocket(); - return _sock.ConnectAsync(ep, port); + await _sock.ConnectAsync(ep, port) + .ConfigureAwait(false); } public virtual void Close() diff --git a/projects/RabbitMQ.Client/client/impl/AsyncRpcContinuations.cs b/projects/RabbitMQ.Client/client/impl/AsyncRpcContinuations.cs index 04e25b2e18..0e9b3b949e 100644 --- a/projects/RabbitMQ.Client/client/impl/AsyncRpcContinuations.cs +++ b/projects/RabbitMQ.Client/client/impl/AsyncRpcContinuations.cs @@ -31,7 +31,6 @@ using System; using System.Diagnostics; -using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using RabbitMQ.Client.client.framing; @@ -44,7 +43,6 @@ namespace RabbitMQ.Client.Impl internal abstract class AsyncRpcContinuation : IRpcContinuation, IDisposable { private readonly CancellationTokenSource _cancellationTokenSource; - private readonly ConfiguredTaskAwaitable _taskAwaitable; protected readonly TaskCompletionSource _tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); @@ -68,13 +66,11 @@ public AsyncRpcContinuation(TimeSpan continuationTimeout) // in the same manner as BlockingCell? } }, useSynchronizationContext: false); - - _taskAwaitable = _tcs.Task.ConfigureAwait(false); } - public ConfiguredTaskAwaitable.ConfiguredTaskAwaiter GetAwaiter() + public Task WaitAsync() { - return _taskAwaitable.GetAwaiter(); + return _tcs.Task; } public abstract void HandleCommand(in IncomingCommand cmd); diff --git a/projects/RabbitMQ.Client/client/impl/ChannelBase.cs b/projects/RabbitMQ.Client/client/impl/ChannelBase.cs index 624fee7b7c..2f72e61266 100644 --- a/projects/RabbitMQ.Client/client/impl/ChannelBase.cs +++ b/projects/RabbitMQ.Client/client/impl/ChannelBase.cs @@ -253,7 +253,8 @@ await ModelSendAsync(method) .ConfigureAwait(false); } - bool result = await k; + bool result = await k.WaitAsync() + .ConfigureAwait(false); Debug.Assert(result); await ConsumerDispatcher.WaitForShutdownAsync() @@ -287,6 +288,7 @@ await ConsumerDispatcher.WaitForShutdownAsync() } } + // TODO cancellation tokens internal async ValueTask ConnectionOpenAsync(string virtualHost) { var m = new ConnectionOpen(virtualHost); @@ -317,7 +319,8 @@ await ModelSendAsync(method) // negotiation finishes } - return await k; + return await k.WaitAsync() + .ConfigureAwait(false); } finally { @@ -348,7 +351,8 @@ await ModelSendAsync(method) // negotiation finishes } - return await k; + return await k.WaitAsync() + .ConfigureAwait(false); } finally { @@ -383,7 +387,8 @@ await _rpcSemaphore.WaitAsync() await ModelSendAsync(method) .ConfigureAwait(false); - bool result = await k; + bool result = await k.WaitAsync() + .ConfigureAwait(false); Debug.Assert(result); return this; } @@ -1088,7 +1093,8 @@ await _rpcSemaphore.WaitAsync() await ModelSendAsync(method) .ConfigureAwait(false); - bool result = await k; + bool result = await k.WaitAsync() + .ConfigureAwait(false); Debug.Assert(result); return; } @@ -1162,7 +1168,8 @@ await _rpcSemaphore.WaitAsync() await ModelSendAsync(method) .ConfigureAwait(false); - return await k; + return await k.WaitAsync() + .ConfigureAwait(false); } finally { @@ -1202,7 +1209,8 @@ await _rpcSemaphore.WaitAsync() await ModelSendAsync(method) .ConfigureAwait(false); - return await k; + return await k.WaitAsync() + .ConfigureAwait(false); } finally { @@ -1304,7 +1312,8 @@ await _rpcSemaphore.WaitAsync() await ModelSendAsync(method) .ConfigureAwait(false); - bool result = await k; + bool result = await k.WaitAsync() + .ConfigureAwait(false); Debug.Assert(result); return; } @@ -1348,7 +1357,8 @@ await _rpcSemaphore.WaitAsync() await ModelSendAsync(method) .ConfigureAwait(false); - bool result = await k; + bool result = await k.WaitAsync() + .ConfigureAwait(false); Debug.Assert(result); return; @@ -1377,7 +1387,8 @@ await _rpcSemaphore.WaitAsync() await ModelSendAsync(method) .ConfigureAwait(false); - bool result = await k; + bool result = await k.WaitAsync() + .ConfigureAwait(false); Debug.Assert(result); return; } @@ -1410,7 +1421,8 @@ await _rpcSemaphore.WaitAsync() await ModelSendAsync(method) .ConfigureAwait(false); - bool result = await k; + bool result = await k.WaitAsync() + .ConfigureAwait(false); Debug.Assert(result); return; } @@ -1448,7 +1460,8 @@ await _rpcSemaphore.WaitAsync() await ModelSendAsync(method) .ConfigureAwait(false); - bool result = await k; + bool result = await k.WaitAsync() + .ConfigureAwait(false); Debug.Assert(result); return; } @@ -1481,7 +1494,8 @@ await _rpcSemaphore.WaitAsync() await ModelSendAsync(method) .ConfigureAwait(false); - bool result = await k; + bool result = await k.WaitAsync() + .ConfigureAwait(false); Debug.Assert(result); return; } @@ -1524,7 +1538,8 @@ await _rpcSemaphore.WaitAsync() await ModelSendAsync(method) .ConfigureAwait(false); - QueueDeclareOk result = await k; + QueueDeclareOk result = await k.WaitAsync() + .ConfigureAwait(false); if (false == passive) { CurrentQueue = result.QueueName; @@ -1550,7 +1565,8 @@ await _rpcSemaphore.WaitAsync() await ModelSendAsync(method) .ConfigureAwait(false); - bool result = await k; + bool result = await k.WaitAsync() + .ConfigureAwait(false); Debug.Assert(result); return; } @@ -1600,7 +1616,8 @@ await _rpcSemaphore.WaitAsync() await ModelSendAsync(method) .ConfigureAwait(false); - return await k; + return await k.WaitAsync() + .ConfigureAwait(false); } finally { @@ -1631,7 +1648,8 @@ await _rpcSemaphore.WaitAsync() await ModelSendAsync(method) .ConfigureAwait(false); - return await k; + return await k.WaitAsync() + .ConfigureAwait(false); } finally { @@ -1654,7 +1672,8 @@ await _rpcSemaphore.WaitAsync() await ModelSendAsync(method) .ConfigureAwait(false); - bool result = await k; + bool result = await k.WaitAsync() + .ConfigureAwait(false); Debug.Assert(result); return; } @@ -1679,7 +1698,8 @@ await _rpcSemaphore.WaitAsync() await ModelSendAsync(method) .ConfigureAwait(false); - bool result = await k; + bool result = await k.WaitAsync() + .ConfigureAwait(false); Debug.Assert(result); return; } @@ -1704,7 +1724,8 @@ await _rpcSemaphore.WaitAsync() await ModelSendAsync(method) .ConfigureAwait(false); - bool result = await k; + bool result = await k.WaitAsync() + .ConfigureAwait(false); Debug.Assert(result); return; } @@ -1729,7 +1750,8 @@ await _rpcSemaphore.WaitAsync() await ModelSendAsync(method) .ConfigureAwait(false); - bool result = await k; + bool result = await k.WaitAsync() + .ConfigureAwait(false); Debug.Assert(result); return; } diff --git a/projects/RabbitMQ.Client/client/impl/Connection.cs b/projects/RabbitMQ.Client/client/impl/Connection.cs index cdb379bc79..1237029a3b 100644 --- a/projects/RabbitMQ.Client/client/impl/Connection.cs +++ b/projects/RabbitMQ.Client/client/impl/Connection.cs @@ -213,8 +213,7 @@ internal void TakeOver(Connection other) internal IConnection Open() { - return OpenAsync() - .ConfigureAwait(false).GetAwaiter().GetResult(); + return OpenAsync().EnsureCompleted(); } internal async ValueTask OpenAsync() @@ -526,7 +525,7 @@ internal void Write(RentedMemory frames) ValueTask task = _frameHandler.WriteAsync(frames); if (!task.IsCompletedSuccessfully) { - task.ConfigureAwait(false).GetAwaiter().GetResult(); + task.EnsureCompleted(); } } diff --git a/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs b/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs index 0a2870e42d..b15287e001 100644 --- a/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs +++ b/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs @@ -183,7 +183,7 @@ public TimeSpan WriteTimeout public void Close() { - CloseAsync().ConfigureAwait(false).GetAwaiter().GetResult(); + CloseAsync().EnsureCompleted(); } public async ValueTask CloseAsync() @@ -334,12 +334,9 @@ private void ConnectOrFail(ITcpClient socket, IPEndPoint endpoint, TimeSpan time { try { + // this ensures exceptions aren't wrapped in an AggregateException socket.ConnectAsync(endpoint.Address, endpoint.Port) - .TimeoutAfter(timeout) - .ConfigureAwait(false) - // this ensures exceptions aren't wrapped in an AggregateException - .GetAwaiter() - .GetResult(); + .TimeoutAfter(timeout).EnsureCompleted(); } catch (ArgumentException e) { diff --git a/projects/RabbitMQ.Client/client/impl/SslHelper.cs b/projects/RabbitMQ.Client/client/impl/SslHelper.cs index 8f4b2825da..9052e382f8 100644 --- a/projects/RabbitMQ.Client/client/impl/SslHelper.cs +++ b/projects/RabbitMQ.Client/client/impl/SslHelper.cs @@ -53,6 +53,7 @@ private SslHelper(SslOption sslOption) /// /// Upgrade a Tcp stream to an Ssl stream using the TLS options provided. /// + // TODO async public static Stream TcpUpgrade(Stream tcpStream, SslOption options) { var helper = new SslHelper(options); @@ -67,8 +68,9 @@ public static Stream TcpUpgrade(Stream tcpStream, SslOption options) Action TryAuthenticating = (SslOption opts) => { sslStream.AuthenticateAsClientAsync(opts.ServerName, opts.Certs, opts.Version, - opts.CheckCertificateRevocation).GetAwaiter().GetResult(); + opts.CheckCertificateRevocation).EnsureCompleted(); }; + try { // TODO async diff --git a/projects/Test/SequentialIntegration/TestConnectionRecoveryBase.cs b/projects/Test/SequentialIntegration/TestConnectionRecoveryBase.cs index e4617bec97..117a1717db 100644 --- a/projects/Test/SequentialIntegration/TestConnectionRecoveryBase.cs +++ b/projects/Test/SequentialIntegration/TestConnectionRecoveryBase.cs @@ -255,8 +255,7 @@ internal void RestartServerAndWaitForRecovery(AutorecoveringConnection conn) protected bool WaitForConfirms(IChannel m) { using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(4)); - return m.WaitForConfirmsAsync(cts.Token) - .ConfigureAwait(false).GetAwaiter().GetResult(); + return m.WaitForConfirmsAsync(cts.Token).EnsureCompleted(); } protected void WaitForRecovery()