Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Adds misc missing type hints #11953

Merged
merged 5 commits into from
Feb 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11953.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints.
6 changes: 6 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ disallow_untyped_defs = True
[mypy-synapse.crypto.*]
disallow_untyped_defs = True

[mypy-synapse.event_auth]
disallow_untyped_defs = True

[mypy-synapse.events.*]
disallow_untyped_defs = True

Expand All @@ -166,6 +169,9 @@ disallow_untyped_defs = True
[mypy-synapse.module_api.*]
disallow_untyped_defs = True

[mypy-synapse.notifier]
disallow_untyped_defs = True

[mypy-synapse.push.*]
disallow_untyped_defs = True

Expand Down
4 changes: 3 additions & 1 deletion synapse/event_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,9 @@ def get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -
return default


def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase]):
def _verify_third_party_invite(
event: EventBase, auth_events: StateMap[EventBase]
) -> bool:
"""
Validates that the invite event is authorized by a previous third-party invite.

Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,9 +544,9 @@ async def _exchange_code(self, code: str) -> Token:
"""
metadata = await self.load_metadata()
token_endpoint = metadata.get("token_endpoint")
raw_headers = {
raw_headers: Dict[str, str] = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": self._http_client.user_agent,
"User-Agent": self._http_client.user_agent.decode("ascii"),
"Accept": "application/json",
}

Expand Down
11 changes: 5 additions & 6 deletions synapse/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,21 +322,20 @@ def __init__(
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
self._extra_treq_args = treq_args or {}

self.user_agent = hs.version_string
self.clock = hs.get_clock()

user_agent = hs.version_string
if hs.config.server.user_agent_suffix:
self.user_agent = "%s %s" % (
self.user_agent,
user_agent = "%s %s" % (
user_agent,
hs.config.server.user_agent_suffix,
)
self.user_agent = user_agent.encode("ascii")

# We use this for our body producers to ensure that they use the correct
# reactor.
self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor()))

self.user_agent = self.user_agent.encode("ascii")

if self._ip_blacklist:
# If we have an IP blacklist, we need to use a DNS resolver which
# filters out blacklisted IP addresses, to prevent DNS rebinding.
Expand Down
3 changes: 1 addition & 2 deletions synapse/http/matrixfederationclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,12 +334,11 @@ def __init__(self, hs: "HomeServer", tls_client_options_factory):
user_agent = hs.version_string
if hs.config.server.user_agent_suffix:
user_agent = "%s %s" % (user_agent, hs.config.server.user_agent_suffix)
user_agent = user_agent.encode("ascii")

federation_agent = MatrixFederationAgent(
self.reactor,
tls_client_options_factory,
user_agent,
user_agent.encode("ascii"),
hs.config.server.federation_ip_range_whitelist,
hs.config.server.federation_ip_range_blacklist,
)
Expand Down
43 changes: 24 additions & 19 deletions synapse/notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import logging
from typing import (
TYPE_CHECKING,
Awaitable,
Callable,
Collection,
Expand All @@ -32,7 +33,6 @@

from twisted.internet import defer

import synapse.server
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.api.errors import AuthError
from synapse.events import EventBase
Expand All @@ -53,6 +53,9 @@
from synapse.util.metrics import Measure
from synapse.visibility import filter_events_for_client

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)

notified_events_counter = Counter("synapse_notifier_notified_events", "")
Expand Down Expand Up @@ -82,7 +85,7 @@ class _NotificationListener:

__slots__ = ["deferred"]

def __init__(self, deferred):
def __init__(self, deferred: "defer.Deferred"):
self.deferred = deferred


