Skip to content

Commit

Permalink
Add labels firehose data stream (#304)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarshalX committed Mar 14, 2024
1 parent ad8234b commit 809866d
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 23 deletions.
43 changes: 40 additions & 3 deletions docs/source/atproto_firehose/index.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
Firehose (data streaming)
=========================

You can use the clients below to get real-time updates from the whole network. If you subscribe to retrieve messages from repositories you will get information about each created, deleted, liked, reposted post, etc.
You can use the clients below to get real-time updates from the whole network. If you subscribe to retrieve messages from repositories you will get information about each created, deleted, liked, reposted post, etc. If you subscribe to labels you will get information about added or updated labels on posts from moderation tools.

All clients present in two variants: sync and async. As a developer, you should create your own callback on a new message to handle incoming data. Here is how to do it:
All clients present in two variants: sync and async. As a developer, you should create your own callback on a new message to handle incoming data. Here is how to do it for repositories events:

.. code-block:: python
Expand All @@ -18,12 +18,27 @@ All clients present in two variants: sync and async. As a developer, you should
client.start(on_message_handler)
For labeling events:

.. code-block:: python
from atproto import FirehoseSubscribeLabelsClient, parse_subscribe_labels_message
client = FirehoseSubscribeLabelsClient()
def on_message_handler(message) -> None:
print(message.header, parse_subscribe_labels_message(message))
client.start(on_message_handler)
More code examples: https://github.com/MarshalX/atproto/tree/main/examples/firehose

.. note::
To achieve more performance you could parse only required messages using `message.header` to filter.

By default :obj:`parse_subscribe_repos_message` and :obj:`parse_subscribe_labels_message` doesn't decode inner DAG-CBOR. Probably you want to decode it. To do so use :obj:`atproto.CAR`. Example of message handler with decoding of CAR files (commit blocks):
By default :obj:`parse_subscribe_repos_message` doesn't decode inner DAG-CBOR. Probably you want to decode it. To do so use :obj:`atproto.CAR`. Example of message handler with decoding of CAR files (commit blocks):

.. code-block:: python
Expand All @@ -37,6 +52,28 @@ By default :obj:`parse_subscribe_repos_message` and :obj:`parse_subscribe_labels
car = CAR.from_bytes(commit.blocks)
Here is how you can process labeling events:

.. code-block:: python
from atproto import FirehoseSubscribeLabelsClient, firehose_models, models, parse_subscribe_labels_message
client = FirehoseSubscribeLabelsClient()
def on_message_handler(message: firehose_models.MessageFrame) -> None:
labels_message = parse_subscribe_labels_message(message)
if not isinstance(labels_message, models.ComAtprotoLabelSubscribeLabels.Labels):
return
for label in labels_message.labels:
neg = '(NEG)' if label.neg else ''
print(f'[{label.cts}] ({label.src}) {label.uri} => {label.val} {neg}')
client.start(on_message_handler)
.. automodule:: atproto_firehose
:members:
:undoc-members:
Expand Down
16 changes: 16 additions & 0 deletions examples/firehose/sub_labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from atproto import FirehoseSubscribeLabelsClient, firehose_models, models, parse_subscribe_labels_message

client = FirehoseSubscribeLabelsClient()


def on_message_handler(message: firehose_models.MessageFrame) -> None:
labels_message = parse_subscribe_labels_message(message)
if not isinstance(labels_message, models.ComAtprotoLabelSubscribeLabels.Labels):
return

for label in labels_message.labels:
neg = '(NEG)' if label.neg else ''
print(f'[{label.cts}] ({label.src}) {label.uri} => {label.val} {neg}')


client.start(on_message_handler)
18 changes: 4 additions & 14 deletions packages/atproto_firehose/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from atproto_firehose.exceptions import FirehoseDecodingError, FirehoseError
from atproto_firehose.models import ErrorFrame, Frame, MessageFrame

_BASE_WEBSOCKET_URI = 'wss://bsky.network/xrpc'
_MAX_MESSAGE_SIZE_BYTES = 1024 * 1024 * 5 # 5MB

OnMessageCallback = t.Callable[['MessageFrame'], None]
Expand All @@ -40,12 +39,7 @@
from websockets.legacy.client import Connect as AsyncConnect


def _build_websocket_uri(
method: str, base_uri: t.Optional[str] = None, params: t.Optional[t.Dict[str, t.Any]] = None
) -> str:
if base_uri is None:
base_uri = _BASE_WEBSOCKET_URI

def _build_websocket_uri(method: str, base_uri: str, params: t.Optional[t.Dict[str, t.Any]] = None) -> str:
query_string = ''
if params:
query_string = f'?{urlencode(params)}'
Expand Down Expand Up @@ -88,7 +82,7 @@ class _WebsocketClientBase:
def __init__(
self,
method: str,
base_uri: t.Optional[str] = None,
base_uri: str,
params: t.Optional[t.Dict[str, t.Any]] = None,
) -> None:
self._method = method
Expand Down Expand Up @@ -132,9 +126,7 @@ def _get_reconnection_delay(self) -> int:


class _WebsocketClient(_WebsocketClientBase):
def __init__(
self, method: str, base_uri: t.Optional[str] = None, params: t.Optional[t.Dict[str, t.Any]] = None
) -> None:
def __init__(self, method: str, base_uri: str, params: t.Optional[t.Dict[str, t.Any]] = None) -> None:
super().__init__(method, base_uri, params)

# TODO(DXsmiley): Not sure if this should be a Lock or not, the async is using an Event now
Expand Down Expand Up @@ -213,9 +205,7 @@ def stop(self) -> None:


class _AsyncWebsocketClient(_WebsocketClientBase):
def __init__(
self, method: str, base_uri: t.Optional[str] = None, params: t.Optional[t.Dict[str, t.Any]] = None
) -> None:
def __init__(self, method: str, base_uri: str, params: t.Optional[t.Dict[str, t.Any]] = None) -> None:
super().__init__(method, base_uri, params)

self._stop_event = asyncio.Event()
Expand Down
14 changes: 8 additions & 6 deletions packages/atproto_firehose/firehose.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
if t.TYPE_CHECKING:
from atproto_firehose.models import MessageFrame

# TODO(MarshalX): Everything here could be autogenerated from the lexicon.
_REPOS_BASE_WEBSOCKET_URI = 'wss://bsky.network/xrpc'
_LABELS_BASE_WEBSOCKET_URI = 'wss://mod.bsky.app/xrpc'

# TODO(MarshalX): Everything here could be autogenerated from the lexicon.
_SUBSCRIBE_REPOS_MESSAGE_TYPE_TO_MODEL = {
'#commit': models.ComAtprotoSyncSubscribeRepos.Commit,
'#handle': models.ComAtprotoSyncSubscribeRepos.Handle,
Expand All @@ -19,7 +21,7 @@
'#identity': models.ComAtprotoSyncSubscribeRepos.Identity,
}
_SUBSCRIBE_LABELS_MESSAGE_TYPE_TO_MODEL = {
'#label': models.ComAtprotoLabelSubscribeLabels.Labels,
'#labels': models.ComAtprotoLabelSubscribeLabels.Labels,
'#info': models.ComAtprotoLabelSubscribeLabels.Info,
}

Expand Down Expand Up @@ -79,7 +81,7 @@ class FirehoseSubscribeReposClient(FirehoseClient):
def __init__(
self,
params: t.Optional[t.Union[dict, 'models.ComAtprotoSyncSubscribeRepos.Params']] = None,
base_uri: t.Optional[str] = None,
base_uri: t.Optional[str] = _REPOS_BASE_WEBSOCKET_URI,
) -> None:
params_model = get_or_create(params, models.ComAtprotoSyncSubscribeRepos.Params)

Expand All @@ -101,7 +103,7 @@ class AsyncFirehoseSubscribeReposClient(AsyncFirehoseClient):
def __init__(
self,
params: t.Optional[t.Union[dict, 'models.ComAtprotoSyncSubscribeRepos.Params']] = None,
base_uri: t.Optional[str] = None,
base_uri: t.Optional[str] = _REPOS_BASE_WEBSOCKET_URI,
) -> None:
params_model = get_or_create(params, models.ComAtprotoSyncSubscribeRepos.Params)

Expand All @@ -126,7 +128,7 @@ class FirehoseSubscribeLabelsClient(FirehoseClient):
def __init__(
self,
params: t.Optional[t.Union[dict, 'models.ComAtprotoLabelSubscribeLabels.Params']] = None,
base_uri: t.Optional[str] = None,
base_uri: t.Optional[str] = _LABELS_BASE_WEBSOCKET_URI,
) -> None:
params_model = get_or_create(params, models.ComAtprotoLabelSubscribeLabels.Params)

Expand All @@ -148,7 +150,7 @@ class AsyncFirehoseSubscribeLabelsClient(AsyncFirehoseClient):
def __init__(
self,
params: t.Optional[t.Union[dict, 'models.ComAtprotoLabelSubscribeLabels.Params']] = None,
base_uri: t.Optional[str] = None,
base_uri: t.Optional[str] = _LABELS_BASE_WEBSOCKET_URI,
) -> None:
params_model = get_or_create(params, models.ComAtprotoLabelSubscribeLabels.Params)

Expand Down

0 comments on commit 809866d

Please sign in to comment.