Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unblock subsequent requests when an async over sync Dns call runs for too long #92863

Closed
wants to merge 9 commits into from
147 changes: 112 additions & 35 deletions src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs
Original file line number Diff line number Diff line change
Expand Up @@ -633,72 +633,149 @@ private static bool LogFailure(object hostNameOrAddress, long? startingTimestamp
return false;
}

/// <summary>Mapping from key to current task in flight for that key.</summary>
private static readonly Dictionary<object, Task> s_tasks = new Dictionary<object, Task>();
/// <summary>Mapping from key to the head of the request queue for that key.</summary>
private static readonly Dictionary<object, DnsRequestWaiter> s_requestQueues = new();

/// <summary>
/// The maximum time a request can block subsequent requests in the queue.
/// </summary>
private static readonly TimeSpan s_maximumWaitTime = TimeSpan.FromSeconds(1);

/// <summary>Queue the function to be invoked asynchronously.</summary>
/// <remarks>
/// Since this is doing synchronous work on a thread pool thread, we want to limit how many threads end up being
/// blocked. We could employ a semaphore to limit overall usage, but a common case is that DNS requests are made
/// for only a handful of endpoints, and a reasonable compromise is to ensure that requests for a given host are
/// serialized. Once the data for that host is cached locally by the OS, the subsequent requests should all complete
/// very quickly, and if the head-of-line request is taking a long time due to the connection to the server, we won't
/// block lots of threads all getting data for that one host. We also still want to issue the request to the OS, rather
/// serialized. We also still want to issue the request to the OS, rather
/// than having all concurrent requests for the same host share the exact same task, so that any shuffling of the results
/// by the OS to enable round robin is still perceived.
/// </remarks>
private static Task<TResult> RunAsync<TResult>(Func<object, long, TResult> func, object key, CancellationToken cancellationToken)
{
long startingTimestamp = NameResolutionTelemetry.Log.BeforeResolution(key);
long startTimestamp = Stopwatch.GetTimestamp();
NameResolutionTelemetry.Log.BeforeResolution(key);

Task<TResult>? task = null;
DnsRequestWaiter current;
lock (s_requestQueues)
{
// Get the queue head for this key, if there are requests in flight.
if (s_requestQueues.TryGetValue(key, out DnsRequestWaiter? head))
{
DnsRequestWaiter? last = null;
DnsRequestWaiter? next = head;

lock (s_tasks)
while (next != null)
{
// Remove long-running requests from the queue and forward the head.
if (next.Elapsed(startTimestamp) > s_maximumWaitTime)
{
next.Complete();
}
last = next;
next = next.Next;
}
Debug.Assert(last is not null);
current = new DnsRequestWaiter(key, startTimestamp, last);

// If Complete() has cleared the head, make 'current' the new head.
if (!s_requestQueues.ContainsKey(key))
{
s_requestQueues[key] = current;
}
}
else
{
current = new DnsRequestWaiter(key, startTimestamp, null);
s_requestQueues[key] = current;
}
}

return current.Run(func, cancellationToken);
}