Expand Down Expand Up @@ -124,7 +127,7 @@ def notify(
stream_key: str,
stream_id: Union[int, RoomStreamToken],
time_now_ms: int,
):
) -> None:
"""Notify any listeners for this user of a new event from an
event source.
Args:
Expand Down Expand Up @@ -152,7 +155,7 @@ def notify(
self.notify_deferred = ObservableDeferred(defer.Deferred())
noify_deferred.callback(self.current_token)

def remove(self, notifier: "Notifier"):
def remove(self, notifier: "Notifier") -> None:
"""Remove this listener from all the indexes in the Notifier
it knows about.
"""
Expand Down Expand Up @@ -188,7 +191,7 @@ class EventStreamResult:
start_token: StreamToken
end_token: StreamToken

def __bool__(self):
def __bool__(self) -> bool:
return bool(self.events)


Expand All @@ -212,7 +215,7 @@ class Notifier:

UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000

def __init__(self, hs: "synapse.server.HomeServer"):
def __init__(self, hs: "HomeServer"):
self.user_to_user_stream: Dict[str, _NotifierUserStream] = {}
self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {}

Expand Down Expand Up @@ -248,7 +251,7 @@ def __init__(self, hs: "synapse.server.HomeServer"):
# This is not a very cheap test to perform, but it's only executed
# when rendering the metrics page, which is likely once per minute at
# most when scraping it.
def count_listeners():
def count_listeners() -> int:
all_user_streams: Set[_NotifierUserStream] = set()

for streams in list(self.room_to_user_streams.values()):
Expand All @@ -270,7 +273,7 @@ def count_listeners():
"synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream)
)

def add_replication_callback(self, cb: Callable[[], None]):
def add_replication_callback(self, cb: Callable[[], None]) -> None:
"""Add a callback that will be called when some new data is available.
Callback is not given any arguments. It should *not* return a Deferred - if
it needs to do any asynchronous work, a background thread should be started and
Expand All @@ -284,7 +287,7 @@ async def on_new_room_event(
event_pos: PersistedEventPosition,
max_room_stream_token: RoomStreamToken,
extra_users: Optional[Collection[UserID]] = None,
):
) -> None:
"""Unwraps event and calls `on_new_room_event_args`."""
await self.on_new_room_event_args(
event_pos=event_pos,
Expand All @@ -307,7 +310,7 @@ async def on_new_room_event_args(
event_pos: PersistedEventPosition,
max_room_stream_token: RoomStreamToken,
extra_users: Optional[Collection[UserID]] = None,
):
) -> None:
"""Used by handlers to inform the notifier something has happened
in the room, room event wise.

Expand Down Expand Up @@ -338,7 +341,9 @@ async def on_new_room_event_args(

self.notify_replication()

def _notify_pending_new_room_events(self, max_room_stream_token: RoomStreamToken):
def _notify_pending_new_room_events(
self, max_room_stream_token: RoomStreamToken
) -> None:
"""Notify for the room events that were queued waiting for a previous
event to be persisted.
Args:
Expand Down Expand Up @@ -374,7 +379,7 @@ def _notify_pending_new_room_events(self, max_room_stream_token: RoomStreamToken
)
self._on_updated_room_token(max_room_stream_token)

def _on_updated_room_token(self, max_room_stream_token: RoomStreamToken):
def _on_updated_room_token(self, max_room_stream_token: RoomStreamToken) -> None:
"""Poke services that might care that the room position has been
updated.
"""
Expand All @@ -386,13 +391,13 @@ def _on_updated_room_token(self, max_room_stream_token: RoomStreamToken):
if self.federation_sender:
self.federation_sender.notify_new_events(max_room_stream_token)

def _notify_app_services(self, max_room_stream_token: RoomStreamToken):
def _notify_app_services(self, max_room_stream_token: RoomStreamToken) -> None:
try:
self.appservice_handler.notify_interested_services(max_room_stream_token)
except Exception:
logger.exception("Error notifying application services of event")

