Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PR #8163/006fbe03 backport][3.9] Avoid creating a task to do DNS resolution if there is no throttle #8172

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGES/8163.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Improved the DNS resolution performance on cache hit
-- by :user:`bdraco`.

This is achieved by avoiding an :mod:`asyncio` task creation
in this case.
50 changes: 36 additions & 14 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,7 @@ def clear_dns_cache(
async def _resolve_host(
self, host: str, port: int, traces: Optional[List["Trace"]] = None
) -> List[Dict[str, Any]]:
"""Resolve host and return list of addresses."""
if is_ip_address(host):
return [
{
Expand Down Expand Up @@ -852,8 +853,7 @@ async def _resolve_host(
return res

key = (host, port)

if (key in self._cached_hosts) and (not self._cached_hosts.expired(key)):
if key in self._cached_hosts and not self._cached_hosts.expired(key):
# get result early, before any await (#4014)
result = self._cached_hosts.next_addrs(key)

Expand All @@ -862,6 +862,39 @@ async def _resolve_host(
await trace.send_dns_cache_hit(host)
return result

#
# If multiple connectors are resolving the same host, we wait
# for the first one to resolve and then use the result for all of them.
# We use a throttle event to ensure that we only resolve the host once
# and then use the result for all the waiters.
#
# In this case we need to create a task to ensure that we can shield
# the task from cancellation as cancelling this lookup should not cancel
# the underlying lookup or else the cancel event will get broadcast to
# all the waiters across all connections.
#
resolved_host_task = asyncio.create_task(
self._resolve_host_with_throttle(key, host, port, traces)
)
try:
return await asyncio.shield(resolved_host_task)
except asyncio.CancelledError:

def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
with suppress(Exception, asyncio.CancelledError):
fut.result()

resolved_host_task.add_done_callback(drop_exception)
raise

async def _resolve_host_with_throttle(
self,
key: Tuple[str, int],
host: str,
port: int,
traces: Optional[List["Trace"]],
) -> List[Dict[str, Any]]:
"""Resolve host with a dns events throttle."""
if key in self._throttle_dns_events:
# get event early, before any await (#4014)
event = self._throttle_dns_events[key]
Expand Down Expand Up @@ -1163,22 +1196,11 @@ async def _create_direct_connection(
host = host.rstrip(".") + "."
port = req.port
assert port is not None
host_resolved = asyncio.ensure_future(
self._resolve_host(host, port, traces=traces), loop=self._loop
)
try:
# Cancelling this lookup should not cancel the underlying lookup
# or else the cancel event will get broadcast to all the waiters
# across all connections.
hosts = await asyncio.shield(host_resolved)
except asyncio.CancelledError:

def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
with suppress(Exception, asyncio.CancelledError):
fut.result()

host_resolved.add_done_callback(drop_exception)
raise
hosts = await self._resolve_host(host, port, traces=traces)
except OSError as exc:
if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
raise
Expand Down
6 changes: 6 additions & 0 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,7 @@ async def test_tcp_connector_dns_throttle_requests(loop, dns_response) -> None:
loop.create_task(conn._resolve_host("localhost", 8080))
loop.create_task(conn._resolve_host("localhost", 8080))
await asyncio.sleep(0)
await asyncio.sleep(0)
m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0)


Expand All @@ -778,6 +779,9 @@ async def test_tcp_connector_dns_throttle_requests_exception_spread(loop) -> Non
r1 = loop.create_task(conn._resolve_host("localhost", 8080))
r2 = loop.create_task(conn._resolve_host("localhost", 8080))
await asyncio.sleep(0)
await asyncio.sleep(0)
await asyncio.sleep(0)
await asyncio.sleep(0)
assert r1.exception() == e
assert r2.exception() == e

Expand All @@ -792,6 +796,7 @@ async def test_tcp_connector_dns_throttle_requests_cancelled_when_close(
loop.create_task(conn._resolve_host("localhost", 8080))
f = loop.create_task(conn._resolve_host("localhost", 8080))

await asyncio.sleep(0)
await asyncio.sleep(0)
await conn.close()

Expand Down Expand Up @@ -956,6 +961,7 @@ async def test_tcp_connector_dns_tracing_throttle_requests(loop, dns_response) -
loop.create_task(conn._resolve_host("localhost", 8080, traces=traces))
loop.create_task(conn._resolve_host("localhost", 8080, traces=traces))
await asyncio.sleep(0)
await asyncio.sleep(0)
on_dns_cache_hit.assert_called_once_with(
session, trace_config_ctx, aiohttp.TraceDnsCacheHitParams("localhost")
)
Expand Down
Loading