Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Fix client reader sharding tests (#7853)
Browse files Browse the repository at this point in the history
* Fix client reader sharding tests

* Newsfile

* Fix typing

* Update changelog.d/7853.misc

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>

* Move mocking of http_client to tests

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
  • Loading branch information
erikjohnston and clokep committed Jul 15, 2020
1 parent b11450d commit f13061d
Show file tree
Hide file tree
Showing 7 changed files with 300 additions and 174 deletions.
1 change: 1 addition & 0 deletions changelog.d/7853.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for handling registration requests across multiple client reader workers.
24 changes: 23 additions & 1 deletion synapse/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
IReactorPluggableNameResolver,
IResolutionReceiver,
)
from twisted.internet.task import Cooperator
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
from twisted.web.client import Agent, HTTPConnectionPool, readBody
Expand Down Expand Up @@ -69,6 +70,21 @@ def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
return False


_EPSILON = 0.00000001


def _make_scheduler(reactor):
"""Makes a schedular suitable for a Cooperator using the given reactor.
(This is effectively just a copy from `twisted.internet.task`)
"""

def _scheduler(x):
return reactor.callLater(_EPSILON, x)

return _scheduler


class IPBlacklistingResolver(object):
"""
A proxy for reactor.nameResolver which only produces non-blacklisted IP
Expand Down Expand Up @@ -212,6 +228,10 @@ def __init__(
if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix)

# We use this for our body producers to ensure that they use the correct
# reactor.
self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor()))

self.user_agent = self.user_agent.encode("ascii")

if self._ip_blacklist:
Expand Down Expand Up @@ -292,7 +312,9 @@ def request(self, method, uri, data=None, headers=None):
try:
body_producer = None
if data is not None:
body_producer = QuieterFileBodyProducer(BytesIO(data))
body_producer = QuieterFileBodyProducer(
BytesIO(data), cooperator=self._cooperator,
)

request_deferred = treq.request(
method,
Expand Down
5 changes: 5 additions & 0 deletions synapse/server.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import synapse.handlers.room
import synapse.handlers.room_member
import synapse.handlers.set_password
import synapse.http.client
import synapse.http.matrixfederationclient
import synapse.notifier
import synapse.push.pusherpool
import synapse.replication.tcp.client
Expand Down Expand Up @@ -143,3 +144,7 @@ class HomeServer(object):
pass
def get_replication_streams(self) -> Dict[str, Stream]:
pass
def get_http_client(
self,
) -> synapse.http.matrixfederationclient.MatrixFederationHttpClient:
pass
168 changes: 165 additions & 3 deletions tests/replication/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import logging
from typing import Any, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple

import attr

Expand All @@ -26,16 +26,17 @@
GenericWorkerReplicationHandler,
GenericWorkerServer,
)
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
from synapse.replication.http import streams
from synapse.replication.http import ReplicationRestResource, streams
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer
from synapse.util import Clock

from tests import unittest
from tests.server import FakeTransport
from tests.server import FakeTransport, render

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -180,6 +181,159 @@ def assert_request_is_get_repl_stream_updates(
self.assertEqual(request.method, b"GET")


class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests running multiple workers.
Automatically handle HTTP replication requests from workers to master,
unlike `BaseStreamTestCase`.
"""

servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]]

def setUp(self):
super().setUp()

# build a replication server
self.server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = self.hs.get_replication_streamer()

store = self.hs.get_datastore()
self.database = store.db

self.reactor.lookups["testserv"] = "1.2.3.4"

self._worker_hs_to_resource = {}

# When we see a connection attempt to the master replication listener we
# automatically set up the connection. This is so that tests don't
# manually have to go and explicitly set it up each time (plus sometimes
# it is impossible to write the handling explicitly in the tests).
self.reactor.add_tcp_client_callback(
"1.2.3.4", 8765, self._handle_http_replication_attempt
)

def create_test_json_resource(self):
"""Overrides `HomeserverTestCase.create_test_json_resource`.
"""
# We override this so that it automatically registers all the HTTP
# replication servlets, without having to explicitly do that in all
# subclassses.

resource = ReplicationRestResource(self.hs)

for servlet in self.servlets:
servlet(self.hs, resource)

return resource

def make_worker_hs(
self, worker_app: str, extra_config: dict = {}, **kwargs
) -> HomeServer:
"""Make a new worker HS instance, correctly connecting replcation
stream to the master HS.
Args:
worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
extra_config: Any extra config to use for this instances.
**kwargs: Options that get passed to `self.setup_test_homeserver`,
useful to e.g. pass some mocks for things like `http_client`
Returns:
The new worker HomeServer instance.
"""

config = self._get_worker_hs_config()
config["worker_app"] = worker_app
config.update(extra_config)

worker_hs = self.setup_test_homeserver(
homeserverToUse=GenericWorkerServer,
config=config,
reactor=self.reactor,
**kwargs
)

store = worker_hs.get_datastore()
store.db._db_pool = self.database._db_pool

repl_handler = ReplicationCommandHandler(worker_hs)
client = ClientReplicationStreamProtocol(
worker_hs, "client", "test", self.clock, repl_handler,
)
server = self.server_factory.buildProtocol(None)

client_transport = FakeTransport(server, self.reactor)
client.makeConnection(client_transport)

