diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 084227dad528..85c5cd281f0d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -104,7 +104,7 @@ @attr.s(slots=True) class _NewEventInfo: - """Holds information about a received event, ready for passing to _handle_new_events + """Holds information about a received event, ready for passing to _auth_and_persist_events Attributes: event: the received event @@ -808,8 +808,10 @@ async def _process_received_pdu( logger.debug("Processing event: %s", event) try: - context, auth_events = await self._prep_event(event, state=state) - await self._handle_new_event( + context, auth_events = await self._calculate_event_context( + event, state=state + ) + await self._auth_and_persist_event( origin, event, context, auth_events=auth_events, state=state ) except AuthError as e: @@ -1014,7 +1016,9 @@ async def backfill( ) if ev_infos: - await self._handle_new_events(dest, room_id, ev_infos, backfilled=True) + await self._auth_and_persist_events( + dest, room_id, ev_infos, backfilled=True + ) # Step 2: Persist the rest of the events in the chunk one by one events.sort(key=lambda e: e.depth) @@ -1027,12 +1031,12 @@ async def backfill( # non-outliers assert not event.internal_metadata.is_outlier() - context, context_auth_events = await self._prep_event(event) + context, context_auth_events = await self._calculate_event_context(event) # We store these one at a time since each event depends on the # previous to work out the state. # TODO: We can probably do something more clever here. - await self._handle_new_event( + await self._auth_and_persist_event( dest, event, context, auth_events=context_auth_events, backfilled=True ) @@ -1368,7 +1372,7 @@ async def get_event(event_id: str): event_infos.append(_NewEventInfo(event, None, auth)) - await self._handle_new_events( + await self._auth_and_persist_events( destination, room_id, event_infos, @@ -1675,15 +1679,15 @@ async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict: event.internal_metadata.send_on_behalf_of = origin # Calculate the event context and persist the event. - context, auth_events = await self._prep_event( + context, auth_events = await self._calculate_event_context( event, state=None, auth_events=None ) - context = await self._handle_new_event( + context = await self._auth_and_persist_event( origin, event, context, auth_events=auth_events, backfilled=False ) logger.debug( - "on_send_join_request: After _handle_new_event: %s, sigs: %s", + "on_send_join_request: After _auth_and_persist_event: %s, sigs: %s", event.event_id, event.signatures, ) @@ -1892,11 +1896,13 @@ async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None: event.internal_metadata.outlier = False - context, auth_events = await self._prep_event(event) - await self._handle_new_event(origin, event, context, auth_events=auth_events) + context, auth_events = await self._calculate_event_context(event) + await self._auth_and_persist_event( + origin, event, context, auth_events=auth_events + ) logger.debug( - "on_send_leave_request: After _handle_new_event: %s, sigs: %s", + "on_send_leave_request: After _auth_and_persist_event: %s, sigs: %s", event.event_id, event.signatures, ) @@ -2004,7 +2010,7 @@ async def get_persisted_pdu( async def get_min_depth_for_context(self, context: str) -> int: return await self.store.get_min_depth(context) - async def _handle_new_event( + async def _auth_and_persist_event( self, origin: str, event: EventBase, @@ -2014,12 +2020,15 @@ async def _handle_new_event( backfilled: bool = False, ) -> EventContext: """ - Process an event. + Process an event by performing auth checks and then persisting to the database. Args: origin: The host the event originates from. event: The event itself. - context: The event context. + context: + The event context. + + NB that this function potentially modifies it. state: The state events used to auth the event. auth_events: Map from (event_type, state_key) to event @@ -2033,7 +2042,7 @@ async def _handle_new_event( Returns: The event context. """ - context = await self.do_auth( + context = await self._check_event_auth( origin, event, context, @@ -2063,7 +2072,7 @@ async def _handle_new_event( return context - async def _handle_new_events( + async def _auth_and_persist_events( self, origin: str, room_id: str, @@ -2081,12 +2090,12 @@ async def _handle_new_events( async def prep(ev_info: _NewEventInfo): event = ev_info.event with nested_logging_context(suffix=event.event_id): - res, auth_events = await self._prep_event( + res, auth_events = await self._calculate_event_context( event, state=ev_info.state, auth_events=ev_info.auth_events, ) - res = await self.do_auth( + res = await self._check_event_auth( origin, event, res, @@ -2224,14 +2233,14 @@ async def _persist_auth_tree( room_id, [(event, new_event_context)] ) - async def _prep_event( + async def _calculate_event_context( self, event: EventBase, state: Optional[Iterable[EventBase]] = None, auth_events: Optional[MutableStateMap[EventBase]] = None, ) -> Tuple[EventContext, MutableStateMap[EventBase]]: """ - Prepare an event for sending over federation. + Calculate the context and auth events for a given event. Args: event: The event itself. @@ -2381,7 +2390,7 @@ async def on_get_missing_events( return missing_events - async def do_auth( + async def _check_event_auth( self, origin: str, event: EventBase, @@ -2391,11 +2400,15 @@ async def do_auth( backfilled: bool, ) -> EventContext: """ + Checks whether an event should be rejected (for failing auth checks). Args: origin: The host the event originates from. event: The event itself. - context: The event context. + context: + The event context. + + NB that this function potentially modifies it. state: The state events to calculate the event context from. This is ignored if context is provided. auth_events: @@ -2408,8 +2421,9 @@ async def do_auth( Also NB that this function adds entries to it. backfilled: True if the event was backfilled. + Returns: - updated context object + The updated context object. """ room_version = await self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] @@ -2455,7 +2469,7 @@ async def _update_auth_events_and_context_for_auth( context: EventContext, auth_events: MutableStateMap[EventBase], ) -> EventContext: - """Helper for do_auth. See there for docs. + """Helper for _check_event_auth. See there for docs. Checks whether a given event has the expected auth events. If it doesn't then we talk to the remote server to compare state to see if @@ -2535,10 +2549,14 @@ async def _update_auth_events_and_context_for_auth( e.internal_metadata.outlier = True logger.debug( - "do_auth %s missing_auth: %s", event.event_id, e.event_id + "_check_event_auth %s missing_auth: %s", + event.event_id, + e.event_id, + ) + context, auth = await self._calculate_event_context( + e, auth_events=auth ) - context, auth = await self._prep_event(e, auth_events=auth) - await self._handle_new_event( + await self._auth_and_persist_event( origin, e, context, auth_events=auth ) diff --git a/tests/test_federation.py b/tests/test_federation.py index 7943b1e3fefa..07bc3a21edbb 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -76,7 +76,7 @@ def setUp(self): ) self.handler = self.homeserver.get_federation_handler() - self.handler.do_auth = ( + self.handler._check_event_auth = ( lambda origin, event, context, state, auth_events, backfilled: succeed( context )