diff --git a/CHANGES/4013.bugfix b/CHANGES/4013.bugfix new file mode 100644 index 00000000000..1793f2137b5 --- /dev/null +++ b/CHANGES/4013.bugfix @@ -0,0 +1 @@ +Fixed race conditions in _resolve_host caching and throttling when tracing is enabled. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index fa1b23df982..0f474de6b04 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -134,6 +134,7 @@ Jian Zeng Jinkyu Yi Joel Watts Jon Nabozny +Jonas Obrist Joongi Kim Josep Cugat Joshu Coats diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 706bc99adb7..0099f467551 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -778,24 +778,28 @@ async def _resolve_host(self, if (key in self._cached_hosts) and \ (not self._cached_hosts.expired(key)): + # get result early, before any await (#4014) + result = self._cached_hosts.next_addrs(key) if traces: for trace in traces: await trace.send_dns_cache_hit(host) - - return self._cached_hosts.next_addrs(key) + return result if key in self._throttle_dns_events: + # get event early, before any await (#4014) + event = self._throttle_dns_events[key] if traces: for trace in traces: await trace.send_dns_cache_hit(host) - await self._throttle_dns_events[key].wait() + await event.wait() else: + # update dict early, before any await (#4014) + self._throttle_dns_events[key] = \ + EventResultOrError(self._loop) if traces: for trace in traces: await trace.send_dns_cache_miss(host) - self._throttle_dns_events[key] = \ - EventResultOrError(self._loop) try: if traces: diff --git a/tests/test_connector.py b/tests/test_connector.py index e40559d0c6f..03f5d8bd830 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -18,8 +18,9 @@ from aiohttp import client, web from aiohttp.client import ClientRequest, ClientTimeout from aiohttp.client_reqrep import ConnectionKey -from aiohttp.connector import Connection, _DNSCacheTable +from aiohttp.connector import Connection, TCPConnector, _DNSCacheTable from aiohttp.helpers import PY_37 +from aiohttp.locks import EventResultOrError from aiohttp.test_utils import make_mocked_coro, unused_port from aiohttp.tracing import Trace from conftest import needs_unix @@ -2257,3 +2258,32 @@ def test_next_addrs_single(self, dns_cache_table) -> None: addrs = dns_cache_table.next_addrs('foo') assert addrs == ['127.0.0.1'] + + +async def test_connector_cache_trace_race(): + class DummyTracer: + async def send_dns_cache_hit(self, *args, **kwargs): + connector._cached_hosts.remove(("", 0)) + + token = object() + connector = TCPConnector() + connector._cached_hosts.add(("", 0), [token]) + + traces = [DummyTracer()] + assert await connector._resolve_host("", 0, traces) == [token] + + +async def test_connector_throttle_trace_race(loop): + key = ("", 0) + token = object() + + class DummyTracer: + async def send_dns_cache_hit(self, *args, **kwargs): + event = connector._throttle_dns_events.pop(key) + event.set() + connector._cached_hosts.add(key, [token]) + + connector = TCPConnector() + connector._throttle_dns_events[key] = EventResultOrError(loop) + traces = [DummyTracer()] + assert await connector._resolve_host("", 0, traces) == [token]