Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Don't pull out state in compute_event_context for unconflicted state #13267

Merged
merged 8 commits into from
Jul 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/13267.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Don't pull out state in `compute_event_context` for unconflicted state.
7 changes: 6 additions & 1 deletion synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -1444,7 +1444,12 @@ async def cache_joined_hosts_for_event(
if state_entry.state_group in self._external_cache_joined_hosts_updates:
return

joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
state = await state_entry.get_state(
self._storage_controllers.state, StateFilter.all()
)
joined_hosts = await self.store.get_joined_hosts(
event.room_id, state, state_entry
)

# Note that the expiry times must be larger than the expiry time in
# _external_cache_joined_hosts_updates.
Expand Down
117 changes: 67 additions & 50 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
Sequence,
Set,
Tuple,
Union,
)

import attr
Expand All @@ -47,13 +46,15 @@
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo
from synapse.storage.state import StateFilter
from synapse.types import StateMap
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure, measure_func

if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.controllers import StateStorageController
from synapse.storage.databases.main import DataStore

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -83,17 +84,20 @@ def _gen_state_id() -> str:


class _StateCacheEntry:
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
__slots__ = ["state", "state_group", "prev_group", "delta_ids"]
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
state: StateMap[str],
state: Optional[StateMap[str]],
state_group: Optional[int],
prev_group: Optional[int] = None,
delta_ids: Optional[StateMap[str]] = None,
):
if state is None and state_group is None:
raise Exception("Either state or state group must be not None")

# A map from (type, state_key) to event_id.
self.state = frozendict(state)
self.state = frozendict(state) if state is not None else None