server_transport = FakeTransport(client, self.reactor)
server.makeConnection(server_transport)

# Set up a resource for the worker
resource = ReplicationRestResource(self.hs)

for servlet in self.servlets:
servlet(worker_hs, resource)

self._worker_hs_to_resource[worker_hs] = resource

return worker_hs

def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_replication_host"] = "testserv"
config["worker_replication_http_port"] = "8765"
return config

def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
render(request, self._worker_hs_to_resource[worker_hs], self.reactor)

def replicate(self):
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
self.streamer.on_notifier_poke()
self.pump()

def _handle_http_replication_attempt(self):
"""Handles a connection attempt to the master replication HTTP
listener.
"""

# We should have at least one outbound connection attempt, where the
# last is one to the HTTP repication IP/port.
clients = self.reactor.tcpClients
self.assertGreaterEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 8765)

# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)

request_factory = OneShotRequestFactory()

# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor)
channel.requestFactory = request_factory
channel.site = self.site

# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
channel, self.reactor, client_protocol
)
client_protocol.makeConnection(client_to_server_transport)

server_to_client_transport = FakeTransport(
client_protocol, self.reactor, channel
)
channel.makeConnection(server_to_client_transport)

# Note: at this point we've wired everything up, but we need to return
# before the data starts flowing over the connections as this is called
# inside `connecTCP` before the connection has been passed back to the
# code that requested the TCP connection.


class TestReplicationDataHandler(GenericWorkerReplicationHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""

Expand Down Expand Up @@ -241,6 +395,14 @@ def unregisterProducer(self):
# We need to manually stop the _PullToPushProducer.
self._pull_to_push_producer.stop()

def checkPersistence(self, request, version):
"""Check whether the connection can be re-used
"""
# We hijack this to always say no for ease of wiring stuff up in
# `handle_http_replication_attempt`.
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
return False


class _PullToPushProducer:
"""A push producer that wraps a pull producer.
Expand Down
59 changes: 11 additions & 48 deletions tests/replication/test_client_reader_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,63 +15,26 @@
import logging

from synapse.api.constants import LoginType
from synapse.app.generic_worker import GenericWorkerServer
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.rest.client.v2_alpha import register

from tests import unittest
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
from tests.server import FakeChannel, render
from tests.server import FakeChannel

logger = logging.getLogger(__name__)


class ClientReaderTestCase(unittest.HomeserverTestCase):
class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
"""Base class for tests of the replication streams"""

servlets = [
register.register_servlets,
]
servlets = [register.register_servlets]

def prepare(self, reactor, clock, hs):
# build a replication server
self.server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()

store = hs.get_datastore()
self.database = store.db

self.recaptcha_checker = DummyRecaptchaChecker(hs)
auth_handler = hs.get_auth_handler()
auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker

self.reactor.lookups["testserv"] = "1.2.3.4"

def make_worker_hs(self, extra_config={}):
config = self._get_worker_hs_config()
config.update(extra_config)

worker_hs = self.setup_test_homeserver(
homeserverToUse=GenericWorkerServer, config=config, reactor=self.reactor,
)

store = worker_hs.get_datastore()
store.db._db_pool = self.database._db_pool

# Register the expected servlets, essentially this is HomeserverTestCase.create_test_json_resource.
resource = JsonResource(self.hs)

for servlet in self.servlets:
servlet(worker_hs, resource)

# Essentially HomeserverTestCase.render.
def _render(request):
render(request, self.resource, self.reactor)

return worker_hs, _render

def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.client_reader"
Expand All @@ -82,14 +45,14 @@ def _get_worker_hs_config(self) -> dict:
def test_register_single_worker(self):
"""Test that registration works when using a single client reader worker.
"""
_, worker_render = self.make_worker_hs()
worker_hs = self.make_worker_hs("synapse.app.client_reader")

request_1, channel_1 = self.make_request(
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
) # type: SynapseRequest, FakeChannel
worker_render(request_1)
self.render_on_worker(worker_hs, request_1)
self.assertEqual(request_1.code, 401)

# Grab the session
Expand All @@ -99,7 +62,7 @@ def test_register_single_worker(self):
request_2, channel_2 = self.make_request(
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
) # type: SynapseRequest, FakeChannel
worker_render(request_2)
self.render_on_worker(worker_hs, request_2)
self.assertEqual(request_2.code, 200)

# We're given a registered user.
Expand All @@ -108,15 +71,15 @@ def test_register_single_worker(self):
def test_register_multi_worker(self):
"""Test that registration works when using multiple client reader workers.
"""
_, worker_render_1 = self.make_worker_hs()
_, worker_render_2 = self.make_worker_hs()
worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")

request_1, channel_1 = self.make_request(
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
) # type: SynapseRequest, FakeChannel
worker_render_1(request_1)
self.render_on_worker(worker_hs_1, request_1)
self.assertEqual(request_1.code, 401)

# Grab the session
Expand All @@ -126,7 +89,7 @@ def test_register_multi_worker(self):
request_2, channel_2 = self.make_request(
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
) # type: SynapseRequest, FakeChannel
worker_render_2(request_2)
self.render_on_worker(worker_hs_2, request_2)
self.assertEqual(request_2.code, 200)

# We're given a registered user.
Expand Down
Loading

0 comments on commit f13061d

Please sign in to comment.