Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert the federation agent and related code to async/await. #7874

Merged
merged 2 commits into from
Jul 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7874.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert the federation agent and related code to async/await.
16 changes: 6 additions & 10 deletions synapse/http/federation/matrix_federation_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import logging
import urllib
from typing import List

from netaddr import AddrFormatError, IPAddress
from zope.interface import implementer
Expand Down Expand Up @@ -236,11 +237,10 @@ def connect(self, protocol_factory):

return run_in_background(self._do_connect, protocol_factory)

@defer.inlineCallbacks
def _do_connect(self, protocol_factory):
async def _do_connect(self, protocol_factory):
first_exception = None

server_list = yield self._resolve_server()
server_list = await self._resolve_server()

for server in server_list:
host = server.host
Expand All @@ -251,7 +251,7 @@ def _do_connect(self, protocol_factory):
endpoint = HostnameEndpoint(self._reactor, host, port)
if self._tls_options:
endpoint = wrapClientTLS(self._tls_options, endpoint)
result = yield make_deferred_yieldable(
result = await make_deferred_yieldable(
endpoint.connect(protocol_factory)
)

Expand All @@ -271,13 +271,9 @@ def _do_connect(self, protocol_factory):
# to try and if that doesn't work then we'll have an exception.
raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,))

@defer.inlineCallbacks
def _resolve_server(self):
async def _resolve_server(self) -> List[Server]:
"""Resolves the server name to a list of hosts and ports to attempt to
connect to.

Returns:
Deferred[list[Server]]
"""

if self._parsed_uri.scheme != b"matrix":
Expand All @@ -298,7 +294,7 @@ def _resolve_server(self):
if port or _is_ip_literal(host):
return [Server(host, port or 8448)]

server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
server_list = await self._srv_resolver.resolve_service(b"_matrix._tcp." + host)

if server_list:
return server_list
Expand Down
10 changes: 4 additions & 6 deletions synapse/http/federation/srv_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import logging
import random
import time
from typing import List

import attr

from twisted.internet import defer
from twisted.internet.error import ConnectError
from twisted.names import client, dns
from twisted.names.error import DNSNameError, DomainError
Expand Down Expand Up @@ -113,16 +113,14 @@ def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
self._cache = cache
self._get_time = get_time

@defer.inlineCallbacks
def resolve_service(self, service_name):
async def resolve_service(self, service_name: bytes) -> List[Server]:
"""Look up a SRV record

Args:
service_name (bytes): record to look up

Returns:
Deferred[list[Server]]:
a list of the SRV records, or an empty list if none found
a list of the SRV records, or an empty list if none found
"""
now = int(self._get_time())

Expand All @@ -136,7 +134,7 @@ def resolve_service(self, service_name):
return _sort_server_list(servers)

try:
answers, _, _ = yield make_deferred_yieldable(
answers, _, _ = await make_deferred_yieldable(
self._dns_client.lookupService(service_name)
)
except DNSNameError:
Expand Down
51 changes: 30 additions & 21 deletions tests/http/federation/test_matrix_federation_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ def get_connection_factory():
return test_server_connection_factory


# Once Async Mocks or lambdas are supported this can go away.
def generate_resolve_service(result):
async def resolve_service(_):
return result

return resolve_service


class MatrixFederationAgentTests(unittest.TestCase):
def setUp(self):
self.reactor = ThreadedMemoryReactorClock()
Expand Down Expand Up @@ -373,7 +381,7 @@ def test_get_hostname_bad_cert(self):
"""
Test the behaviour when the certificate on the server doesn't match the hostname
"""
self.mock_resolver.resolve_service.side_effect = lambda _: []
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv1"] = "1.2.3.4"

test_d = self._make_get_request(b"matrix://testserv1/foo/bar")
Expand Down Expand Up @@ -456,7 +464,7 @@ def test_get_no_srv_no_well_known(self):
Test the behaviour when the server name has no port, no SRV, and no well-known
"""

self.mock_resolver.resolve_service.side_effect = lambda _: []
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"

test_d = self._make_get_request(b"matrix://testserv/foo/bar")
Expand Down Expand Up @@ -510,7 +518,7 @@ def test_get_well_known(self):
"""Test the behaviour when the .well-known delegates elsewhere
"""

self.mock_resolver.resolve_service.side_effect = lambda _: []
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["target-server"] = "1::f"

Expand Down Expand Up @@ -572,7 +580,7 @@ def test_get_well_known_redirect(self):
"""Test the behaviour when the server name has no port and no SRV record, but
the .well-known has a 300 redirect
"""
self.mock_resolver.resolve_service.side_effect = lambda _: []
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["target-server"] = "1::f"

Expand Down Expand Up @@ -661,7 +669,7 @@ def test_get_invalid_well_known(self):
Test the behaviour when the server name has an *invalid* well-known (and no SRV)
"""

self.mock_resolver.resolve_service.side_effect = lambda _: []
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"

test_d = self._make_get_request(b"matrix://testserv/foo/bar")
Expand Down Expand Up @@ -717,7 +725,7 @@ def test_get_well_known_unsigned_cert(self):
# the config left to the default, which will not trust it (since the
# presented cert is signed by a test CA)

self.mock_resolver.resolve_service.side_effect = lambda _: []
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"

config = default_config("test", parse=True)
Expand Down Expand Up @@ -764,9 +772,9 @@ def test_get_hostname_srv(self):
"""
Test the behaviour when there is a single SRV record
"""
self.mock_resolver.resolve_service.side_effect = lambda _: [
Server(host=b"srvtarget", port=8443)
]
self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
[Server(host=b"srvtarget", port=8443)]
)
self.reactor.lookups["srvtarget"] = "1.2.3.4"

