Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Implement MSC3706: partial state in /send_join response #11967

Merged
merged 3 commits into from
Feb 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11967.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Experimental implementation of [MSC3706](https://github.com/matrix-org/matrix-doc/pull/3706): extensions to `/send_join` to support reduced response size.
3 changes: 3 additions & 0 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,6 @@ def read_config(self, config: JsonDict, **kwargs):
self.msc2409_to_device_messages_enabled: bool = experimental.get(
"msc2409_to_device_messages_enabled", False
)

# MSC3706 (server-side support for partial state in /send_join responses)
self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False)
91 changes: 81 additions & 10 deletions synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Any,
Awaitable,
Callable,
Collection,
Dict,
Iterable,
List,
Expand Down Expand Up @@ -64,7 +65,7 @@
ReplicationGetQueryRestServlet,
)
from synapse.storage.databases.main.lock import Lock
from synapse.types import JsonDict, get_domain_from_id
from synapse.types import JsonDict, StateMap, get_domain_from_id
from synapse.util import json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
from synapse.util.caches.response_cache import ResponseCache
Expand Down Expand Up @@ -571,7 +572,7 @@ async def _on_state_ids_request_compute(
) -> JsonDict:
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
return {"pdu_ids": state_ids, "auth_chain_ids": list(auth_chain_ids)}