def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken):
def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken) -> None:
try:
self._pusher_pool.on_new_notifications(max_room_stream_token)
except Exception:
Expand Down Expand Up @@ -475,8 +480,8 @@ async def wait_for_events(
user_id: str,
timeout: int,
callback: Callable[[StreamToken, StreamToken], Awaitable[T]],
room_ids=None,
from_token=StreamToken.START,
room_ids: Optional[Collection[str]] = None,
from_token: StreamToken = StreamToken.START,
) -> T:
"""Wait until the callback returns a non empty response or the
timeout fires.
Expand Down Expand Up @@ -700,14 +705,14 @@ def remove_expired_streams(self) -> None:
for expired_stream in expired_streams:
expired_stream.remove(self)

def _register_with_keys(self, user_stream: _NotifierUserStream):
def _register_with_keys(self, user_stream: _NotifierUserStream) -> None:
self.user_to_user_stream[user_stream.user_id] = user_stream

for room in user_stream.rooms:
s = self.room_to_user_streams.setdefault(room, set())
s.add(user_stream)

def _user_joined_room(self, user_id: str, room_id: str):
def _user_joined_room(self, user_id: str, room_id: str) -> None:
new_user_stream = self.user_to_user_stream.get(user_id)
if new_user_stream is not None:
room_streams = self.room_to_user_streams.setdefault(room_id, set())
Expand All @@ -719,7 +724,7 @@ def notify_replication(self) -> None:
for cb in self.replication_callbacks:
cb()

def notify_remote_server_up(self, server: str):
def notify_remote_server_up(self, server: str) -> None:
"""Notify any replication that a remote server has come back up"""
# We call federation_sender directly rather than registering as a
# callback as a) we already have a reference to it and b) it introduces
Expand Down
8 changes: 4 additions & 4 deletions synapse/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,8 @@ def __init__(
self,
hostname: str,
config: HomeServerConfig,
reactor=None,
version_string="Synapse",
reactor: Optional[ISynapseReactor] = None,
version_string: str = "Synapse",
):
"""
Args:
Expand All @@ -244,7 +244,7 @@ def __init__(
if not reactor:
from twisted.internet import reactor as _reactor

reactor = _reactor
reactor = cast(ISynapseReactor, _reactor)

self._reactor = reactor
self.hostname = hostname
Expand All @@ -264,7 +264,7 @@ def __init__(
self._module_web_resources: Dict[str, Resource] = {}
self._module_web_resources_consumed = False

def register_module_web_resource(self, path: str, resource: Resource):
def register_module_web_resource(self, path: str, resource: Resource) -> None:
"""Allows a module to register a web resource to be served at the given path.

If multiple modules register a resource for the same path, the module that
Expand Down
9 changes: 2 additions & 7 deletions tests/handlers/test_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def default_config(self):
def make_homeserver(self, reactor, clock):
self.http_client = Mock(spec=["get_json"])
self.http_client.get_json.side_effect = get_json
self.http_client.user_agent = "Synapse Test"
self.http_client.user_agent = b"Synapse Test"

hs = self.setup_test_homeserver(proxied_http_client=self.http_client)

Expand Down Expand Up @@ -438,12 +438,9 @@ def test_callback(self):
state = "state"
nonce = "nonce"
client_redirect_url = "http://client/redirect"
user_agent = "Browser"
ip_address = "10.0.0.1"
session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
request = _build_callback_request(
code, state, session, user_agent=user_agent, ip_address=ip_address
)
request = _build_callback_request(code, state, session, ip_address=ip_address)

self.get_success(self.handler.handle_oidc_callback(request))

Expand Down Expand Up @@ -1274,7 +1271,6 @@ def _build_callback_request(
code: str,
state: str,
session: str,
user_agent: str = "Browser",
ip_address: str = "10.0.0.1",
):
"""Builds a fake SynapseRequest to mock the browser callback
Expand All @@ -1289,7 +1285,6 @@ def _build_callback_request(
query param. Should be the same as was embedded in the session in
_build_oidc_session.
session: the "session" which would have been passed around in the cookie.
user_agent: the user-agent to present
ip_address: the IP address to pretend the request came from
"""
request = Mock(
Expand Down