Skip to content

Commit

Permalink
Merge pull request #1426 from rabbitmq/rabbitmq-dotnet-client-1425
Browse files Browse the repository at this point in the history
Ensure that the underlying timer for `Task.Delay` is canceled.
  • Loading branch information
lukebakken authored Nov 22, 2023
2 parents 10a3499 + ef1c97a commit 5b2cd95
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 72 deletions.
114 changes: 83 additions & 31 deletions projects/RabbitMQ.Client/client/TaskExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ internal static class TaskExtensions
{
#if !NET6_0_OR_GREATER
private static readonly TaskContinuationOptions s_tco = TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously;
private static void continuation(Task t, object s) => t.Exception.Handle(e => true);
private static void IgnoreTaskContinuation(Task t, object s) => t.Exception.Handle(e => true);
#endif

public static Task TimeoutAfter(this Task task, TimeSpan timeout)
Expand All @@ -59,60 +59,112 @@ public static Task TimeoutAfter(this Task task, TimeSpan timeout)

return DoTimeoutAfter(task, timeout);

// https://github.com/davidfowl/AspNetCoreDiagnosticScenarios/blob/master/AsyncGuidance.md#using-a-timeout
static async Task DoTimeoutAfter(Task task, TimeSpan timeout)
{
if (task == await Task.WhenAny(task, Task.Delay(timeout)).ConfigureAwait(false))
using (var cts = new CancellationTokenSource())
{
Task delayTask = Task.Delay(timeout, cts.Token);
Task resultTask = await Task.WhenAny(task, delayTask).ConfigureAwait(false);
if (resultTask == delayTask)
{
task.Ignore();
throw new TimeoutException();
}
else
{
cts.Cancel();
}

await task.ConfigureAwait(false);
}
else
{
Task supressErrorTask = task.ContinueWith(
continuationAction: continuation,
state: null,
cancellationToken: CancellationToken.None,
continuationOptions: s_tco,
scheduler: TaskScheduler.Default);
throw new TimeoutException();
}
}
#endif
}

public static async ValueTask TimeoutAfter(this ValueTask task, TimeSpan timeout)
public static async ValueTask TimeoutAfter(this ValueTask valueTask, TimeSpan timeout)
{
if (task.IsCompletedSuccessfully)
if (valueTask.IsCompletedSuccessfully)
{
return;
}

#if NET6_0_OR_GREATER
Task actualTask = task.AsTask();
await actualTask.WaitAsync(timeout)
Task task = valueTask.AsTask();
await task.WaitAsync(timeout)
.ConfigureAwait(false);
#else
await DoTimeoutAfter(task, timeout)
await DoTimeoutAfter(valueTask, timeout)
.ConfigureAwait(false);

async static ValueTask DoTimeoutAfter(ValueTask task, TimeSpan timeout)
// https://github.com/davidfowl/AspNetCoreDiagnosticScenarios/blob/master/AsyncGuidance.md#using-a-timeout
static async ValueTask DoTimeoutAfter(ValueTask valueTask, TimeSpan timeout)
{
Task actualTask = task.AsTask();
if (actualTask == await Task.WhenAny(actualTask, Task.Delay(timeout)).ConfigureAwait(false))
Task task = valueTask.AsTask();
using (var cts = new CancellationTokenSource())
{
await actualTask.ConfigureAwait(false);
}
else
{
Task supressErrorTask = actualTask.ContinueWith(
continuationAction: continuation,
state: null,
cancellationToken: CancellationToken.None,
continuationOptions: s_tco,
scheduler: TaskScheduler.Default);
throw new TimeoutException();
Task delayTask = Task.Delay(timeout, cts.Token);
Task resultTask = await Task.WhenAny(task, delayTask).ConfigureAwait(false);
if (resultTask == delayTask)
{
task.Ignore();
throw new TimeoutException();
}
else
{
cts.Cancel();
}

await valueTask.ConfigureAwait(false);
}
}
#endif
}

/*
* https://devblogs.microsoft.com/dotnet/configureawait-faq/
* I'm using GetAwaiter().GetResult(). Do I need to use ConfigureAwait(false)?
* Answer: No
*/
public static void EnsureCompleted(this Task task)
{
task.GetAwaiter().GetResult();
}

public static T EnsureCompleted<T>(this Task<T> task)
{
return task.GetAwaiter().GetResult();
}

public static T EnsureCompleted<T>(this ValueTask<T> task)
{
return task.GetAwaiter().GetResult();
}

public static void EnsureCompleted(this ValueTask task)
{
task.GetAwaiter().GetResult();
}

#if !NET6_0_OR_GREATER
// https://github.com/dotnet/runtime/issues/23878
// https://github.com/dotnet/runtime/issues/23878#issuecomment-1398958645
public static void Ignore(this Task task)
{
if (task.IsCompleted)
{
_ = task.Exception;
}
else
{
_ = task.ContinueWith(
continuationAction: IgnoreTaskContinuation,
state: null,
cancellationToken: CancellationToken.None,
continuationOptions: s_tco,
scheduler: TaskScheduler.Default);
}
}
#endif
}
}
2 changes: 2 additions & 0 deletions projects/RabbitMQ.Client/client/api/ITcpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ public interface ITcpClient : IDisposable

Socket Client { get; }

// TODO CancellationToken
Task ConnectAsync(string host, int port);
// TODO CancellationToken
Task ConnectAsync(IPAddress host, int port);

NetworkStream GetStream();
Expand Down
5 changes: 3 additions & 2 deletions projects/RabbitMQ.Client/client/api/TcpClientAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ await ConnectAsync(ep, port)
.ConfigureAwait(false);
}

public virtual Task ConnectAsync(IPAddress ep, int port)
public virtual async Task ConnectAsync(IPAddress ep, int port)
{
AssertSocket();
return _sock.ConnectAsync(ep, port);
await _sock.ConnectAsync(ep, port)
.ConfigureAwait(false);
}

public virtual void Close()
Expand Down
8 changes: 2 additions & 6 deletions projects/RabbitMQ.Client/client/impl/AsyncRpcContinuations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

using System;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using RabbitMQ.Client.client.framing;
Expand All @@ -44,7 +43,6 @@ namespace RabbitMQ.Client.Impl
internal abstract class AsyncRpcContinuation<T> : IRpcContinuation, IDisposable
{
private readonly CancellationTokenSource _cancellationTokenSource;
private readonly ConfiguredTaskAwaitable<T> _taskAwaitable;

protected readonly TaskCompletionSource<T> _tcs = new TaskCompletionSource<T>(TaskCreationOptions.RunContinuationsAsynchronously);

Expand All @@ -68,13 +66,11 @@ public AsyncRpcContinuation(TimeSpan continuationTimeout)
// in the same manner as BlockingCell?
}
}, useSynchronizationContext: false);

_taskAwaitable = _tcs.Task.ConfigureAwait(false);
}

public ConfiguredTaskAwaitable<T>.ConfiguredTaskAwaiter GetAwaiter()
public Task<T> WaitAsync()
{
return _taskAwaitable.GetAwaiter();
return _tcs.Task;
}

public abstract void HandleCommand(in IncomingCommand cmd);
Expand Down
Loading

0 comments on commit 5b2cd95

Please sign in to comment.