Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add additional type hints to the storage. #8980

Merged
merged 14 commits into from
Dec 30, 2020
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ files =
synapse/spam_checker_api,
synapse/state,
synapse/storage/__init__.py,
synapse/storage/_base.py,
synapse/storage/databases/main/appservice.py,
synapse/storage/databases/main/events.py,
synapse/storage/databases/main/pusher.py,
Expand Down
32 changes: 21 additions & 11 deletions synapse/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@
import logging
import random
from abc import ABCMeta
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Iterable, Optional, Union

from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import DatabasePool
from synapse.types import Collection, get_domain_from_id
from synapse.storage.types import Connection
from synapse.types import Collection, StreamToken, get_domain_from_id
from synapse.util import json_decoder

if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer

logger = logging.getLogger(__name__)


Expand All @@ -36,24 +40,27 @@ class SQLBaseStore(metaclass=ABCMeta):
per data store (and not one per physical database).
"""

def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine
self.db_pool = database
self.rand = random.SystemRandom()

def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self, stream_name: str, instance_name: str, token: StreamToken, rows
clokep marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
pass

def _invalidate_state_caches(self, room_id, members_changed):
def _invalidate_state_caches(
self, room_id: str, members_changed: Iterable[str]
) -> None:
"""Invalidates caches that are based on the current state, but does
not stream invalidations down replication.

Args:
room_id (str): Room where state changed
members_changed (iterable[str]): The user_ids of members that have
changed
room_id: Room where state changed
members_changed: The user_ids of members that have changed
"""
for host in {get_domain_from_id(u) for u in members_changed}:
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
Expand All @@ -64,7 +71,7 @@ def _invalidate_state_caches(self, room_id, members_changed):

def _attempt_to_invalidate_cache(
self, cache_name: str, key: Optional[Collection[Any]]
):
) -> None:
"""Attempts to invalidate the cache of the given name, ignoring if the
cache doesn't exist. Mainly used for invalidating caches on workers,
where they may not have the cache.
Expand All @@ -88,12 +95,15 @@ def _attempt_to_invalidate_cache(
cache.invalidate(tuple(key))


def db_to_json(db_content):
def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
"""
Take some data from a database row and return a JSON-decoded object.

Args:
db_content (memoryview|buffer|bytes|bytearray|unicode)
db_content: The JSON-encoded contents from the database.

Returns:
The object decoded from JSON.
"""
# psycopg2 on Python 3 returns memoryview objects, which we need to
# cast to bytes to decode
Expand Down