test_d = self._make_get_request(b"matrix://testserv/foo/bar")
Expand Down Expand Up @@ -819,9 +827,9 @@ def test_get_well_known_srv(self):
self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 443)

self.mock_resolver.resolve_service.side_effect = lambda _: [
Server(host=b"srvtarget", port=8443)
]
self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
[Server(host=b"srvtarget", port=8443)]
)

self._handle_well_known_connection(
client_factory,
Expand Down Expand Up @@ -861,7 +869,7 @@ def test_get_well_known_srv(self):
def test_idna_servername(self):
"""test the behaviour when the server name has idna chars in"""

self.mock_resolver.resolve_service.side_effect = lambda _: []
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])

# the resolver is always called with the IDNA hostname as a native string.
self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4"
Expand Down Expand Up @@ -922,9 +930,9 @@ def test_idna_servername(self):
def test_idna_srv_target(self):
"""test the behaviour when the target of a SRV record has idna chars"""

self.mock_resolver.resolve_service.side_effect = lambda _: [
Server(host=b"xn--trget-3qa.com", port=8443) # târget.com
]
self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
[Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com
)
self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4"

test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar")
Expand Down Expand Up @@ -1087,11 +1095,12 @@ def test_well_known_cache_with_temp_failure(self):
def test_srv_fallbacks(self):
"""Test that other SRV results are tried if the first one fails.
"""

self.mock_resolver.resolve_service.side_effect = lambda _: [
Server(host=b"target.com", port=8443),
Server(host=b"target.com", port=8444),
]
self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
[
Server(host=b"target.com", port=8443),
Server(host=b"target.com", port=8444),
]
)
self.reactor.lookups["target.com"] = "1.2.3.4"

test_d = self._make_get_request(b"matrix://testserv/foo/bar")
Expand Down
26 changes: 10 additions & 16 deletions tests/http/federation/test_srv_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from twisted.names import dns, error

from synapse.http.federation.srv_resolver import SrvResolver
from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from synapse.logging.context import LoggingContext, current_context

from tests import unittest
from tests.utils import MockClock
Expand Down Expand Up @@ -50,13 +50,7 @@ def do_lookup():

with LoggingContext("one") as ctx:
resolve_d = resolver.resolve_service(service_name)

self.assertNoResult(resolve_d)

# should have reset to the sentinel context
self.assertIs(current_context(), SENTINEL_CONTEXT)

result = yield resolve_d
result = yield defer.ensureDeferred(resolve_d)

# should have restored our context
self.assertIs(current_context(), ctx)
Expand Down Expand Up @@ -91,7 +85,7 @@ def test_from_cache_expired_and_dns_fail(self):
cache = {service_name: [entry]}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)

servers = yield resolver.resolve_service(service_name)
servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))

dns_client_mock.lookupService.assert_called_once_with(service_name)

Expand All @@ -117,7 +111,7 @@ def test_from_cache(self):
dns_client=dns_client_mock, cache=cache, get_time=clock.time
)

servers = yield resolver.resolve_service(service_name)
servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))

self.assertFalse(dns_client_mock.lookupService.called)

Expand All @@ -136,7 +130,7 @@ def test_empty_cache(self):
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)

with self.assertRaises(error.DNSServerError):
yield resolver.resolve_service(service_name)
yield defer.ensureDeferred(resolver.resolve_service(service_name))

@defer.inlineCallbacks
def test_name_error(self):
Expand All @@ -149,7 +143,7 @@ def test_name_error(self):
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)

servers = yield resolver.resolve_service(service_name)
servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))

self.assertEquals(len(servers), 0)
self.assertEquals(len(cache), 0)
Expand All @@ -166,8 +160,8 @@ def test_disabled_service(self):
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)

resolve_d = resolver.resolve_service(service_name)
self.assertNoResult(resolve_d)
# Old versions of Twisted don't have an ensureDeferred in failureResultOf.
resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name))

# returning a single "." should make the lookup fail with a ConenctError
lookup_deferred.callback(
Expand All @@ -192,8 +186,8 @@ def test_non_srv_answer(self):
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)

resolve_d = resolver.resolve_service(service_name)
self.assertNoResult(resolve_d)
# Old versions of Twisted don't have an ensureDeferred in successResultOf.
resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name))

lookup_deferred.callback(
(
Expand Down