Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Handle replication commands synchronously where possible #7876

Merged
merged 13 commits into from
Jul 27, 2020
40 changes: 20 additions & 20 deletions synapse/replication/tcp/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from twisted.internet.protocol import ReconnectingClientFactory

from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.client import DirectTcpReplicationClientFactory
from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
Expand Down Expand Up @@ -155,7 +156,7 @@ def __init__(self, hs):
# When POSITION or RDATA commands arrive, we stick them in a queue and process
# them in order in a separate background process.

# the streams which are currently being processed by _unsafe_process_stream
# the streams which are currently being processed by _unsafe_process_queue
self._processing_streams = set() # type: Set[str]

# for each stream, a queue of commands that are awaiting processing, and the
Expand Down Expand Up @@ -188,7 +189,7 @@ def __init__(self, hs):
if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender()

async def _add_command_to_stream_queue(
def _add_command_to_stream_queue(
self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None:
"""Queue the given received command for processing
Expand All @@ -202,33 +203,32 @@ async def _add_command_to_stream_queue(
logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name)
return

# if we're already processing this stream, stick the new command in the
# queue, and we're done.
queue.append((cmd, conn))

# if we're already processing this stream, there's nothing more to do:
# the new entry on the queue will get picked up in due course
if stream_name in self._processing_streams:
queue.append((cmd, conn))
return

# otherwise, process the new command.
# fire off a background process to start processing the queue.
run_as_background_process(
"process-replication-data", self._unsafe_process_queue, conn, cmd
)

# arguably we should start off a new background process here, but nothing
# will be too upset if we don't return for ages, so let's save the overhead
# and use the existing logcontext.
async def _unsafe_process_queue(self, stream_name: str):
"""Processes the command queue for the given stream, until it is empty

Does not check if there is already a thread processing the queue, hence "unsafe"
"""
self._processing_streams.add(stream_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add another guard here to test stream_name in self._processing_streams in case there is a race? I don't think there should be but it requires understanding how run_as_background_process works

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a pattern we use pretty widely across the codebase. I've added an assertion.

try:
# might as well skip the queue for this one, since it must be empty
assert not queue
await self._process_command(cmd, conn, stream_name)

# now process any other commands that have built up while we were
# dealing with that one.
queue = self._command_queues_by_stream.get(stream_name)
while queue:
cmd, conn = queue.popleft()
try:
await self._process_command(cmd, conn, stream_name)
except Exception:
logger.exception("Failed to handle command %s", cmd)

finally:
self._processing_streams.discard(stream_name)

Expand Down Expand Up @@ -387,7 +387,7 @@ async def _handle_user_ip(self, cmd: UserIpCommand):
assert self._server_notices_sender is not None
await self._server_notices_sender.on_user_ip(cmd.user_id)

async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
if cmd.instance_name == self._instance_name:
# Ignore RDATA that are just our own echoes
return
Expand All @@ -401,7 +401,7 @@ async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
# 2. so we don't race with getting a POSITION command and fetching
# missing RDATA.

await self._add_command_to_stream_queue(conn, cmd)
self._add_command_to_stream_queue(conn, cmd)

async def _process_rdata(
self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
Expand Down Expand Up @@ -478,14 +478,14 @@ async def on_rdata(
stream_name, instance_name, token, rows
)

async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name:
# Ignore POSITION that are just our own echoes
return

logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())

await self._add_command_to_stream_queue(conn, cmd)
self._add_command_to_stream_queue(conn, cmd)

async def _process_position(
self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
Expand Down