Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Use StateFilter
Browse files Browse the repository at this point in the history
  • Loading branch information
David Robertson committed Dec 12, 2022
1 parent a71808b commit 2d490c6
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 108 deletions.
53 changes: 6 additions & 47 deletions synapse/config/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,71 +13,30 @@
# limitations under the License.

import logging
from typing import Any, Container, Dict, Iterable, Mapping, Optional, Set, Tuple, Type

import attr
from typing import Any, Iterable, Optional, Tuple

from synapse.api.constants import EventTypes
from synapse.config._base import Config, ConfigError
from synapse.config._util import validate_config
from synapse.types import JsonDict
from synapse.types.state import StateFilter

logger = logging.getLogger(__name__)


@attr.s(auto_attribs=True)
class StateKeyFilter(Container[str]):
"""A simpler version of StateFilter which ignores event types.
Represents an optional constraint that state_keys must belong to a given set of
strings called `options`. An empty set of `options` means that there are no
restrictions.
"""

options: Set[str]

@classmethod
def any(cls: Type["StateKeyFilter"]) -> "StateKeyFilter":
return cls(set())

@classmethod
def only(cls: Type["StateKeyFilter"], state_key: str) -> "StateKeyFilter":
return cls({state_key})

def __contains__(self, state_key: object) -> bool:
return not self.options or state_key in self.options

def add(self, state_key: Optional[str]) -> None:
if state_key is None:
self.options = set()
elif self.options:
self.options.add(state_key)


class ApiConfig(Config):
section = "api"

room_prejoin_state: Mapping[str, StateKeyFilter]
room_prejoin_state: StateFilter
track_puppetted_users_ips: bool

def read_config(self, config: JsonDict, **kwargs: Any) -> None:
validate_config(_MAIN_SCHEMA, config, ())
self.room_prejoin_state = self._build_prejoin_state(config)
self.room_prejoin_state = StateFilter.from_types(
self._get_prejoin_state_entries(config)
)
self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False)

def _build_prejoin_state(self, config: JsonDict) -> Dict[str, StateKeyFilter]:
prejoin_events = {}
for event_type, state_key in self._get_prejoin_state_entries(config):
if event_type not in prejoin_events:
if state_key is None:
filter = StateKeyFilter.any()
else:
filter = StateKeyFilter.only(state_key)
prejoin_events[event_type] = filter
else:
prejoin_events[event_type].add(state_key)
return prejoin_events

