Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Recursively fetch the thread for receipts & notifications. (#13824)
Browse files Browse the repository at this point in the history
Consider an event to be part of a thread if you can follow a
chain of relations up to a thread root.

Part of MSC3773 & MSC3771.
  • Loading branch information
clokep committed Oct 4, 2022
1 parent 3e74ad2 commit 2b6d41e
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 2 deletions.
1 change: 1 addition & 0 deletions changelog.d/13824.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).
5 changes: 5 additions & 0 deletions synapse/push/bulk_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,13 @@ async def action_for_event_by_user(
relation.parent_id,
itertools.chain(*(r.rules() for r in rules_by_user.values())),
)
# Recursively attempt to find the thread this event relates to.
if relation.rel_type == RelationTypes.THREAD:
thread_id = relation.parent_id
else:
# Since the event has not yet been persisted we check whether
# the parent is part of a thread.
thread_id = await self.store.get_thread_id(relation.parent_id) or "main"

evaluator = PushRuleEvaluator(
_flatten_dict(event),
Expand Down
22 changes: 20 additions & 2 deletions synapse/rest/client/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import TYPE_CHECKING, Tuple

from synapse.api.constants import ReceiptTypes
from synapse.api.errors import SynapseError
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
Expand All @@ -43,6 +43,7 @@ def __init__(self, hs: "HomeServer"):
self.receipts_handler = hs.get_receipts_handler()
self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler()
self._main_store = hs.get_datastores().main

self._known_receipt_types = {
ReceiptTypes.READ,
Expand Down Expand Up @@ -71,7 +72,24 @@ async def on_POST(
thread_id = body.get("thread_id")
if not thread_id or not isinstance(thread_id, str):
raise SynapseError(
400, "thread_id field must be a non-empty string"
400,
"thread_id field must be a non-empty string",
Codes.INVALID_PARAM,
)

if receipt_type == ReceiptTypes.FULLY_READ:
raise SynapseError(
400,
f"thread_id is not compatible with {ReceiptTypes.FULLY_READ} receipts.",
Codes.INVALID_PARAM,
)

# Ensure the event ID roughly correlates to the thread ID.
if thread_id != await self._main_store.get_thread_id(event_id):
raise SynapseError(
400,
f"event_id {event_id} is not related to thread {thread_id}",
Codes.INVALID_PARAM,
)

await self.presence_handler.bump_presence_active_time(requester.user)
Expand Down
36 changes: 36 additions & 0 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,42 @@ def _get_event_relations(
"get_event_relations", _get_event_relations
)

@cached()
async def get_thread_id(self, event_id: str) -> Optional[str]:
"""
Get the thread ID for an event. This considers multi-level relations,
e.g. an annotation to an event which is part of a thread.
Args:
event_id: The event ID to fetch the thread ID for.
Returns:
The event ID of the root event in the thread, if this event is part
of a thread. None, otherwise.
"""
# Since event relations form a tree, we should only ever find 0 or 1
# results from the below query.
sql = """
WITH RECURSIVE related_events AS (
SELECT event_id, relates_to_id, relation_type
FROM event_relations
WHERE event_id = ?
UNION SELECT e.event_id, e.relates_to_id, e.relation_type
FROM event_relations e
INNER JOIN related_events r ON r.relates_to_id = e.event_id
) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread';
"""

def _get_thread_id(txn: LoggingTransaction) -> Optional[str]:
txn.execute(sql, (event_id,))
# TODO Should we ensure there's only a single result here?
row = txn.fetchone()
if row:
return row[0]
return None

return await self.db_pool.runInteraction("get_thread_id", _get_thread_id)


class RelationsStore(RelationsWorkerStore):
pass
100 changes: 100 additions & 0 deletions tests/storage/test_event_push_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,106 @@ def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None:
_rotate()
_assert_counts(0, 0, 0, 0)

def test_recursive_thread(self) -> None:
"""
Events related to events in a thread should still be considered part of
that thread.
"""

# Create a user to receive notifications and send receipts.
user_id = self.register_user("user1235", "pass")
token = self.login("user1235", "pass")

# And another users to send events.
other_id = self.register_user("other", "pass")
other_token = self.login("other", "pass")

# Create a room and put both users in it.
room_id = self.helper.create_room_as(user_id, tok=token)
self.helper.join(room_id, other_id, tok=other_token)

# Update the user's push rules to care about reaction events.
self.get_success(
self.store.add_push_rule(
user_id,
"related_events",
priority_class=5,
conditions=[
{"kind": "event_match", "key": "type", "pattern": "m.reaction"}
],
actions=["notify"],
)
)

def _create_event(type: str, content: JsonDict) -> str:
result = self.helper.send_event(
room_id, type=type, content=content, tok=other_token
)
return result["event_id"]

def _assert_counts(noitf_count: int, thread_notif_count: int) -> None:
counts = self.get_success(
self.store.db_pool.runInteraction(
"get-unread-counts",
self.store._get_unread_counts_by_receipt_txn,
room_id,
user_id,
)
)
self.assertEqual(
counts.main_timeline,
NotifCounts(
notify_count=noitf_count, unread_count=0, highlight_count=0
),
)
if thread_notif_count:
self.assertEqual(
counts.threads,
{
thread_id: NotifCounts(
notify_count=thread_notif_count,
unread_count=0,
highlight_count=0,
),
},
)
else:
self.assertEqual(counts.threads, {})

# Create a root event.
thread_id = _create_event(
"m.room.message", {"msgtype": "m.text", "body": "msg"}
)
_assert_counts(1, 0)

# Reply, creating a thread.
reply_id = _create_event(
"m.room.message",
{
"msgtype": "m.text",
"body": "msg",
"m.relates_to": {
"rel_type": "m.thread",
"event_id": thread_id,
},
},
)
_assert_counts(1, 1)

# Create an event related to a thread event, this should still appear in
# the thread.
_create_event(
type="m.reaction",
content={
"m.relates_to": {
"rel_type": "m.annotation",
"event_id": reply_id,
"key": "A",
}
},
)
_assert_counts(1, 2)

def test_find_first_stream_ordering_after_ts(self) -> None:
def add_event(so: int, ts: int) -> None:
self.get_success(
Expand Down

0 comments on commit 2b6d41e

Please sign in to comment.