Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Optimize filter_events_for_client for faster /messages - v1 #14494

1 change: 1 addition & 0 deletions changelog.d/14494.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Speed-up `/messages` with `filter_events_for_client` optimizations.
52 changes: 51 additions & 1 deletion synapse/storage/controllers/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.logging.opentracing import tag_args, trace
from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace
from synapse.storage.roommember import ProfileInfo
from synapse.storage.state import StateFilter
from synapse.storage.util.partial_state_events_tracker import (
Expand Down Expand Up @@ -182,6 +182,56 @@ def _get_state_groups_from_groups(

return self.stores.state._get_state_groups_from_groups(groups, state_filter)

@trace
@tag_args
async def _get_state_for_events_when_filtering_for_client(
self, event_ids: Collection[str], user_id_viewing_events: str
) -> Dict[str, StateMap[EventBase]]:
"""Get the state at each event that is necessary to filter
them before being displayed to clients from the perspective of the
`user_id_viewing_events`. Will fetch `m.room.history_visibility` and
`m.room.member` event of `user_id_viewing_events`.

Args:
event_ids: List of event ID's that will be displayed to the client
user_id_viewing_events: User ID that will be viewing these events

Returns:
Dict of event_id to state map.
"""
set_tag(
SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
str(len(event_ids)),
)

# Since we're making decisions based on the state, we need to wait.
await_full_state = True
Comment on lines +207 to +208
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this accurate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably the source of our TestPartialStateJoin failures in Complement (https://github.com/matrix-org/synapse/actions/runs/3520947916/jobs/5902365185#step:6:11950). What should I be doing here?


event_to_groups = await self.get_state_group_for_events(
event_ids, await_full_state=await_full_state
)

groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_client_filtering(
groups, user_id_viewing_events
)

state_event_map = await self.stores.main.get_events(
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
get_prev_content=False,
)

event_to_state = {
event_id: {
k: state_event_map[v]
for k, v in group_to_state[group].items()
if v in state_event_map
}
for event_id, group in event_to_groups.items()
}

return {event: event_to_state[event] for event in event_ids}

@trace
async def get_state_for_events(
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
Expand Down
194 changes: 173 additions & 21 deletions synapse/storage/databases/state/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,22 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
)

import attr

from synapse.api.constants import EventTypes
from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
Expand All @@ -29,9 +40,11 @@
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import MutableStateMap, StateKey, StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -158,6 +171,8 @@ def _get_state_group_delta_txn(txn: LoggingTransaction) -> _GetStateGroupDelta:
)

@cancellable
@trace
@tag_args
Comment on lines 173 to +175
Copy link
Contributor Author

@MadLittleMods MadLittleMods Nov 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should @cancellable be the innermost or outermost decorator?

I am guessing outermost but perhaps it doesn't matter

async def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter
) -> Dict[int, StateMap[str]]:
Expand All @@ -171,6 +186,11 @@ async def _get_state_groups_from_groups(
Returns:
Dict of state group to state map.
"""
set_tag(
SynapseTags.FUNC_ARG_PREFIX + "groups.length",
str(len(groups)),
)

results: Dict[int, StateMap[str]] = {}

chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
Expand Down Expand Up @@ -237,45 +257,177 @@ def _get_state_for_group_using_cache(

return state_filter.filter_state(state_dict_ids), not missing_types

@cancellable
async def _get_state_for_groups(
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Dict[int, MutableStateMap[str]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
async def _get_state_groups_from_cache(
self, state_groups: Iterable[int], state_filter: StateFilter
) -> Tuple[Dict[int, MutableStateMap[str]], Set[int]]:
"""Given a `state_filter`, pull out the relevant cached state groups that match
the filter.

Args:
groups: list of state groups for which we want
to get the state.
state_filter: The state filter used to fetch state
from the database.
state_groups: List of state group ID's to fetch from the cache
state_filter: The relevant StateFilter to pull against

Returns:
Dict of state group to state map.
A map from each state_group ID to the complete/incomplete state map (filled
in by cached values) and the set of incomplete state_groups that still need
to be filled in.
"""
state_filter = state_filter or StateFilter.all()

member_filter, non_member_filter = state_filter.get_member_split()

# Now we look them up in the member and non-member caches
(
non_member_state,
incomplete_groups_nm,
) = self._get_state_for_groups_using_cache(
groups, self._state_group_cache, state_filter=non_member_filter
state_groups, self._state_group_cache, state_filter=non_member_filter
)

(member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache(
groups, self._state_group_members_cache, state_filter=member_filter
(member_state, incomplete_groups_m) = self._get_state_for_groups_using_cache(
state_groups, self._state_group_members_cache, state_filter=member_filter
)

state = dict(non_member_state)
for group in groups:
state[group].update(member_state[group])

# Now fetch any missing groups from the database
for state_group in state_groups:
state[state_group].update(member_state[state_group])

# We may have only got one or none of the events for the group so mark those as
# incomplete that need fetching from the database.
incomplete_groups = incomplete_groups_m | incomplete_groups_nm

return (state, incomplete_groups)

@cancellable
@trace
@tag_args
async def _get_state_for_client_filtering(
self, state_group_ids: Iterable[int], user_id_viewing_events: str
) -> Dict[int, MutableStateMap[str]]:
"""Get a state map for each state group ID provided that is necessary to filter
the corresponding events before being displayed to clients from the perspective
of the `user_id_viewing_events`.

Args:
state_group_ids: The state groups to fetch
user_id_viewing_events: User ID that will be viewing the events that correspond
to the state groups

Returns:
Dict of state_group ID to state map.
"""

def _get_state_for_client_filtering_txn(
txn: LoggingTransaction, groups: Iterable[int]
) -> Mapping[int, MutableStateMap[str]]:

sql = """
WITH RECURSIVE sgs(state_group) AS (
VALUES(CAST(? AS bigint))
UNION ALL
SELECT prev_state_group FROM state_group_edges e, sgs s
WHERE s.state_group = e.state_group
)
SELECT
type, state_key, event_id
FROM state_groups_state
WHERE
state_group IN (
SELECT state_group FROM sgs
)
AND (type = ? AND state_key = ?)
ORDER BY
-- Use the lastest state in the chain (highest numbered state_group in the chain)
state_group DESC
LIMIT 1
"""

results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}
for group in groups:
row_info_list: List[Tuple] = []
txn.execute(sql, (group, EventTypes.RoomHistoryVisibility, ""))
history_vis_info = txn.fetchone()
if history_vis_info is not None:
row_info_list.append(history_vis_info)

txn.execute(sql, (group, EventTypes.Member, user_id_viewing_events))
membership_info = txn.fetchone()
if membership_info is not None:
row_info_list.append(membership_info)
Comment on lines +346 to +354
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way we can batch up these two individual queries to have less database round-trip time?

Is there a way we can batch up all the queries across all of the state groups?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, see the UNION idea from @erikjohnston in #14494 (comment)


for row in row_info_list:
typ, state_key, event_id = row
key = (intern_string(typ), intern_string(state_key))
results[group][key] = event_id

return results

# Craft a StateFilter to use with the cache
state_filter_for_cache_lookup = StateFilter.from_types(
(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id_viewing_events),
)
)
(
results_from_cache,
incomplete_groups,
) = await self._get_state_groups_from_cache(
state_group_ids, state_filter_for_cache_lookup
)

cache_sequence_nm = self._state_group_cache.sequence
cache_sequence_m = self._state_group_members_cache.sequence

results = results_from_cache
for batch in batch_iter(incomplete_groups, 100):
group_to_state_mapping = await self.db_pool.runInteraction(
"_get_state_for_client_filtering_txn",
_get_state_for_client_filtering_txn,
batch,
)

# Now lets update the caches
# Help the cache hit ratio by expanding the filter a bit
state_filter_for_cache_insertion = (
state_filter_for_cache_lookup.return_expanded()
)
group_to_state_dict: Dict[int, StateMap[str]] = {}
group_to_state_dict.update(group_to_state_mapping)
self._insert_into_cache(
group_to_state_dict,
state_filter_for_cache_insertion,
cache_seq_num_members=cache_sequence_m,
cache_seq_num_non_members=cache_sequence_nm,
)

results.update(group_to_state_mapping)

return results

@cancellable
@trace
@tag_args
async def _get_state_for_groups(
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Dict[int, MutableStateMap[str]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key

Args:
groups: list of state groups for which we want
to get the state.
state_filter: The state filter used to fetch state
from the database.
Returns:
Dict of state group to state map.
"""
state_filter = state_filter or StateFilter.all()
(
results_from_cache,
incomplete_groups,
) = await self._get_state_groups_from_cache(groups, state_filter)

# Now fetch any missing groups from the database
state = results_from_cache
if not incomplete_groups:
return state

Expand Down
6 changes: 3 additions & 3 deletions synapse/storage/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,16 @@ def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
The new state filter.
"""
type_dict: Dict[str, Optional[Set[str]]] = {}
for typ, s in types:
for typ, state_key in types:
if typ in type_dict:
if type_dict[typ] is None:
continue

if s is None:
if state_key is None:
type_dict[typ] = None
continue

type_dict.setdefault(typ, set()).add(s) # type: ignore
type_dict.setdefault(typ, set()).add(state_key) # type: ignore

return StateFilter(
types=frozendict(
Expand Down
Loading