def _get_prejoin_state_entries(
self, config: JsonDict
) -> Iterable[Tuple[str, Optional[str]]]:
Expand Down
38 changes: 14 additions & 24 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
import threading
import weakref
from enum import Enum, auto
from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
List,
Mapping,
MutableMapping,
Optional,
Set,
Expand All @@ -46,7 +46,6 @@
RoomVersion,
RoomVersions,
)
from synapse.config.api import StateKeyFilter
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.events.utils import prune_event
Expand Down Expand Up @@ -77,6 +76,7 @@
)
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
from synapse.types.state import StateFilter
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
from synapse.util.caches.descriptors import cached, cachedList
Expand Down Expand Up @@ -880,7 +880,7 @@ def _get_events_from_local_cache(
async def get_stripped_room_state_from_event_context(
self,
context: EventContext,
state_keys_to_include: Mapping[str, StateKeyFilter],
state_keys_to_include: StateFilter,
membership_user_id: Optional[str] = None,
) -> List[JsonDict]:
"""
Expand All @@ -902,31 +902,21 @@ async def get_stripped_room_state_from_event_context(
Returns:
A list of dictionaries, each representing a stripped state event from the room.
"""
current_state_ids = await context.get_current_state_ids()
if membership_user_id:
types = chain(
state_keys_to_include.to_types(),
[(EventTypes.Member, membership_user_id)],
)
filter = StateFilter.from_types(types)
else:
filter = state_keys_to_include
selected_state_ids = await context.get_current_state_ids(filter)

# We know this event is not an outlier, so this must be
# non-None.
assert current_state_ids is not None

def should_include(t: str, s: str) -> bool:
if t in state_keys_to_include and s in state_keys_to_include[t]:
return True
if (
membership_user_id
and t == EventTypes.Member
and s == membership_user_id
):
return True
return False

# The state to include
state_to_include_ids = [
e_id
for (event_type, state_key), e_id in current_state_ids.items()
if should_include(event_type, state_key)
]
assert selected_state_ids is not None

state_to_include = await self.get_events(state_to_include_ids)
state_to_include = await self.get_events(selected_state_ids.values())

return [
{
Expand Down
67 changes: 30 additions & 37 deletions tests/config/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,32 @@
import yaml

from synapse.config import ConfigError
from synapse.config.api import ApiConfig, StateKeyFilter

DEFAULT_PREJOIN_STATE = {
"m.room.join_rules": StateKeyFilter.only(""),
"m.room.canonical_alias": StateKeyFilter.only(""),
"m.room.avatar": StateKeyFilter.only(""),
"m.room.encryption": StateKeyFilter.only(""),
"m.room.name": StateKeyFilter.only(""),
"m.room.create": StateKeyFilter.only(""),
"m.room.topic": StateKeyFilter.only(""),
from synapse.config.api import ApiConfig
from synapse.types.state import StateFilter

DEFAULT_PREJOIN_STATE_PAIRS = {
("m.room.join_rules", ""),
("m.room.canonical_alias", ""),
("m.room.avatar", ""),
("m.room.encryption", ""),
("m.room.name", ""),
("m.room.create", ""),
("m.room.topic", ""),
}


class TestRoomPrejoinState(StdlibTestCase):
def test_state_key_filter(self) -> None:
"""Sanity check the StateKeyFilter class."""
s = StateKeyFilter.only("foo")
self.assertIn("foo", s)
self.assertNotIn("bar", s)
self.assertNotIn("baz", s)
s.add("bar")
self.assertIn("foo", s)
self.assertIn("bar", s)
self.assertNotIn("baz", s)

s = StateKeyFilter.any()
self.assertIn("foo", s)
self.assertIn("bar", s)
self.assertIn("baz", s)
s.add("bar")
self.assertIn("foo", s)
self.assertIn("bar", s)
self.assertIn("baz", s)

def read_config(self, source: str) -> ApiConfig:
config = ApiConfig()
config.read_config(yaml.safe_load(source))
return config

def test_no_prejoin_state(self) -> None:
config = self.read_config("foo: bar")
self.assertEqual(config.room_prejoin_state, DEFAULT_PREJOIN_STATE)
self.assertFalse(config.room_prejoin_state.has_wildcards())
self.assertEqual(
set(config.room_prejoin_state.concrete_types()), DEFAULT_PREJOIN_STATE_PAIRS
)

def test_disable_default_event_types(self) -> None:
config = self.read_config(
Expand All @@ -53,7 +37,7 @@ def test_disable_default_event_types(self) -> None:
disable_default_event_types: true
"""
)
self.assertEqual(config.room_prejoin_state, {})
self.assertEqual(config.room_prejoin_state, StateFilter.none())

def test_event_without_state_key(self) -> None:
config = self.read_config(
Expand All @@ -64,7 +48,8 @@ def test_event_without_state_key(self) -> None:
- foo
"""
)
self.assertEqual(config.room_prejoin_state, {"foo": StateKeyFilter.any()})
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
self.assertEqual(config.room_prejoin_state.concrete_types(), [])

def test_event_with_specific_state_key(self) -> None:
config = self.read_config(
Expand All @@ -75,7 +60,11 @@ def test_event_with_specific_state_key(self) -> None:
- [foo, bar]
"""
)
self.assertEqual(config.room_prejoin_state, {"foo": StateKeyFilter.only("bar")})
self.assertFalse(config.room_prejoin_state.has_wildcards())
self.assertEqual(
set(config.room_prejoin_state.concrete_types()),
{("foo", "bar")},
)

def test_repeated_event_with_specific_state_key(self) -> None:
config = self.read_config(
Expand All @@ -87,8 +76,10 @@ def test_repeated_event_with_specific_state_key(self) -> None:
- [foo, baz]
"""
)
self.assertFalse(config.room_prejoin_state.has_wildcards())
self.assertEqual(
config.room_prejoin_state, {"foo": StateKeyFilter({"bar", "baz"})}
set(config.room_prejoin_state.concrete_types()),
{("foo", "bar"), ("foo", "baz")},
)

def test_no_specific_state_key_overrides_specific_state_key(self) -> None:
Expand All @@ -101,7 +92,8 @@ def test_no_specific_state_key_overrides_specific_state_key(self) -> None:
- foo
"""
)
self.assertEqual(config.room_prejoin_state, {"foo": StateKeyFilter.any()})
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
self.assertEqual(config.room_prejoin_state.concrete_types(), [])

config = self.read_config(
"""
Expand All @@ -112,7 +104,8 @@ def test_no_specific_state_key_overrides_specific_state_key(self) -> None:
- [foo, bar]
"""
)
self.assertEqual(config.room_prejoin_state, {"foo": StateKeyFilter.any()})
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
self.assertEqual(config.room_prejoin_state.concrete_types(), [])

def test_bad_event_type_entry_raises(self) -> None:
with self.assertRaises(ConfigError):
Expand Down

0 comments on commit 2d490c6

Please sign in to comment.