Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add additional type hints to the storage module. (#8980)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Dec 30, 2020
1 parent b859189 commit 637282b
Show file tree
Hide file tree
Showing 12 changed files with 224 additions and 148 deletions.
1 change: 1 addition & 0 deletions changelog.d/8980.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to the base storage code.
10 changes: 10 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ files =
synapse/server_notices,
synapse/spam_checker_api,
synapse/state,
synapse/storage/__init__.py,
synapse/storage/_base.py,
synapse/storage/background_updates.py,
synapse/storage/databases/main/appservice.py,
synapse/storage/databases/main/events.py,
synapse/storage/databases/main/pusher.py,
Expand All @@ -78,8 +81,15 @@ files =
synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py,
synapse/storage/engines,
synapse/storage/keys.py,
synapse/storage/persist_events.py,
synapse/storage/prepare_database.py,
synapse/storage/purge_events.py,
synapse/storage/push_rule.py,
synapse/storage/relations.py,
synapse/storage/roommember.py,
synapse/storage/state.py,
synapse/storage/types.py,
synapse/storage/util,
synapse/streams,
synapse/types.py,
Expand Down
4 changes: 1 addition & 3 deletions synapse/handlers/initial_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,7 @@ async def _room_initial_sync_parted(
member_event_id: str,
is_peeking: bool,
) -> JsonDict:
room_state = await self.state_store.get_state_for_events([member_event_id])

room_state = room_state[member_event_id]
room_state = await self.state_store.get_state_for_event(member_event_id)

limit = pagin_config.limit if pagin_config else None
if limit is None:
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ async def get_state_after_event(
event.event_id, state_filter=state_filter
)
if event.is_state():
state_ids = state_ids.copy()
state_ids = dict(state_ids)
state_ids[(event.type, event.state_key)] = event.event_id
return state_ids

Expand Down
9 changes: 7 additions & 2 deletions synapse/storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,26 @@
data stores associated with them (e.g. the schema version tables), which are
stored in `synapse.storage.schema`.
"""
from typing import TYPE_CHECKING

from synapse.storage.databases import Databases
from synapse.storage.databases.main import DataStore
from synapse.storage.persist_events import EventsPersistenceStorage
from synapse.storage.purge_events import PurgeEventsStorage
from synapse.storage.state import StateGroupStorage

__all__ = ["DataStores", "DataStore"]
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer


__all__ = ["Databases", "DataStore"]


class Storage:
"""The high level interfaces for talking to various storage layers.
"""

def __init__(self, hs, stores: Databases):
def __init__(self, hs: "HomeServer", stores: Databases):
# We include the main data store here mainly so that we don't have to
# rewrite all the existing code to split it into high vs low level
# interfaces.
Expand Down
36 changes: 25 additions & 11 deletions synapse/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@
import logging
import random
from abc import ABCMeta
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Iterable, Optional, Union

from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import DatabasePool
from synapse.types import Collection, get_domain_from_id
from synapse.storage.types import Connection
from synapse.types import Collection, StreamToken, get_domain_from_id
from synapse.util import json_decoder

if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer

logger = logging.getLogger(__name__)


Expand All @@ -36,24 +40,31 @@ class SQLBaseStore(metaclass=ABCMeta):
per data store (and not one per physical database).
"""

def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine
self.db_pool = database
self.rand = random.SystemRandom()

def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: StreamToken,
rows: Iterable[Any],
) -> None:
pass

def _invalidate_state_caches(self, room_id, members_changed):
def _invalidate_state_caches(
self, room_id: str, members_changed: Iterable[str]
) -> None:
"""Invalidates caches that are based on the current state, but does
not stream invalidations down replication.
Args:
room_id (str): Room where state changed
members_changed (iterable[str]): The user_ids of members that have
changed
room_id: Room where state changed
members_changed: The user_ids of members that have changed
"""
for host in {get_domain_from_id(u) for u in members_changed}:
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
Expand All @@ -64,7 +75,7 @@ def _invalidate_state_caches(self, room_id, members_changed):

def _attempt_to_invalidate_cache(
self, cache_name: str, key: Optional[Collection[Any]]
):
) -> None:
"""Attempts to invalidate the cache of the given name, ignoring if the
cache doesn't exist. Mainly used for invalidating caches on workers,
where they may not have the cache.
Expand All @@ -88,12 +99,15 @@ def _attempt_to_invalidate_cache(
cache.invalidate(tuple(key))


def db_to_json(db_content):
def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
"""
Take some data from a database row and return a JSON-decoded object.
Args:
db_content (memoryview|buffer|bytes|bytearray|unicode)
db_content: The JSON-encoded contents from the database.
Returns:
The object decoded from JSON.
"""
# psycopg2 on Python 3 returns memoryview objects, which we need to
# cast to bytes to decode
Expand Down
Loading

0 comments on commit 637282b

Please sign in to comment.