Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom feed generators #47

Merged
merged 7 commits into from
May 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions atproto/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@
OUTPUT_MODEL = 'Response'


def format_code(filepath: Path) -> None:
def format_code(filepath: Path, quiet: bool = True) -> None:
if not isinstance(filepath, Path):
return

subprocess.run(['ruff', '--quiet', '--fix', filepath]) # noqa: S603, S607
subprocess.run(['black', '--quiet', filepath]) # noqa: S603, S607
quiet_option = '--quiet'
if not quiet:
quiet_option = ''

# FIXME(MarshalX): doesn't work well with not-project dir
subprocess.run(['ruff', quiet_option, '--fix', filepath]) # noqa: S603, S607
subprocess.run(['black', quiet_option, filepath]) # noqa: S603, S607


def append_code(filepath: Path, code: str) -> None:
Expand Down
13 changes: 11 additions & 2 deletions atproto/codegen/models/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ def _generate_import_aliases(root_package_path: Path) -> None:
# from xrpc_client.models.app.bsky.actor import defs as AppBskyActorDefs # noqa: ERA001

import_lines = []
ids_db = ['class _Ids:']
for root, __, files in os.walk(root_package_path):
root = Path(root)

Expand All @@ -543,11 +544,19 @@ def _generate_import_aliases(root_package_path: Path) -> None:
import_parts = root.parts[root.parts.index(_MODELS_OUTPUT_DIR.parent.name) :]
from_import = '.'.join(import_parts)

nsid_parts = list(root.parts[root.parts.index('models') + 1 :]) + file[:-3].split('_')
alias_name = ''.join([p.capitalize() for p in nsid_parts])
nsid_parts = list(root.parts[root.parts.index('models') + 1 :])
method_name_parts = file[:-3].split('_')
alias_name = ''.join([p.capitalize() for p in [*nsid_parts, *method_name_parts]])

camel_case_method_name = method_name_parts[0] + ''.join(ele.title() for ele in method_name_parts[1:])
method_path = f"{'.'.join(nsid_parts)}.{camel_case_method_name}"
ids_db.append(f"{_(1)}{alias_name}: str = '{method_path}'")

import_lines.append(f'from atproto.{from_import} import {file[:-3]} as {alias_name}')

ids_db.append('ids = _Ids()')
import_lines.extend(ids_db)

write_code(_MODELS_OUTPUT_DIR.joinpath('__init__.py'), join_code(import_lines))


Expand Down
4 changes: 2 additions & 2 deletions atproto/codegen/namespaces/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _get_namespace_imports() -> str:
'import typing as t',
'',
'from atproto.xrpc_client import models',
'from atproto.xrpc_client.models.utils import get_or_create_model, get_response_model',
'from atproto.xrpc_client.models.utils import get_or_create, get_response_model',
'from atproto.xrpc_client.namespaces.base import DefaultNamespace, NamespaceBase',
]

Expand Down Expand Up @@ -136,7 +136,7 @@ def _get_namespace_method_body(method_info: MethodInfo, *, sync: bool) -> str:

def _override_arg_line(name: str, model_name: str) -> str:
model_path = f'models.{get_import_path(method_info.nsid)}.{model_name}'
return f'{_(2)}{name} = get_or_create_model({name}, {model_path})'
return f'{_(2)}{name} = get_or_create({name}, {model_path})'

invoke_args = [f"'{method_info.nsid}'"]

Expand Down
23 changes: 8 additions & 15 deletions atproto/firehose/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import typing as t

from atproto import CAR
from atproto.firehose.client import AsyncFirehoseClient, FirehoseClient
from atproto.xrpc_client import models
from atproto.xrpc_client.models.utils import get_or_create_model
from atproto.xrpc_client.models.utils import get_or_create

if t.TYPE_CHECKING:
from atproto.firehose.models import MessageFrame
Expand Down Expand Up @@ -37,26 +36,20 @@
]


def parse_subscribe_repos_message(message: 'MessageFrame', *, decode_inner_cbor: bool = True) -> SubscribeReposMessage:
def parse_subscribe_repos_message(message: 'MessageFrame') -> SubscribeReposMessage:
"""Parse Firehose repositories message to the corresponding model.

Note:
Use `decode_inner_cbor` only when required to increase performance.

Args:
message: Message frame.
decode_inner_cbor: Decode DAG-CBOR inside models.

Returns:
:obj:`SubscribeReposMessage`: Corresponding message model.
"""
model_class = _SUBSCRIBE_REPOS_MESSAGE_TYPE_TO_MODEL[message.type]
model_instance = get_or_create_model(message.body, model_class)

