Skip to content

Commit

Permalink
Fix parsing of custom (extended) records; add dot notation for dictio…
Browse files Browse the repository at this point in the history
…naries (#106)
  • Loading branch information
MarshalX committed Jul 21, 2023
1 parent 50f67c4 commit 356865a
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 12 deletions.
1 change: 1 addition & 0 deletions atproto/codegen/models/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def _get_ref_union_typehint(nsid: NSID, field_type_def, *, optional: bool) -> st
# maybe it's for the records that have custom fields... idk
# ref: https://github.com/bluesky-social/atproto/blob/b01e47b61730d05a780f7a42667b91ccaa192e8e/packages/lex-cli/src/codegen/lex-gen.ts#L325
# grep by "{$type: string; [k: string]: unknown}" string
# TODO(MarshalX): use 'base.UnknownDict' and convert to DotDict
def_names.append('t.Dict[str, t.Any]')

def_names = ', '.join([f"'{name}'" for name in def_names])
Expand Down
12 changes: 7 additions & 5 deletions atproto/xrpc_client/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
def _unknown_type_hook(data: dict) -> t.Union[UnknownRecordType, DotDict]:
if '$type' in data:
# $type used for inner Record types
return get_or_create_model(data, RECORD_TYPE_TO_MODEL_CLASS[data.pop('$type')])
return get_or_create(data, strict=False)
# any another unknown (not described by lexicon) type
return DotDict(data)

Expand All @@ -51,7 +51,7 @@ def _decode_cid_hook(ref: t.Union[CID, str]) -> CID:

def get_or_create(
model_data: ModelData, model: t.Optional[t.Type[M]] = None, *, strict: bool = True
) -> t.Optional[t.Union[M, dict]]:
) -> t.Optional[t.Union[M, UnknownRecordType, DotDict]]:
"""Get model instance from raw data.
Note:
Expand Down Expand Up @@ -80,17 +80,19 @@ def get_or_create(
if strict:
raise e

return model_data
return DotDict(model_data)


def _get_or_create(model_data: ModelData, model: t.Type[M], *, strict: bool) -> t.Optional[t.Union[M, dict]]:
def _get_or_create(
model_data: ModelData, model: t.Type[M], *, strict: bool
) -> t.Optional[t.Union[M, UnknownRecordType, DotDict]]:
if model_data is None:
return None

if model is None:
# resolve model by $type and try to parse
# resolves only Records
record_type = model_data.pop('$type')
record_type = model_data.pop('$type', None)
if not record_type or record_type not in RECORD_TYPE_TO_MODEL_CLASS:
return None

Expand Down
4 changes: 2 additions & 2 deletions atproto/xrpc_client/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self) -> None:
self._client = httpx.Client()

def _send_request(self, method: str, url: str, **kwargs) -> httpx.Response:
headers = self.get_headers(kwargs.pop('headers'))
headers = self.get_headers(kwargs.pop('headers', None))

try:
response = self._client.request(method=method, url=url, headers=headers, **kwargs)
Expand All @@ -120,7 +120,7 @@ def __init__(self) -> None:
self._client = httpx.AsyncClient()

async def _send_request(self, method: str, url: str, **kwargs) -> httpx.Response:
headers = self.get_headers(kwargs.pop('headers'))
headers = self.get_headers(kwargs.pop('headers', None))

try:
response = await self._client.request(method=method, url=url, headers=headers, **kwargs)
Expand Down
18 changes: 13 additions & 5 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,23 @@ def sync_main():
client = Client()
client.login(os.environ['USERNAME'], os.environ['PASSWORD'])

# r = client.send_post('regress test')
# print(r)
lexicon_correct_record = client.com.atproto.repo.get_record(
{'collection': 'app.bsky.feed.post', 'repo': 'test.marshal.dev', 'rkey': '3k2yihcrp6f2c'}
)
print(lexicon_correct_record.value.text)
print(type(lexicon_correct_record.value))
extended_record = client.com.atproto.repo.get_record(
{'collection': 'app.bsky.feed.post', 'repo': 'test.marshal.dev', 'rkey': '3k2yinh52ne2x'}
)
print(extended_record.value.text)
print(extended_record.value.lol) # custom (out of lexicon) attribute
print(type(extended_record.value))

did_doc = client.com.atproto.repo.describe_repo({'repo': 'did:plc:ze3uieyyns7prike7itbdjiy'}).didDoc
print(did_doc.service)
print(did_doc['service'])
print(did_doc['@context'])
print(type(did_doc))

atproto_feed = client.com.atproto.repo.get_record(
{'collection': ids.AppBskyFeedGenerator, 'repo': 'marshal.dev', 'rkey': 'atproto'}
Expand Down Expand Up @@ -116,7 +126,7 @@ async def main():
# with open('cat2.png', 'rb') as f:
# cat_data = f.read()

# await async_client.send_image('Cat looking for a Async Python', cat_data, 'async cat alt')
# await async_client.send_image('Cat looking for an Async Python', cat_data, 'async cat alt')

# resolve = await async_client.com.atproto.identity.resolve_handle(
# models.ComAtprotoIdentityResolveHandle.Params(profile.handle)
Expand Down Expand Up @@ -171,11 +181,9 @@ def _get_ops_by_type(commit: models.ComAtprotoSyncSubscribeRepos.Commit) -> dict


def _custom_feed_firehose():
client = FirehoseSubscribeReposClient({'cursor': 93278360})
client = FirehoseSubscribeReposClient()

def on_message_handler(message: 'MessageFrame') -> None:
# return
commit = parse_subscribe_repos_message(message)
if not isinstance(commit, models.ComAtprotoSyncSubscribeRepos.Commit):
return
Expand Down

0 comments on commit 356865a

Please sign in to comment.