From 50f67c455f5b4409ff31435fd1388a2f0aaf3253 Mon Sep 17 00:00:00 2001 From: Ilya Siamionau Date: Fri, 21 Jul 2023 10:14:52 +0200 Subject: [PATCH] Fix unknown type that could be plain dictionary (#105) --- atproto/codegen/models/generator.py | 24 +++++++++---- .../models/app/bsky/embed/record.py | 2 +- .../xrpc_client/models/app/bsky/feed/defs.py | 2 +- .../bsky/notification/list_notifications.py | 2 +- atproto/xrpc_client/models/base.py | 35 ++++++++++++++----- .../models/com/atproto/admin/defs.py | 8 ++--- .../models/com/atproto/repo/apply_writes.py | 4 +-- .../models/com/atproto/repo/create_record.py | 2 +- .../models/com/atproto/repo/describe_repo.py | 2 +- .../models/com/atproto/repo/get_record.py | 2 +- .../models/com/atproto/repo/list_records.py | 2 +- .../models/com/atproto/repo/put_record.py | 2 +- atproto/xrpc_client/models/unknown_type.py | 18 ++++++++++ atproto/xrpc_client/models/utils.py | 15 ++++---- test.py | 28 +++++++++++---- 15 files changed, 108 insertions(+), 40 deletions(-) create mode 100644 atproto/xrpc_client/models/unknown_type.py diff --git a/atproto/codegen/models/generator.py b/atproto/codegen/models/generator.py index 62ab5e8e..924b8d24 100644 --- a/atproto/codegen/models/generator.py +++ b/atproto/codegen/models/generator.py @@ -61,6 +61,7 @@ def _get_model_imports() -> str: 'import typing_extensions as te', 'from atproto.xrpc_client import models', 'from atproto.xrpc_client.models import base', + 'from atproto.xrpc_client.models import unknown_type', 'from atproto.xrpc_client.models.blob_ref import BlobRef', '', 'from atproto import CID', @@ -162,8 +163,8 @@ def _get_model_field_typehint(nsid: NSID, field_name: str, field_type_def, *, op field_type = type(field_type_def) if field_type == models.LexUnknown: - # TODO(MarshalX): some of "unknown" types are well known... - return _get_optional_typehint("'base.RecordModelBase'", optional=optional) + # unknown type is a generic response with records or any not described type in the lexicon. for example didDoc + return _get_optional_typehint("'base.UnknownDict'", optional=optional) type_hint = _LEXICON_TYPE_TO_PRIMITIVE_TYPEHINT.get(field_type) if type_hint: @@ -425,7 +426,15 @@ def _generate_record_models(lex_db: builder.BuiltRecordModels) -> None: def _generate_record_type_database(lex_db: builder.BuiltRecordModels) -> None: - lines = ['from atproto.xrpc_client import models', 'RECORD_TYPE_TO_MODEL_CLASS = {'] + type_conversion_lines = ['from atproto.xrpc_client import models', 'RECORD_TYPE_TO_MODEL_CLASS = {'] + unknown_record_type_hint_lines = [ + 'import typing as t', + 'import typing_extensions as te', + 'if t.TYPE_CHECKING:', + f'{_(4)}from atproto.xrpc_client import models', + '', + 'UnknownRecordType: te.TypeAlias = t.Union[', + ] for nsid, defs in lex_db.items(): _save_code_import_if_not_exist(nsid) @@ -439,11 +448,14 @@ def _generate_record_type_database(lex_db: builder.BuiltRecordModels) -> None: path_to_class = f'models.{get_import_path(nsid)}.{class_name}' - lines.append(f"'{record_type}': {path_to_class},") + type_conversion_lines.append(f"'{record_type}': {path_to_class},") + unknown_record_type_hint_lines.append(f"{_(4)}'{path_to_class}',") - lines.append('}') + type_conversion_lines.append('}') + unknown_record_type_hint_lines.append(']') - write_code(_MODELS_OUTPUT_DIR.joinpath('type_conversion.py'), join_code(lines)) + write_code(_MODELS_OUTPUT_DIR.joinpath('type_conversion.py'), join_code(type_conversion_lines)) + write_code(_MODELS_OUTPUT_DIR.joinpath('unknown_type.py'), join_code(unknown_record_type_hint_lines)) def _generate_ref_models(lex_db: builder.BuiltRefsModels) -> None: diff --git a/atproto/xrpc_client/models/app/bsky/embed/record.py b/atproto/xrpc_client/models/app/bsky/embed/record.py index b756f6af..0db75f33 100644 --- a/atproto/xrpc_client/models/app/bsky/embed/record.py +++ b/atproto/xrpc_client/models/app/bsky/embed/record.py @@ -48,7 +48,7 @@ class ViewRecord(base.ModelBase): cid: str #: Cid. indexedAt: str #: Indexed at. uri: str #: Uri. - value: 'base.RecordModelBase' #: Value. + value: 'base.UnknownDict' #: Value. embeds: t.Optional[ t.List[ t.Union[ diff --git a/atproto/xrpc_client/models/app/bsky/feed/defs.py b/atproto/xrpc_client/models/app/bsky/feed/defs.py index 15f766ef..9651a2af 100644 --- a/atproto/xrpc_client/models/app/bsky/feed/defs.py +++ b/atproto/xrpc_client/models/app/bsky/feed/defs.py @@ -20,7 +20,7 @@ class PostView(base.ModelBase): author: 'models.AppBskyActorDefs.ProfileViewBasic' #: Author. cid: str #: Cid. indexedAt: str #: Indexed at. - record: 'base.RecordModelBase' #: Record. + record: 'base.UnknownDict' #: Record. uri: str #: Uri. embed: t.Optional[ t.Union[ diff --git a/atproto/xrpc_client/models/app/bsky/notification/list_notifications.py b/atproto/xrpc_client/models/app/bsky/notification/list_notifications.py index 01daf423..6a914bcb 100644 --- a/atproto/xrpc_client/models/app/bsky/notification/list_notifications.py +++ b/atproto/xrpc_client/models/app/bsky/notification/list_notifications.py @@ -41,7 +41,7 @@ class Notification(base.ModelBase): indexedAt: str #: Indexed at. isRead: bool #: Is read. reason: str #: Expected values are 'like', 'repost', 'follow', 'mention', 'reply', and 'quote'. - record: 'base.RecordModelBase' #: Record. + record: 'base.UnknownDict' #: Record. uri: str #: Uri. labels: t.Optional[t.List['models.ComAtprotoLabelDefs.Label']] = None #: Labels. reasonSubject: t.Optional[str] = None #: Reason subject. diff --git a/atproto/xrpc_client/models/base.py b/atproto/xrpc_client/models/base.py index e7a0ae7c..41e0934c 100644 --- a/atproto/xrpc_client/models/base.py +++ b/atproto/xrpc_client/models/base.py @@ -1,9 +1,6 @@ -from dataclasses import dataclass - from atproto.exceptions import ModelFieldNotFoundError -@dataclass class ModelBase: def __getitem__(self, item: str): if hasattr(self, item): @@ -12,21 +9,43 @@ def __getitem__(self, item: str): raise ModelFieldNotFoundError(f"Can't find field '{item}' in the object of type {type(self)}.") -@dataclass class ParamsModelBase(ModelBase): pass -@dataclass class DataModelBase(ModelBase): pass -@dataclass class ResponseModelBase(ModelBase): pass -@dataclass -class RecordModelBase(ModelBase): +class UnknownDict(ModelBase): + pass + + +class DotDict(UnknownDict): + def __init__(self, data: dict): + self._data = data + + def __getattr__(self, item): + return self._data.get(item) + + def __setattr__(self, key, value): + if key == '_data': + super().__setattr__(key, value) + return + + self._data[key] = value + + def __delattr__(self, item): + del self._data[item] + + +class UnknownRecord(UnknownDict): + pass + + +class RecordModelBase(UnknownRecord): pass diff --git a/atproto/xrpc_client/models/com/atproto/admin/defs.py b/atproto/xrpc_client/models/com/atproto/admin/defs.py index 28f65c5c..9633cd7b 100644 --- a/atproto/xrpc_client/models/com/atproto/admin/defs.py +++ b/atproto/xrpc_client/models/com/atproto/admin/defs.py @@ -146,7 +146,7 @@ class RepoView(base.ModelBase): handle: str #: Handle. indexedAt: str #: Indexed at. moderation: 'models.ComAtprotoAdminDefs.Moderation' #: Moderation. - relatedRecords: t.List['base.RecordModelBase'] #: Related records. + relatedRecords: t.List['base.UnknownDict'] #: Related records. email: t.Optional[str] = None #: Email. invitedBy: t.Optional['models.ComAtprotoServerDefs.InviteCode'] = None #: Invited by. invitesDisabled: t.Optional[bool] = None #: Invites disabled. @@ -163,7 +163,7 @@ class RepoViewDetail(base.ModelBase): handle: str #: Handle. indexedAt: str #: Indexed at. moderation: 'models.ComAtprotoAdminDefs.ModerationDetail' #: Moderation. - relatedRecords: t.List['base.RecordModelBase'] #: Related records. + relatedRecords: t.List['base.UnknownDict'] #: Related records. email: t.Optional[str] = None #: Email. invitedBy: t.Optional['models.ComAtprotoServerDefs.InviteCode'] = None #: Invited by. invites: t.Optional[t.List['models.ComAtprotoServerDefs.InviteCode']] = None #: Invites. @@ -204,7 +204,7 @@ class RecordView(base.ModelBase): moderation: 'models.ComAtprotoAdminDefs.Moderation' #: Moderation. repo: 'models.ComAtprotoAdminDefs.RepoView' #: Repo. uri: str #: Uri. - value: 'base.RecordModelBase' #: Value. + value: 'base.UnknownDict' #: Value. _type: str = 'com.atproto.admin.defs#recordView' @@ -220,7 +220,7 @@ class RecordViewDetail(base.ModelBase): moderation: 'models.ComAtprotoAdminDefs.ModerationDetail' #: Moderation. repo: 'models.ComAtprotoAdminDefs.RepoView' #: Repo. uri: str #: Uri. - value: 'base.RecordModelBase' #: Value. + value: 'base.UnknownDict' #: Value. labels: t.Optional[t.List['models.ComAtprotoLabelDefs.Label']] = None #: Labels. _type: str = 'com.atproto.admin.defs#recordViewDetail' diff --git a/atproto/xrpc_client/models/com/atproto/repo/apply_writes.py b/atproto/xrpc_client/models/com/atproto/repo/apply_writes.py index 404bbc5a..fef9e727 100644 --- a/atproto/xrpc_client/models/com/atproto/repo/apply_writes.py +++ b/atproto/xrpc_client/models/com/atproto/repo/apply_writes.py @@ -36,7 +36,7 @@ class Create(base.ModelBase): """Definition model for :obj:`com.atproto.repo.applyWrites`. Create a new record.""" collection: str #: Collection. - value: 'base.RecordModelBase' #: Value. + value: 'base.UnknownDict' #: Value. rkey: t.Optional[str] = None #: Rkey. _type: str = 'com.atproto.repo.applyWrites#create' @@ -49,7 +49,7 @@ class Update(base.ModelBase): collection: str #: Collection. rkey: str #: Rkey. - value: 'base.RecordModelBase' #: Value. + value: 'base.UnknownDict' #: Value. _type: str = 'com.atproto.repo.applyWrites#update' diff --git a/atproto/xrpc_client/models/com/atproto/repo/create_record.py b/atproto/xrpc_client/models/com/atproto/repo/create_record.py index e69f6ee4..6c1bcd75 100644 --- a/atproto/xrpc_client/models/com/atproto/repo/create_record.py +++ b/atproto/xrpc_client/models/com/atproto/repo/create_record.py @@ -17,7 +17,7 @@ class Data(base.DataModelBase): """Input data model for :obj:`com.atproto.repo.createRecord`.""" collection: str #: The NSID of the record collection. - record: 'base.RecordModelBase' #: The record to create. + record: 'base.UnknownDict' #: The record to create. repo: str #: The handle or DID of the repo. rkey: t.Optional[str] = None #: The key of the record. swapCommit: t.Optional[str] = None #: Compare and swap with the previous commit by cid. diff --git a/atproto/xrpc_client/models/com/atproto/repo/describe_repo.py b/atproto/xrpc_client/models/com/atproto/repo/describe_repo.py index 7e413222..c1e10e97 100644 --- a/atproto/xrpc_client/models/com/atproto/repo/describe_repo.py +++ b/atproto/xrpc_client/models/com/atproto/repo/describe_repo.py @@ -26,6 +26,6 @@ class Response(base.ResponseModelBase): collections: t.List[str] #: Collections. did: str #: Did. - didDoc: 'base.RecordModelBase' #: Did doc. + didDoc: 'base.UnknownDict' #: Did doc. handle: str #: Handle. handleIsCorrect: bool #: Handle is correct. diff --git a/atproto/xrpc_client/models/com/atproto/repo/get_record.py b/atproto/xrpc_client/models/com/atproto/repo/get_record.py index 4f66edef..94e2f483 100644 --- a/atproto/xrpc_client/models/com/atproto/repo/get_record.py +++ b/atproto/xrpc_client/models/com/atproto/repo/get_record.py @@ -30,5 +30,5 @@ class Response(base.ResponseModelBase): """Output data model for :obj:`com.atproto.repo.getRecord`.""" uri: str #: Uri. - value: 'base.RecordModelBase' #: Value. + value: 'base.UnknownDict' #: Value. cid: t.Optional[str] = None #: Cid. diff --git a/atproto/xrpc_client/models/com/atproto/repo/list_records.py b/atproto/xrpc_client/models/com/atproto/repo/list_records.py index cf836a04..91f3261f 100644 --- a/atproto/xrpc_client/models/com/atproto/repo/list_records.py +++ b/atproto/xrpc_client/models/com/atproto/repo/list_records.py @@ -42,6 +42,6 @@ class Record(base.ModelBase): cid: str #: Cid. uri: str #: Uri. - value: 'base.RecordModelBase' #: Value. + value: 'base.UnknownDict' #: Value. _type: str = 'com.atproto.repo.listRecords#record' diff --git a/atproto/xrpc_client/models/com/atproto/repo/put_record.py b/atproto/xrpc_client/models/com/atproto/repo/put_record.py index 095d6b70..c5407e1d 100644 --- a/atproto/xrpc_client/models/com/atproto/repo/put_record.py +++ b/atproto/xrpc_client/models/com/atproto/repo/put_record.py @@ -17,7 +17,7 @@ class Data(base.DataModelBase): """Input data model for :obj:`com.atproto.repo.putRecord`.""" collection: str #: The NSID of the record collection. - record: 'base.RecordModelBase' #: The record to write. + record: 'base.UnknownDict' #: The record to write. repo: str #: The handle or DID of the repo. rkey: str #: The key of the record. swapCommit: t.Optional[str] = None #: Compare and swap with the previous commit by cid. diff --git a/atproto/xrpc_client/models/unknown_type.py b/atproto/xrpc_client/models/unknown_type.py new file mode 100644 index 00000000..4042caa6 --- /dev/null +++ b/atproto/xrpc_client/models/unknown_type.py @@ -0,0 +1,18 @@ +import typing as t + +import typing_extensions as te + +if t.TYPE_CHECKING: + from atproto.xrpc_client import models + +UnknownRecordType: te.TypeAlias = t.Union[ + 'models.AppBskyFeedGenerator.Main', + 'models.AppBskyActorProfile.Main', + 'models.AppBskyFeedRepost.Main', + 'models.AppBskyGraphListitem.Main', + 'models.AppBskyFeedLike.Main', + 'models.AppBskyGraphFollow.Main', + 'models.AppBskyGraphList.Main', + 'models.AppBskyGraphBlock.Main', + 'models.AppBskyFeedPost.Main', +] diff --git a/atproto/xrpc_client/models/utils.py b/atproto/xrpc_client/models/utils.py index 54d49ace..82926356 100644 --- a/atproto/xrpc_client/models/utils.py +++ b/atproto/xrpc_client/models/utils.py @@ -14,9 +14,10 @@ UnexpectedFieldError, WrongTypeError, ) -from atproto.xrpc_client.models.base import ModelBase, RecordModelBase +from atproto.xrpc_client.models.base import DotDict, ModelBase, UnknownDict from atproto.xrpc_client.models.blob_ref import BlobRef from atproto.xrpc_client.models.type_conversion import RECORD_TYPE_TO_MODEL_CLASS +from atproto.xrpc_client.models.unknown_type import UnknownRecordType if t.TYPE_CHECKING: from atproto.xrpc_client.request import Response @@ -25,10 +26,12 @@ ModelData: te.TypeAlias = t.Union[M, dict, None] -def _record_model_type_hook(data: dict) -> RecordModelBase: - # used for inner Record types - record_type = data.pop('$type') - return get_or_create_model(data, RECORD_TYPE_TO_MODEL_CLASS[record_type]) +def _unknown_type_hook(data: dict) -> t.Union[UnknownRecordType, DotDict]: + if '$type' in data: + # $type used for inner Record types + return get_or_create_model(data, RECORD_TYPE_TO_MODEL_CLASS[data.pop('$type')]) + # any another unknown (not described by lexicon) type + return DotDict(data) def _decode_cid_hook(ref: t.Union[CID, str]) -> CID: @@ -41,7 +44,7 @@ def _decode_cid_hook(ref: t.Union[CID, str]) -> CID: _TYPE_HOOKS = { BlobRef: lambda ref: BlobRef.from_dict(ref), CID: _decode_cid_hook, - RecordModelBase: _record_model_type_hook, + UnknownDict: _unknown_type_hook, } _DACITE_CONFIG = Config(cast=[Enum], type_hooks=_TYPE_HOOKS) diff --git a/test.py b/test.py index 9509e672..b29dfe1d 100644 --- a/test.py +++ b/test.py @@ -3,6 +3,7 @@ import os import threading import typing as t +from datetime import datetime from atproto import CAR, AsyncClient, AtUri, Client, exceptions, models from atproto.firehose import ( @@ -45,10 +46,23 @@ def sync_main(): client = Client() client.login(os.environ['USERNAME'], os.environ['PASSWORD']) - valid = client.com.atproto.repo.get_record( - {'collection': ids.AppBskyFeedGenerator, 'repo': 'test.marshal.dev', 'rkey': 'whats-alf'} + # r = client.send_post('regress test') + # print(r) + + did_doc = client.com.atproto.repo.describe_repo({'repo': 'did:plc:ze3uieyyns7prike7itbdjiy'}).didDoc + print(did_doc.service) + print(did_doc['service']) + print(did_doc['@context']) + + atproto_feed = client.com.atproto.repo.get_record( + {'collection': ids.AppBskyFeedGenerator, 'repo': 'marshal.dev', 'rkey': 'atproto'} ).value - print(valid) + print(atproto_feed) + print(atproto_feed.createdAt) + print(atproto_feed['createdAt']) + print(type(atproto_feed)) + + exit(0) # client.com.atproto.admin.get_moderation_actions() @@ -158,8 +172,10 @@ def _get_ops_by_type(commit: models.ComAtprotoSyncSubscribeRepos.Commit) -> dict def _custom_feed_firehose(): client = FirehoseSubscribeReposClient({'cursor': 93278360}) + client = FirehoseSubscribeReposClient() def on_message_handler(message: 'MessageFrame') -> None: + # return commit = parse_subscribe_repos_message(message) if not isinstance(commit, models.ComAtprotoSyncSubscribeRepos.Commit): return @@ -221,9 +237,9 @@ async def _stop_after_n_sec(): if __name__ == '__main__': - # sync_main() + sync_main() # asyncio.get_event_loop().run_until_complete(main()) # _custom_feed_firehose() - _main_firehose_test() - asyncio.get_event_loop().run_until_complete(_main_async_firehose_test()) + # _main_firehose_test() + # asyncio.get_event_loop().run_until_complete(_main_async_firehose_test())