if decode_inner_cbor and isinstance(model_instance, models.ComAtprotoSyncSubscribeRepos.Commit):
model_instance.blocks = CAR.from_bytes(model_instance.blocks)

return model_instance
return get_or_create(message.body, model_class)


def parse_subscribe_labels_message(message: 'MessageFrame') -> SubscribeLabelsMessage:
Expand All @@ -69,18 +62,18 @@ def parse_subscribe_labels_message(message: 'MessageFrame') -> SubscribeLabelsMe
:obj:`SubscribeLabelsMessage`: Corresponding message model.
"""
model_class = _SUBSCRIBE_LABELS_MESSAGE_TYPE_TO_MODEL[message.type]
return get_or_create_model(message.body, model_class)
return get_or_create(message.body, model_class)


class FirehoseSubscribeReposClient(FirehoseClient):
def __init__(self, params: t.Optional[t.Union[dict, 'models.ComAtprotoSyncSubscribeRepos.Params']] = None) -> None:
params = get_or_create_model(params, models.ComAtprotoSyncSubscribeRepos.Params)
params = get_or_create(params, models.ComAtprotoSyncSubscribeRepos.Params)
super().__init__(method='com.atproto.sync.subscribeRepos', params=params)


class AsyncFirehoseSubscribeReposClient(AsyncFirehoseClient):
def __init__(self, params: t.Optional[t.Union[dict, 'models.ComAtprotoSyncSubscribeRepos.Params']] = None) -> None:
params = get_or_create_model(params, models.ComAtprotoSyncSubscribeRepos.Params)
params = get_or_create(params, models.ComAtprotoSyncSubscribeRepos.Params)
super().__init__(method='com.atproto.sync.subscribeRepos', params=params)


Expand All @@ -89,11 +82,11 @@ def __init__(self, params: t.Optional[t.Union[dict, 'models.ComAtprotoSyncSubscr

class FirehoseSubscribeLabelsClient(FirehoseClient):
def __init__(self, params: t.Optional[t.Union[dict, 'models.ComAtprotoLabelSubscribeLabels']] = None) -> None:
params = get_or_create_model(params, models.ComAtprotoLabelSubscribeLabels.Params)
params = get_or_create(params, models.ComAtprotoLabelSubscribeLabels.Params)
super().__init__(method='com.atproto.label.subscribeLabels', params=params)


class AsyncFirehoseSubscribeLabelsClient(AsyncFirehoseClient):
def __init__(self, params: t.Optional[t.Union[dict, 'models.ComAtprotoLabelSubscribeLabels']] = None) -> None:
params = get_or_create_model(params, models.ComAtprotoLabelSubscribeLabels.Params)
params = get_or_create(params, models.ComAtprotoLabelSubscribeLabels.Params)
super().__init__(method='com.atproto.label.subscribeLabels', params=params)
8 changes: 4 additions & 4 deletions atproto/firehose/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from atproto.cbor import decode_dag_multi
from atproto.exceptions import AtProtocolError, FirehoseError
from atproto.xrpc_client.models.utils import get_or_create_model
from atproto.xrpc_client.models.utils import get_or_create


class FrameType(Enum):
Expand Down Expand Up @@ -53,16 +53,16 @@ def parse_frame_header(raw_header: dict) -> FrameHeader:

frame_type = FrameType(header_op)
if frame_type is FrameType.MESSAGE:
return get_or_create_model(raw_header, MessageFrameHeader)
return get_or_create_model(raw_header, ErrorFrameHeader)
return get_or_create(raw_header, MessageFrameHeader)
return get_or_create(raw_header, ErrorFrameHeader)
except (ValueError, AtProtocolError) as e:
raise FirehoseError('Invalid frame header') from e


def parse_frame(header: FrameHeader, raw_body: dict) -> Union['ErrorFrame', 'MessageFrame']:
try:
if isinstance(header, ErrorFrameHeader):
body = get_or_create_model(raw_body, ErrorFrameBody)
body = get_or_create(raw_body, ErrorFrameBody)
return ErrorFrame(header, body)
if isinstance(header, MessageFrameHeader):
return MessageFrame(header, raw_body)
Expand Down
11 changes: 6 additions & 5 deletions atproto/xrpc_client/client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from atproto.xrpc_client import models
from atproto.xrpc_client.client.async_raw import AsyncClientRaw
from atproto.xrpc_client.client.methods_mixin import SessionMethodsMixin
from atproto.xrpc_client.models import ids

if t.TYPE_CHECKING:
from atproto.xrpc_client.client.auth import JwtPayload
Expand Down Expand Up @@ -119,7 +120,7 @@ async def send_post(
return await self.com.atproto.repo.create_record(
models.ComAtprotoRepoCreateRecord.Data(
repo=repo,
collection='app.bsky.feed.post',
collection=ids.AppBskyFeedPost,
record=models.AppBskyFeedPost.Main(
createdAt=datetime.now().isoformat(), text=text, reply=reply_to, embed=embed
),
Expand Down Expand Up @@ -178,7 +179,7 @@ async def like(self, subject: models.ComAtprotoRepoStrongRef.Main) -> models.Com
return await self.com.atproto.repo.create_record(
models.ComAtprotoRepoCreateRecord.Data(
repo=self.me.did,
collection='app.bsky.feed.like',
collection=ids.AppBskyFeedLike,
record=models.AppBskyFeedLike.Main(createdAt=datetime.now().isoformat(), subject=subject),
)
)
Expand All @@ -203,7 +204,7 @@ async def unlike(self, record_key: str, profile_identify: t.Optional[str] = None

return await self.com.atproto.repo.delete_record(
models.ComAtprotoRepoDeleteRecord.Data(
collection='app.bsky.feed.like',
collection=ids.AppBskyFeedLike,
repo=repo,
rkey=record_key,
)
Expand Down Expand Up @@ -237,7 +238,7 @@ async def repost(
return await self.com.atproto.repo.create_record(
models.ComAtprotoRepoCreateRecord.Data(
repo=repo,
collection='app.bsky.feed.repost',
collection=ids.AppBskyFeedRepost,
record=models.AppBskyFeedRepost.Main(
createdAt=datetime.now().isoformat(),
subject=subject,
Expand Down Expand Up @@ -265,7 +266,7 @@ async def delete_post(self, post_rkey: str, profile_identify: t.Optional[str] =

return await self.com.atproto.repo.delete_record(
models.ComAtprotoRepoDeleteRecord.Data(
collection='app.bsky.feed.post',
collection=ids.AppBskyFeedPost,
repo=repo,
rkey=post_rkey,
)
Expand Down
11 changes: 6 additions & 5 deletions atproto/xrpc_client/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from atproto.xrpc_client import models
from atproto.xrpc_client.client.methods_mixin import SessionMethodsMixin
from atproto.xrpc_client.client.raw import ClientRaw
from atproto.xrpc_client.models import ids

if t.TYPE_CHECKING:
from atproto.xrpc_client.client.auth import JwtPayload
Expand Down Expand Up @@ -111,7 +112,7 @@ def send_post(
return self.com.atproto.repo.create_record(
models.ComAtprotoRepoCreateRecord.Data(
repo=repo,
collection='app.bsky.feed.post',
collection=ids.AppBskyFeedPost,
record=models.AppBskyFeedPost.Main(
createdAt=datetime.now().isoformat(), text=text, reply=reply_to, embed=embed
),
Expand Down Expand Up @@ -170,7 +171,7 @@ def like(self, subject: models.ComAtprotoRepoStrongRef.Main) -> models.ComAtprot
return self.com.atproto.repo.create_record(
models.ComAtprotoRepoCreateRecord.Data(
repo=self.me.did,
collection='app.bsky.feed.like',
collection=ids.AppBskyFeedLike,
record=models.AppBskyFeedLike.Main(createdAt=datetime.now().isoformat(), subject=subject),
)
)
Expand All @@ -195,7 +196,7 @@ def unlike(self, record_key: str, profile_identify: t.Optional[str] = None) -> b

return self.com.atproto.repo.delete_record(
models.ComAtprotoRepoDeleteRecord.Data(
collection='app.bsky.feed.like',
collection=ids.AppBskyFeedLike,
repo=repo,
rkey=record_key,
)
Expand Down Expand Up @@ -229,7 +230,7 @@ def repost(
return self.com.atproto.repo.create_record(
models.ComAtprotoRepoCreateRecord.Data(
repo=repo,
collection='app.bsky.feed.repost',
collection=ids.AppBskyFeedRepost,
record=models.AppBskyFeedRepost.Main(
createdAt=datetime.now().isoformat(),
subject=subject,
Expand Down Expand Up @@ -257,7 +258,7 @@ def delete_post(self, post_rkey: str, profile_identify: t.Optional[str] = None)

return self.com.atproto.repo.delete_record(
models.ComAtprotoRepoDeleteRecord.Data(
collection='app.bsky.feed.post',
collection=ids.AppBskyFeedPost,
repo=repo,
rkey=post_rkey,
)
Expand Down
Loading