Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add missing types to opentracing. (#13345)
Browse files Browse the repository at this point in the history
After this change `synapse.logging` is fully typed.
  • Loading branch information
clokep committed Jul 21, 2022
1 parent 190f49d commit 5012275
Show file tree
Hide file tree
Showing 14 changed files with 83 additions and 45 deletions.
2 changes: 1 addition & 1 deletion changelog.d/13328.misc
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Add type hints to `trace` decorator.
Add missing type hints to open tracing module.
1 change: 1 addition & 0 deletions changelog.d/13345.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to open tracing module.
3 changes: 0 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@ disallow_untyped_defs = False
[mypy-synapse.http.matrixfederationclient]
disallow_untyped_defs = False

[mypy-synapse.logging.opentracing]
disallow_untyped_defs = False

[mypy-synapse.metrics._reactor_metrics]
disallow_untyped_defs = False
# This module imports select.epoll. That exists on Linux, but doesn't on macOS.
Expand Down
2 changes: 1 addition & 1 deletion synapse/federation/transport/server/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ async def new_func(
raise

# update the active opentracing span with the authenticated entity
set_tag("authenticated_entity", origin)
set_tag("authenticated_entity", str(origin))

# if the origin is authenticated and whitelisted, use its span context
# as the parent.
Expand Down
8 changes: 4 additions & 4 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ async def get_device(self, user_id: str, device_id: str) -> JsonDict:
ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
_update_device_from_client_ips(device, ips)

set_tag("device", device)
set_tag("ips", ips)
set_tag("device", str(device))
set_tag("ips", str(ips))

return device

Expand Down Expand Up @@ -170,7 +170,7 @@ async def get_user_ids_changed(
"""

set_tag("user_id", user_id)
set_tag("from_token", from_token)
set_tag("from_token", str(from_token))
now_room_key = self.store.get_room_max_token()

room_ids = await self.store.get_rooms_for_user(user_id)
Expand Down Expand Up @@ -795,7 +795,7 @@ async def incoming_device_list_update(
"""

set_tag("origin", origin)
set_tag("edu_content", edu_content)
set_tag("edu_content", str(edu_content))
user_id = edu_content.pop("user_id")
device_id = edu_content.pop("device_id")
stream_id = str(edu_content.pop("stream_id")) # They may come as ints
Expand Down
16 changes: 8 additions & 8 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ async def query_devices(
else:
remote_queries[user_id] = device_ids

set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)
set_tag("local_key_query", str(local_query))
set_tag("remote_key_query", str(remote_queries))

# First get local devices.
# A map of destination -> failure response.
Expand Down Expand Up @@ -343,7 +343,7 @@ async def _query_devices_for_destination(
failure = _exception_to_failure(e)
failures[destination] = failure
set_tag("error", True)
set_tag("reason", failure)
set_tag("reason", str(failure))

return

Expand Down Expand Up @@ -405,7 +405,7 @@ async def query_local_devices(
Returns:
A map from user_id -> device_id -> device details
"""
set_tag("local_query", query)
set_tag("local_query", str(query))
local_query: List[Tuple[str, Optional[str]]] = []

result_dict: Dict[str, Dict[str, dict]] = {}
Expand Down Expand Up @@ -477,8 +477,8 @@ async def claim_one_time_keys(
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = one_time_keys

set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)
set_tag("local_key_query", str(local_query))
set_tag("remote_key_query", str(remote_queries))

results = await self.store.claim_e2e_one_time_keys(local_query)

Expand Down Expand Up @@ -508,7 +508,7 @@ async def claim_client_keys(destination: str) -> None:
failure = _exception_to_failure(e)
failures[destination] = failure
set_tag("error", True)
set_tag("reason", failure)
set_tag("reason", str(failure))

await make_deferred_yieldable(
defer.gatherResults(
Expand Down Expand Up @@ -611,7 +611,7 @@ async def upload_keys_for_user(

result = await self.store.count_e2e_one_time_keys(user_id, device_id)

set_tag("one_time_key_counts", result)
set_tag("one_time_key_counts", str(result))
return {"one_time_key_counts": result}

async def _upload_one_time_keys_for_user(
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/e2e_room_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Dict, Optional, cast

from typing_extensions import Literal

Expand Down Expand Up @@ -97,7 +97,7 @@ async def get_room_keys(
user_id, version, room_id, session_id
)

log_kv(results)
log_kv(cast(JsonDict, results))
return results

@trace
Expand Down
44 changes: 35 additions & 9 deletions synapse/logging/opentracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def set_fates(clotho, lachesis, atropos, father="Zues", mother="Themis"):
Type,
TypeVar,
Union,
cast,
overload,
)

import attr
Expand Down Expand Up @@ -328,6 +330,7 @@ class _Sentinel(enum.Enum):

P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")


def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
Expand All @@ -343,22 +346,43 @@ def _only_if_tracing_inner(*args: P.args, **kwargs: P.kwargs) -> Optional[R]:
return _only_if_tracing_inner


def ensure_active_span(message: str, ret=None):
@overload
def ensure_active_span(
message: str,
) -> Callable[[Callable[P, R]], Callable[P, Optional[R]]]:
...


@overload
def ensure_active_span(
message: str, ret: T
) -> Callable[[Callable[P, R]], Callable[P, Union[T, R]]]:
...


def ensure_active_span(
message: str, ret: Optional[T] = None
) -> Callable[[Callable[P, R]], Callable[P, Union[Optional[T], R]]]:
"""Executes the operation only if opentracing is enabled and there is an active span.
If there is no active span it logs message at the error level.
Args:
message: Message which fills in "There was no active span when trying to %s"
in the error log if there is no active span and opentracing is enabled.
ret (object): return value if opentracing is None or there is no active span.
ret: return value if opentracing is None or there is no active span.
Returns (object): The result of the func or ret if opentracing is disabled or there
Returns:
The result of the func, falling back to ret if opentracing is disabled or there
was no active span.
"""

def ensure_active_span_inner_1(func):
def ensure_active_span_inner_1(
func: Callable[P, R]
) -> Callable[P, Union[Optional[T], R]]:
@wraps(func)
def ensure_active_span_inner_2(*args, **kwargs):
def ensure_active_span_inner_2(
*args: P.args, **kwargs: P.kwargs
) -> Union[Optional[T], R]:
if not opentracing:
return ret

Expand Down Expand Up @@ -464,7 +488,7 @@ def start_active_span(
finish_on_close: bool = True,
*,
tracer: Optional["opentracing.Tracer"] = None,
):
) -> "opentracing.Scope":
"""Starts an active opentracing span.
Records the start time for the span, and sets it as the "active span" in the
Expand Down Expand Up @@ -502,7 +526,7 @@ def start_active_span_follows_from(
*,
inherit_force_tracing: bool = False,
tracer: Optional["opentracing.Tracer"] = None,
):
) -> "opentracing.Scope":
"""Starts an active opentracing span, with additional references to previous spans
Args:
Expand Down Expand Up @@ -717,7 +741,9 @@ def inject_response_headers(response_headers: Headers) -> None:
response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}")


@ensure_active_span("get the active span context as a dict", ret={})
@ensure_active_span(
"get the active span context as a dict", ret=cast(Dict[str, str], {})
)
def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]:
"""
Gets a span context as a dict. This can be used instead of manually
Expand Down Expand Up @@ -886,7 +912,7 @@ def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R:
for i, arg in enumerate(argspec.args[1:]):
set_tag("ARG_" + arg, args[i]) # type: ignore[index]
set_tag("args", args[len(argspec.args) :]) # type: ignore[index]
set_tag("kwargs", kwargs)
set_tag("kwargs", str(kwargs))
return func(*args, **kwargs)

return _tag_args_inner
Expand Down
2 changes: 1 addition & 1 deletion synapse/metrics/background_process_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ async def run() -> Optional[R]:
f"bgproc.{desc}", tags={SynapseTags.REQUEST_ID: str(context)}
)
else:
ctx = nullcontext()
ctx = nullcontext() # type: ignore[assignment]
with ctx:
return await func(*args, **kwargs)
except Exception:
Expand Down
4 changes: 3 additions & 1 deletion synapse/rest/client/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:

# We want to enforce they do pass us one, but we ignore it and return
# changes after the "to" as well as before.
set_tag("to", parse_string(request, "to"))
#
# XXX This does not enforce that "to" is passed.
set_tag("to", str(parse_string(request, "to")))

from_token = await StreamToken.from_string(self.store, from_token_string)

Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/deviceinbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ async def delete_messages_for_device(
(user_id, device_id), None
)

set_tag("last_deleted_stream_id", last_deleted_stream_id)
set_tag("last_deleted_stream_id", str(last_deleted_stream_id))

if last_deleted_stream_id:
has_changed = self._device_inbox_stream_cache.has_entity_changed(
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,8 +706,8 @@ async def get_user_devices_from_cache(
else:
results[user_id] = await self.get_cached_devices_for_user(user_id)

set_tag("in_cache", results)
set_tag("not_in_cache", user_ids_not_in_cache)
set_tag("in_cache", str(results))
set_tag("not_in_cache", str(user_ids_not_in_cache))

return user_ids_not_in_cache, results

Expand Down
6 changes: 3 additions & 3 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ async def get_e2e_device_keys_for_cs_api(
key data. The key data will be a dict in the same format as the
DeviceKeys type returned by POST /_matrix/client/r0/keys/query.
"""
set_tag("query_list", query_list)
set_tag("query_list", str(query_list))
if not query_list:
return {}

Expand Down Expand Up @@ -418,7 +418,7 @@ async def add_e2e_one_time_keys(
def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("new_keys", new_keys)
set_tag("new_keys", str(new_keys))
# We are protected from race between lookup and insertion due to
# a unique constraint. If there is a race of two calls to
# `add_e2e_one_time_keys` then they'll conflict and we will only
Expand Down Expand Up @@ -1161,7 +1161,7 @@ def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("time_now", time_now)
set_tag("device_keys", device_keys)
set_tag("device_keys", str(device_keys))

old_key_json = self.db_pool.simple_select_one_onecol_txn(
txn,
Expand Down
30 changes: 21 additions & 9 deletions tests/logging/test_opentracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import cast

from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactorClock

Expand Down Expand Up @@ -40,6 +42,15 @@


class LogContextScopeManagerTestCase(TestCase):
"""
Test logging contexts and active opentracing spans.
There's casts throughout this from generic opentracing objects (e.g.
opentracing.Span) to the ones specific to Jaeger since they have additional
properties that these tests depend on. This is safe since the only supported
opentracing backend is Jaeger.
"""

if LogContextScopeManager is None:
skip = "Requires opentracing" # type: ignore[unreachable]
if jaeger_client is None:
Expand Down Expand Up @@ -69,7 +80,7 @@ def test_start_active_span(self) -> None:

# start_active_span should start and activate a span.
scope = start_active_span("span", tracer=self._tracer)
span = scope.span
span = cast(jaeger_client.Span, scope.span)
self.assertEqual(self._tracer.active_span, span)
self.assertIsNotNone(span.start_time)

Expand All @@ -91,6 +102,7 @@ def test_nested_spans(self) -> None:
with LoggingContext("root context"):
with start_active_span("root span", tracer=self._tracer) as root_scope:
self.assertEqual(self._tracer.active_span, root_scope.span)
root_context = cast(jaeger_client.SpanContext, root_scope.span.context)

scope1 = start_active_span(
"child1",
Expand All @@ -99,27 +111,27 @@ def test_nested_spans(self) -> None:
self.assertEqual(
self._tracer.active_span, scope1.span, "child1 was not activated"
)
self.assertEqual(
scope1.span.context.parent_id, root_scope.span.context.span_id
)
context1 = cast(jaeger_client.SpanContext, scope1.span.context)
self.assertEqual(context1.parent_id, root_context.span_id)

scope2 = start_active_span_follows_from(
"child2",
contexts=(scope1,),
tracer=self._tracer,
)
self.assertEqual(self._tracer.active_span, scope2.span)
self.assertEqual(
scope2.span.context.parent_id, scope1.span.context.span_id
)
context2 = cast(jaeger_client.SpanContext, scope2.span.context)
self.assertEqual(context2.parent_id, context1.span_id)

with scope1, scope2:
pass

# the root scope should be restored
self.assertEqual(self._tracer.active_span, root_scope.span)
self.assertIsNotNone(scope2.span.end_time)
self.assertIsNotNone(scope1.span.end_time)
span2 = cast(jaeger_client.Span, scope2.span)
span1 = cast(jaeger_client.Span, scope1.span)
self.assertIsNotNone(span2.end_time)
self.assertIsNotNone(span1.end_time)

self.assertIsNone(self._tracer.active_span)

Expand Down

0 comments on commit 5012275

Please sign in to comment.