async def _on_context_state_request_compute(
self, room_id: str, event_id: Optional[str]
Expand Down Expand Up @@ -645,27 +646,61 @@ async def on_invite_request(
return {"event": ret_pdu.get_pdu_json(time_now)}

async def on_send_join_request(
self, origin: str, content: JsonDict, room_id: str
self,
origin: str,
content: JsonDict,
room_id: str,
caller_supports_partial_state: bool = False,
) -> Dict[str, Any]:
event, context = await self._on_send_membership_event(
origin, content, Membership.JOIN, room_id
)

prev_state_ids = await context.get_prev_state_ids()
state_ids = list(prev_state_ids.values())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really sure what the point of the list() was here. It might even be a carry-over from the py2->py3 conversion, where the assumption was that you had to wrap any calls to .keys() or .values() in list() unless you knew otherwise. Anyway, I've removed it while I'm in the area.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems completely reasonable (you can iterate over them multiple times, they have the same truthiness rules as lists, and so on) — I would probably agree with you about the Python 2 suspicion, since I think it behaved quite a bit differently back then.

auth_chain = await self.store.get_auth_chain(room_id, state_ids)
state = await self.store.get_events(state_ids)

state_event_ids: Collection[str]
servers_in_room: Optional[Collection[str]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are basically so that the joining homeserver has some other homeservers to contact to fetch events if we were to just vanish, right?
(I guess also because the joining homeserver might need to send events and for that, it needs to know which servers are in the room — usually this is understood from the membership events but it won't have them all yet under this new scheme.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess also because the joining homeserver might need to send events and for that, it needs to know which servers are in the room

yeah, this is the main driver right now.

if caller_supports_partial_state:
state_event_ids = _get_event_ids_for_partial_state_join(
event, prev_state_ids
)
servers_in_room = await self.state.get_hosts_in_room_at_events(
room_id, event_ids=event.prev_event_ids()
)
else:
state_event_ids = prev_state_ids.values()
servers_in_room = None

auth_chain_event_ids = await self.store.get_auth_chain_ids(
room_id, state_event_ids
)
Comment on lines +674 to +676
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've split the previous call to get_auth_chain into its component get_auth_chain_ids and get_events_as_list, which will give us a chance to filter out the duplicates before we pull them all into a big list.


# if the caller has opted in, we can omit any auth_chain events which are
# already in state_event_ids
if caller_supports_partial_state:
auth_chain_event_ids.difference_update(state_event_ids)

auth_chain_events = await self.store.get_events_as_list(auth_chain_event_ids)
state_events = await self.store.get_events_as_list(state_event_ids)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_events_as_list is equivalent to get_events, except that it doesn't build a dict of the results. Since we only ever used the .values of the result, the dict-building was entirely redundant. (The only thing it could plausibly have done was dedup the results, but since state_event_ids is a set anyway, there couldn't have been any dups in the first place)


# we try to do all the async stuff before this point, so that time_now is as
# accurate as possible.
time_now = self._clock.time_msec()
event_json = event.get_pdu_json()
return {
event_json = event.get_pdu_json(time_now)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may as well fix this to use the same "time_now" timestamp as state and auth_chain while we're passing.

resp = {
# TODO Remove the unstable prefix when servers have updated.
"org.matrix.msc3083.v2.event": event_json,
"event": event_json,
"state": [p.get_pdu_json(time_now) for p in state.values()],
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
"state": [p.get_pdu_json(time_now) for p in state_events],
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain_events],
"org.matrix.msc3706.partial_state": caller_supports_partial_state,
}

if servers_in_room is not None:
resp["org.matrix.msc3706.servers_in_room"] = list(servers_in_room)

return resp

async def on_make_leave_request(
self, origin: str, room_id: str, user_id: str
) -> Dict[str, Any]:
Expand Down Expand Up @@ -1339,3 +1374,39 @@ async def on_query(self, query_type: str, args: dict) -> JsonDict:
# error.
logger.warning("No handler registered for query type %s", query_type)
raise NotFoundError("No handler for Query type '%s'" % (query_type,))


def _get_event_ids_for_partial_state_join(
join_event: EventBase,
prev_state_ids: StateMap[str],
) -> Collection[str]:
"""Calculate state to be retuned in a partial_state send_join

Args:
join_event: the join event being send_joined
prev_state_ids: the event ids of the state before the join

Returns:
the event ids to be returned
"""

# return all non-member events
state_event_ids = {
event_id
for (event_type, state_key), event_id in prev_state_ids.items()
if event_type != EventTypes.Member
}

# we also need the current state of the current user (it's going to
# be an auth event for the new join, so we may as well return it)
current_membership_event_id = prev_state_ids.get(
(EventTypes.Member, join_event.state_key)
)
if current_membership_event_id is not None:
state_event_ids.add(current_membership_event_id)

# TODO: return a few more members:
# - those with invites
# - those that are kicked? / banned
Comment on lines +1408 to +1410
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interested in more rationale/elaboration here, if you have any more thoughts.
I guess I can see invites and bans because those are useful for validating other users' join events later.
However, due to e.g. annoying spammers, kicks and bans could possibly be a large amount of the state events for some rooms.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, MSC2775 lists a whole bunch of members we might want to return. I'm unconvinced by some of them, so this is here just as a thing to return to.


return state_event_ids
20 changes: 19 additions & 1 deletion synapse/federation/transport/server/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,16 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):

PREFIX = FEDERATION_V2_PREFIX

def __init__(
self,
hs: "HomeServer",
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self._msc3706_enabled = hs.config.experimental.msc3706_enabled

async def on_PUT(
self,
origin: str,
Expand All @@ -422,7 +432,15 @@ async def on_PUT(
) -> Tuple[int, JsonDict]:
# TODO(paul): assert that event_id parsed from path actually
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't even know who Paul is! :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's @leonerd :). He did some contracting for us back in the pre-NVL days. Which tells us how long this particular TODO has gone undone :'(

# match those given in content
result = await self.handler.on_send_join_request(origin, content, room_id)

partial_state = False
if self._msc3706_enabled:
partial_state = parse_boolean_from_args(
query, "org.matrix.msc3706.partial_state", default=False
)
result = await self.handler.on_send_join_request(
origin, content, room_id, caller_supports_partial_state=partial_state
)
return 200, result


Expand Down
12 changes: 6 additions & 6 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ async def get_auth_chain_ids(
room_id: str,
event_ids: Collection[str],
include_given: bool = False,
) -> List[str]:
) -> Set[str]:
"""Get auth events for given event_ids. The events *must* be state events.

Args:
Expand All @@ -130,7 +130,7 @@ async def get_auth_chain_ids(
include_given: include the given events in result

Returns:
list of event_ids
set of event_ids
"""

# Check if we have indexed the room so we can use the chain cover
Expand Down Expand Up @@ -159,7 +159,7 @@ async def get_auth_chain_ids(

def _get_auth_chain_ids_using_cover_index_txn(
self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
) -> List[str]:
) -> Set[str]:
"""Calculates the auth chain IDs using the chain index."""

# First we look up the chain ID/sequence numbers for the given events.
Expand Down Expand Up @@ -272,11 +272,11 @@ def _get_auth_chain_ids_using_cover_index_txn(
txn.execute(sql, (chain_id, max_no))
results.update(r for r, in txn)

return list(results)
return results

def _get_auth_chain_ids_txn(
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
) -> List[str]:
) -> Set[str]:
"""Calculates the auth chain IDs.

This is used when we don't have a cover index for the room.
Expand Down Expand Up @@ -331,7 +331,7 @@ def _get_auth_chain_ids_txn(
front = new_front
results.update(front)

return list(results)
return results

async def get_auth_chain_difference(
self, room_id: str, state_sets: List[Set[str]]
Expand Down
148 changes: 148 additions & 0 deletions tests/federation/test_federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,21 @@

from parameterized import parameterized

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events import make_event_from_dict
from synapse.federation.federation_server import server_matches_acl_event
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock

from tests import unittest
from tests.unittest import override_config


class FederationServerTests(unittest.FederatingHomeserverTestCase):
Expand Down Expand Up @@ -152,6 +161,145 @@ def test_needs_to_be_in_room(self):
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")


class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
super().prepare(reactor, clock, hs)

# create the room
creator_user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
self._room_id = self.helper.create_room_as(
room_creator=creator_user_id, tok=tok
)

# a second member on the orgin HS
second_member_user_id = self.register_user("fozzie", "bear")
tok2 = self.login("fozzie", "bear")
self.helper.join(self._room_id, second_member_user_id, tok=tok2)

def _make_join(self, user_id) -> JsonDict:
channel = self.make_signed_federation_request(
"GET",
f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}"
f"?ver={DEFAULT_ROOM_VERSION}",
)
self.assertEquals(channel.code, 200, channel.json_body)
return channel.json_body

def test_send_join(self):
"""happy-path test of send_join"""
joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME
join_result = self._make_join(joining_user)

join_event_dict = join_result["event"]
add_hashes_and_signatures(
KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
join_event_dict,
signature_name=self.OTHER_SERVER_NAME,
signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
)
channel = self.make_signed_federation_request(
"PUT",
f"/_matrix/federation/v2/send_join/{self._room_id}/x",
content=join_event_dict,
)
self.assertEquals(channel.code, 200, channel.json_body)

# we should get complete room state back
returned_state = [
(ev["type"], ev["state_key"]) for ev in channel.json_body["state"]
]
self.assertCountEqual(
returned_state,
[
("m.room.create", ""),
("m.room.power_levels", ""),
("m.room.join_rules", ""),
("m.room.history_visibility", ""),
("m.room.member", "@kermit:test"),
("m.room.member", "@fozzie:test"),
# nb: *not* the joining user
],
)

# also check the auth chain
returned_auth_chain_events = [
(ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"]
]
self.assertCountEqual(
returned_auth_chain_events,
[
("m.room.create", ""),
("m.room.member", "@kermit:test"),
("m.room.power_levels", ""),
("m.room.join_rules", ""),
],
)

# the room should show that the new user is a member
r = self.get_success(
self.hs.get_state_handler().get_current_state(self._room_id)
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")

@override_config({"experimental_features": {"msc3706_enabled": True}})
def test_send_join_partial_state(self):
"""When MSC3706 support is enabled, /send_join should return partial state"""
joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME
join_result = self._make_join(joining_user)

join_event_dict = join_result["event"]
add_hashes_and_signatures(
KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
join_event_dict,
signature_name=self.OTHER_SERVER_NAME,
signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
)
channel = self.make_signed_federation_request(
"PUT",
f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true",
content=join_event_dict,
)
self.assertEquals(channel.code, 200, channel.json_body)

# expect a reduced room state
returned_state = [
(ev["type"], ev["state_key"]) for ev in channel.json_body["state"]
]
self.assertCountEqual(
returned_state,
[
("m.room.create", ""),
("m.room.power_levels", ""),
("m.room.join_rules", ""),
("m.room.history_visibility", ""),
],
)

# the auth chain should not include anything already in "state"
returned_auth_chain_events = [
(ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"]
]
self.assertCountEqual(
returned_auth_chain_events,
[
("m.room.member", "@kermit:test"),
],
)

# the room should show that the new user is a member
r = self.get_success(
self.hs.get_state_handler().get_current_state(self._room_id)
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")


def _create_acl_event(content):
return make_event_from_dict(
{
Expand Down
Loading