private sealed class DnsRequestWaiter : TaskCompletionSource
{
private long _startTimestamp;
private Task _previousTask;
public DnsRequestWaiter? Next;
private object _key;
private CancellationToken _cancellationToken;
private object? _func;

public DnsRequestWaiter(object key, long start, DnsRequestWaiter? previous)
{
// Get the previous task for this key, if there is one.
s_tasks.TryGetValue(key, out Task? prevTask);
prevTask ??= Task.CompletedTask;
_key = key;
_startTimestamp = start;
if (previous != null)
{
_previousTask = previous.Task;
previous.Next = this;
}
else
{
_previousTask = Task.CompletedTask;
}
}

// Invoke the function in a queued work item when the previous task completes. Note that some callers expect the
// returned task to have the key as the task's AsyncState.
task = prevTask.ContinueWith(delegate
public Task<TResult> Run<TResult>(Func<object, long, TResult> func, CancellationToken cancellationToken)
{
_cancellationToken = cancellationToken;
_func = func;
Task<TResult> task = _previousTask.ContinueWith(static (_, s) =>
{
Debug.Assert(!Monitor.IsEntered(s_tasks));
DnsRequestWaiter self = (DnsRequestWaiter)s!;
Debug.Assert(self._func is not null);
Func<object, long, TResult> func = (Func<object, long, TResult>)self._func;

try
{
return func(key, startingTimestamp);
using (self._cancellationToken.UnsafeRegister(s => ((DnsRequestWaiter)s!).Complete(), self))
{
return func(self._key, self._startTimestamp);
}
}
finally
{
// When the work is done, remove this key/task pair from the dictionary if this is still the current task.
// Because the work item is created and stored into both the local and the dictionary while the lock is
// held, and since we take the same lock here, inside this lock it's guaranteed to see the changes
// made by the call site.
lock (s_tasks)
{
((ICollection<KeyValuePair<object, Task>>)s_tasks).Remove(new KeyValuePair<object, Task>(key!, task!));
}
self.Complete();
}
}, key, cancellationToken, TaskContinuationOptions.DenyChildAttach, TaskScheduler.Default);
}, this, cancellationToken, TaskContinuationOptions.DenyChildAttach, TaskScheduler.Default);

// If it's possible the task may end up getting canceled, it won't have a chance to remove itself from
// the dictionary if it is canceled, so use a separate continuation to do so.
// If it's possible the task may end up getting canceled, it won't have a chance to call Complete() and AfterResolution()
// if it is canceled, so use a separate continuation to do so.
if (cancellationToken.CanBeCanceled)
{
task.ContinueWith((task, key) =>
_previousTask.ContinueWith(static (_, s) =>
{
lock (s_tasks)
DnsRequestWaiter self = (DnsRequestWaiter)s!;
self.Complete();
NameResolutionTelemetry.Log.AfterResolution(self._key, self._startTimestamp, false);
}, this, CancellationToken.None, TaskContinuationOptions.OnlyOnCanceled | TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default);
}

return task;
}

internal void Complete()
{
if (TrySetResult())
{
lock (s_requestQueues)
{
if (Next != null)
{
// Forward the head for this key to the next request.
s_requestQueues[_key] = Next;
}
else
{
((ICollection<KeyValuePair<object, Task>>)s_tasks).Remove(new KeyValuePair<object, Task>(key!, task));
// No more requests in flight, remove the key from s_requestQueues.
s_requestQueues.Remove(_key);
}
}, key, CancellationToken.None, TaskContinuationOptions.OnlyOnCanceled | TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default);
}
}

// Finally, store the task into the dictionary as the current task for this key.
s_tasks[key] = task;
}

return task;
public TimeSpan Elapsed(long currentTimestamp) => Stopwatch.GetElapsedTime(_startTimestamp, currentTimestamp);
}

private static SocketException CreateException(SocketError error, int nativeError) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public long BeforeResolution(object hostNameOrAddress)
public void AfterResolution(object hostNameOrAddress, long? startingTimestamp, bool successful)
{
Debug.Assert(startingTimestamp.HasValue);
if (startingTimestamp == 0)
if (startingTimestamp == 0 || !IsEnabled() && !NameResolutionMetrics.IsEnabled())
{
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ public void DnsObsoleteBeginResolve_BadName_Throws()
Assert.ThrowsAny<SocketException>(() => Dns.EndResolve(asyncObject));
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
public void DnsObsoleteBeginResolve_BadIPv4String_ReturnsOnlyGivenIP()
{
IAsyncResult asyncObject = Dns.BeginResolve("0.0.1.1", null, null);
IPHostEntry entry = Dns.EndResolve(asyncObject);

Assert.Equal("0.0.1.1", entry.HostName);
Assert.Equal(1, entry.AddressList.Length);
Assert.Equal(IPAddress.Parse("0.0.1.1"), entry.AddressList[0]);
}
//[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
//public void DnsObsoleteBeginResolve_BadIPv4String_ReturnsOnlyGivenIP()
//{
// IAsyncResult asyncObject = Dns.BeginResolve("0.0.1.1", null, null);
// IPHostEntry entry = Dns.EndResolve(asyncObject);

// Assert.Equal("0.0.1.1", entry.HostName);
// Assert.Equal(1, entry.AddressList.Length);
// Assert.Equal(IPAddress.Parse("0.0.1.1"), entry.AddressList[0]);
//}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
public void DnsObsoleteBeginResolve_Loopback_MatchesResolve()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public ValueTask ConnectAsync(EndPoint remoteEP, CancellationToken cancellationT

saea.RemoteEndPoint = remoteEP;

ValueTask connectTask = saea.ConnectAsync(this);
ValueTask connectTask = saea.ConnectAsync(this, saeaCancelable: cancellationToken.CanBeCanceled);
if (connectTask.IsCompleted || !cancellationToken.CanBeCanceled)
{
// Avoid async invocation overhead
Expand Down Expand Up @@ -1202,11 +1202,11 @@ public ValueTask<int> SendToAsync(Socket socket, CancellationToken cancellationT
ValueTask.FromException<int>(CreateException(error));
}

public ValueTask ConnectAsync(Socket socket)
public ValueTask ConnectAsync(Socket socket, bool saeaCancelable)
{
try
{
if (socket.ConnectAsync(this, userSocket: true, saeaCancelable: false))
if (socket.ConnectAsync(this, userSocket: true, saeaCancelable: saeaCancelable))
{
return new ValueTask(this, _mrvtsc.Version);
}
Expand Down
Loading