Skip to content

Commit

Permalink
Fix #4013 _resolve_host race conditions (#4014)
Browse files Browse the repository at this point in the history
  • Loading branch information
ojii authored and asvetlov committed Aug 31, 2019
1 parent 8ee3e3d commit 010caab
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGES/4013.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed race conditions in _resolve_host caching and throttling when tracing is enabled.
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ Jian Zeng
Jinkyu Yi
Joel Watts
Jon Nabozny
Jonas Obrist
Joongi Kim
Josep Cugat
Joshu Coats
Expand Down
14 changes: 9 additions & 5 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,24 +778,28 @@ async def _resolve_host(self,

if (key in self._cached_hosts) and \
(not self._cached_hosts.expired(key)):
# get result early, before any await (#4014)
result = self._cached_hosts.next_addrs(key)

if traces:
for trace in traces:
await trace.send_dns_cache_hit(host)

return self._cached_hosts.next_addrs(key)
return result

if key in self._throttle_dns_events:
# get event early, before any await (#4014)
event = self._throttle_dns_events[key]
if traces:
for trace in traces:
await trace.send_dns_cache_hit(host)
await self._throttle_dns_events[key].wait()
await event.wait()
else:
# update dict early, before any await (#4014)
self._throttle_dns_events[key] = \
EventResultOrError(self._loop)
if traces:
for trace in traces:
await trace.send_dns_cache_miss(host)
self._throttle_dns_events[key] = \
EventResultOrError(self._loop)
try:

if traces:
Expand Down
32 changes: 31 additions & 1 deletion tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
from aiohttp import client, web
from aiohttp.client import ClientRequest, ClientTimeout
from aiohttp.client_reqrep import ConnectionKey
from aiohttp.connector import Connection, _DNSCacheTable
from aiohttp.connector import Connection, TCPConnector, _DNSCacheTable
from aiohttp.helpers import PY_37
from aiohttp.locks import EventResultOrError
from aiohttp.test_utils import make_mocked_coro, unused_port
from aiohttp.tracing import Trace
from conftest import needs_unix
Expand Down Expand Up @@ -2257,3 +2258,32 @@ def test_next_addrs_single(self, dns_cache_table) -> None:

addrs = dns_cache_table.next_addrs('foo')
assert addrs == ['127.0.0.1']


async def test_connector_cache_trace_race():
class DummyTracer:
async def send_dns_cache_hit(self, *args, **kwargs):
connector._cached_hosts.remove(("", 0))

token = object()
connector = TCPConnector()
connector._cached_hosts.add(("", 0), [token])

traces = [DummyTracer()]
assert await connector._resolve_host("", 0, traces) == [token]


async def test_connector_throttle_trace_race(loop):
key = ("", 0)
token = object()

class DummyTracer:
async def send_dns_cache_hit(self, *args, **kwargs):
event = connector._throttle_dns_events.pop(key)
event.set()
connector._cached_hosts.add(key, [token])

connector = TCPConnector()
connector._throttle_dns_events[key] = EventResultOrError(loop)
traces = [DummyTracer()]
assert await connector._resolve_host("", 0, traces) == [token]

0 comments on commit 010caab

Please sign in to comment.