# the ID of a state group if one and only one is involved.
# otherwise, None otherwise?
Expand All @@ -102,20 +106,30 @@ def __init__(
self.prev_group = prev_group
self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None

# The `state_id` is a unique ID we generate that can be used as ID for
# this collection of state. Usually this would be the same as the
# state group, but on worker instances we can't generate a new state
# group each time we resolve state, so we generate a separate one that
# isn't persisted and is used solely for caches.
# `state_id` is either a state_group (and so an int) or a string. This
# ensures we don't accidentally persist a state_id as a stateg_group
if state_group:
self.state_id: Union[str, int] = state_group
else:
self.state_id = _gen_state_id()
async def get_state(
self,
state_storage: "StateStorageController",
state_filter: Optional["StateFilter"] = None,
) -> StateMap[str]:
"""Get the state map for this entry, either from the in-memory state or
looking up the state group in the DB.
"""

if self.state is not None:
return self.state

assert self.state_group is not None

return await state_storage.get_state_ids_for_group(
self.state_group, state_filter
)

def __len__(self) -> int:
return len(self.state)
# The len should is used to estimate how large this cache entry is, for
# cache eviction purposes. This is why if `self.state` is None it's fine
# to return 1.

return len(self.state) if self.state else 1


class StateHandler:
Expand Down Expand Up @@ -153,7 +167,7 @@ async def get_current_state_ids(
"""
logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return ret.state
return await ret.get_state(self._state_storage_controller, StateFilter.all())

async def get_current_users_in_room(
self, room_id: str, latest_event_ids: List[str]
Expand All @@ -177,7 +191,8 @@ async def get_current_users_in_room(

logger.debug("calling resolve_state_groups from get_current_users_in_room")
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return await self.store.get_joined_users_from_state(room_id, entry)
state = await entry.get_state(self._state_storage_controller, StateFilter.all())
return await self.store.get_joined_users_from_state(room_id, state, entry)

async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str]
Expand All @@ -192,7 +207,8 @@ async def get_hosts_in_room_at_events(
The hosts in the room at the given events
"""
entry = await self.resolve_state_groups_for_events(room_id, event_ids)
return await self.store.get_joined_hosts(room_id, entry)
state = await entry.get_state(self._state_storage_controller, StateFilter.all())
return await self.store.get_joined_hosts(room_id, state, entry)

async def compute_event_context(
self,
Expand Down Expand Up @@ -227,10 +243,19 @@ async def compute_event_context(
#
if state_ids_before_event:
# if we're given the state before the event, then we use that
state_group_before_event = None
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
entry = None

# .. though we need to get a state group for it.
state_group_before_event = (
await self._state_storage_controller.store_state_group(
event.event_id,
event.room_id,
prev_group=None,
delta_ids=None,
current_state_ids=state_ids_before_event,
)
)

else:
# otherwise, we'll need to resolve the state across the prev_events.
Expand Down Expand Up @@ -264,36 +289,27 @@ async def compute_event_context(
await_full_state=False,
)

state_ids_before_event = entry.state
state_group_before_event = entry.state_group
state_group_before_event_prev_group = entry.prev_group
deltas_to_state_group_before_event = entry.delta_ids

#
# make sure that we have a state group at that point. If it's not a state event,
# that will be the state group for the new event. If it *is* a state event,
# it might get rejected (in which case we'll need to persist it with the
# previous state group)
#

if not state_group_before_event:
state_group_before_event = (
await self._state_storage_controller.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event_prev_group,
delta_ids=deltas_to_state_group_before_event,
current_state_ids=state_ids_before_event,
# We make sure that we have a state group assigned to the state.
if entry.state_group is None:
state_ids_before_event = await entry.get_state(
self._state_storage_controller, StateFilter.all()
)
state_group_before_event = (
await self._state_storage_controller.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event_prev_group,
delta_ids=deltas_to_state_group_before_event,
current_state_ids=state_ids_before_event,
)
)
)

# Assign the new state group to the cached state entry.
#
# Note that this can race in that we could generate multiple state
# groups for the same state entry, but that is just inefficient
# rather than dangerous.
if entry and entry.state_group is None:
entry.state_group = state_group_before_event
else:
state_group_before_event = entry.state_group
state_ids_before_event = None

#
# now if it's not a state event, we're done
Expand All @@ -313,6 +329,10 @@ async def compute_event_context(
#
# otherwise, we'll need to create a new state group for after the event
#
if state_ids_before_event is None:
state_ids_before_event = await entry.get_state(
self._state_storage_controller, StateFilter.all()
)

key = (event.type, event.state_key)
if key in state_ids_before_event:
Expand Down Expand Up @@ -372,17 +392,14 @@ async def resolve_state_groups_for_events(
state_group_ids_set = set(state_group_ids)
if len(state_group_ids_set) == 1:
(state_group_id,) = state_group_ids_set
state = await self._state_storage_controller.get_state_for_groups(
state_group_ids_set
)
(
prev_group,
delta_ids,
) = await self._state_storage_controller.get_state_group_delta(
state_group_id
)
return _StateCacheEntry(
state=state[state_group_id],
state=None,
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
state_group=state_group_id,
prev_group=prev_group,
delta_ids=delta_ids,
Expand Down
4 changes: 3 additions & 1 deletion synapse/storage/controllers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,6 @@ def __init__(self, hs: "HomeServer", stores: Databases):

self.persistence = None
if stores.persist_events:
self.persistence = EventsPersistenceStorageController(hs, stores)
self.persistence = EventsPersistenceStorageController(
hs, stores, self.state
)
12 changes: 10 additions & 2 deletions synapse/storage/controllers/persist_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@
from synapse.logging import opentracing
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.controllers.state import StateStorageController
from synapse.storage.databases import Databases
from synapse.storage.databases.main.events import DeltaState
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import (
PersistedEventPosition,
RoomStreamToken,
Expand Down Expand Up @@ -308,7 +310,12 @@ class EventsPersistenceStorageController:
current state and forward extremity changes.
"""

def __init__(self, hs: "HomeServer", stores: Databases):
def __init__(
self,
hs: "HomeServer",
stores: Databases,
state_controller: StateStorageController,
):
# We ultimately want to split out the state store from the main store,
# so we use separate variables here even though they point to the same
# store for now.
Expand All @@ -325,6 +332,7 @@ def __init__(self, hs: "HomeServer", stores: Databases):
self._process_event_persist_queue_task
)
self._state_resolution_handler = hs.get_state_resolution_handler()
self._state_controller = state_controller

async def _process_event_persist_queue_task(
self,
Expand Down Expand Up @@ -504,7 +512,7 @@ async def _calculate_current_state(self, room_id: str) -> StateMap[str]:
state_res_store=StateResolutionStore(self.main_store),
)

return res.state
return await res.get_state(self._state_controller, StateFilter.all())

async def _persist_event_batch(
self, _room_id: str, task: _PersistEventsTask
Expand Down
35 changes: 8 additions & 27 deletions synapse/storage/databases/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import (
run_as_background_process,
Expand Down Expand Up @@ -780,26 +779,8 @@ async def get_mutual_rooms_between_users(

return shared_room_ids or frozenset()

async def get_joined_users_from_context(
self, event: EventBase, context: EventContext
) -> Dict[str, ProfileInfo]:
state_group: Union[object, int] = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()

current_state_ids = await context.get_current_state_ids()
assert current_state_ids is not None
assert state_group is not None
return await self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids, event=event, context=context
)

async def get_joined_users_from_state(
self, room_id: str, state_entry: "_StateCacheEntry"
self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
) -> Dict[str, ProfileInfo]:
state_group: Union[object, int] = state_entry.state_group
if not state_group:
Expand All @@ -812,18 +793,17 @@ async def get_joined_users_from_state(
assert state_group is not None
with Measure(self._clock, "get_joined_users_from_state"):
return await self._get_joined_users_from_context(
room_id, state_group, state_entry.state, context=state_entry
room_id, state_group, state, context=state_entry
)

@cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
@cached(num_args=2, iterable=True, max_entries=100000)
async def _get_joined_users_from_context(
self,
room_id: str,
state_group: Union[object, int],
current_state_ids: StateMap[str],
cache_context: _CacheContext,
event: Optional[EventBase] = None,
context: Optional[Union[EventContext, "_StateCacheEntry"]] = None,
context: Optional["_StateCacheEntry"] = None,
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
) -> Dict[str, ProfileInfo]:
# We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states
Expand Down Expand Up @@ -1017,7 +997,7 @@ def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
)

async def get_joined_hosts(
self, room_id: str, state_entry: "_StateCacheEntry"
self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
) -> FrozenSet[str]:
state_group: Union[object, int] = state_entry.state_group
if not state_group:
Expand All @@ -1030,14 +1010,15 @@ async def get_joined_hosts(
assert state_group is not None
with Measure(self._clock, "get_joined_hosts"):
return await self._get_joined_hosts(
room_id, state_group, state_entry=state_entry
room_id, state_group, state, state_entry=state_entry
)

@cached(num_args=2, max_entries=10000, iterable=True)
async def _get_joined_hosts(
self,
room_id: str,
state_group: Union[object, int],
state: StateMap[str],
state_entry: "_StateCacheEntry",
) -> FrozenSet[str]:
# We don't use `state_group`, it's there so that we can cache based on
Expand Down Expand Up @@ -1093,7 +1074,7 @@ async def _get_joined_hosts(
# The cache doesn't match the state group or prev state group,
# so we calculate the result from first principles.
joined_users = await self.get_joined_users_from_state(
room_id, state_entry
room_id, state, state_entry
)

cache.hosts_to_joined_users = {}
Expand Down
Loading