Skip to content

Commit

Permalink
Fix unknown type that could be plain dictionary (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarshalX committed Jul 21, 2023
1 parent 7489721 commit 50f67c4
Show file tree
Hide file tree
Showing 15 changed files with 108 additions and 40 deletions.
24 changes: 18 additions & 6 deletions atproto/codegen/models/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def _get_model_imports() -> str:
'import typing_extensions as te',
'from atproto.xrpc_client import models',
'from atproto.xrpc_client.models import base',
'from atproto.xrpc_client.models import unknown_type',
'from atproto.xrpc_client.models.blob_ref import BlobRef',
'',
'from atproto import CID',
Expand Down Expand Up @@ -162,8 +163,8 @@ def _get_model_field_typehint(nsid: NSID, field_name: str, field_type_def, *, op
field_type = type(field_type_def)

if field_type == models.LexUnknown:
# TODO(MarshalX): some of "unknown" types are well known...
return _get_optional_typehint("'base.RecordModelBase'", optional=optional)
# unknown type is a generic response with records or any not described type in the lexicon. for example didDoc
return _get_optional_typehint("'base.UnknownDict'", optional=optional)

type_hint = _LEXICON_TYPE_TO_PRIMITIVE_TYPEHINT.get(field_type)
if type_hint:
Expand Down Expand Up @@ -425,7 +426,15 @@ def _generate_record_models(lex_db: builder.BuiltRecordModels) -> None:


def _generate_record_type_database(lex_db: builder.BuiltRecordModels) -> None:
lines = ['from atproto.xrpc_client import models', 'RECORD_TYPE_TO_MODEL_CLASS = {']
type_conversion_lines = ['from atproto.xrpc_client import models', 'RECORD_TYPE_TO_MODEL_CLASS = {']
unknown_record_type_hint_lines = [
'import typing as t',
'import typing_extensions as te',
'if t.TYPE_CHECKING:',
f'{_(4)}from atproto.xrpc_client import models',
'',
'UnknownRecordType: te.TypeAlias = t.Union[',
]

for nsid, defs in lex_db.items():
_save_code_import_if_not_exist(nsid)
Expand All @@ -439,11 +448,14 @@ def _generate_record_type_database(lex_db: builder.BuiltRecordModels) -> None:

path_to_class = f'models.{get_import_path(nsid)}.{class_name}'

lines.append(f"'{record_type}': {path_to_class},")
type_conversion_lines.append(f"'{record_type}': {path_to_class},")
unknown_record_type_hint_lines.append(f"{_(4)}'{path_to_class}',")

lines.append('}')
type_conversion_lines.append('}')
unknown_record_type_hint_lines.append(']')

write_code(_MODELS_OUTPUT_DIR.joinpath('type_conversion.py'), join_code(lines))
write_code(_MODELS_OUTPUT_DIR.joinpath('type_conversion.py'), join_code(type_conversion_lines))
write_code(_MODELS_OUTPUT_DIR.joinpath('unknown_type.py'), join_code(unknown_record_type_hint_lines))


def _generate_ref_models(lex_db: builder.BuiltRefsModels) -> None:
Expand Down
2 changes: 1 addition & 1 deletion atproto/xrpc_client/models/app/bsky/embed/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class ViewRecord(base.ModelBase):
cid: str #: Cid.
indexedAt: str #: Indexed at.
uri: str #: Uri.
value: 'base.RecordModelBase' #: Value.
value: 'base.UnknownDict' #: Value.
embeds: t.Optional[
t.List[
t.Union[
Expand Down
2 changes: 1 addition & 1 deletion atproto/xrpc_client/models/app/bsky/feed/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class PostView(base.ModelBase):
author: 'models.AppBskyActorDefs.ProfileViewBasic' #: Author.
cid: str #: Cid.
indexedAt: str #: Indexed at.
record: 'base.RecordModelBase' #: Record.
record: 'base.UnknownDict' #: Record.
uri: str #: Uri.
embed: t.Optional[
t.Union[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Notification(base.ModelBase):
indexedAt: str #: Indexed at.
isRead: bool #: Is read.
reason: str #: Expected values are 'like', 'repost', 'follow', 'mention', 'reply', and 'quote'.
record: 'base.RecordModelBase' #: Record.
record: 'base.UnknownDict' #: Record.
uri: str #: Uri.
labels: t.Optional[t.List['models.ComAtprotoLabelDefs.Label']] = None #: Labels.
reasonSubject: t.Optional[str] = None #: Reason subject.
Expand Down
35 changes: 27 additions & 8 deletions atproto/xrpc_client/models/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from dataclasses import dataclass

from atproto.exceptions import ModelFieldNotFoundError


@dataclass
class ModelBase:
def __getitem__(self, item: str):
if hasattr(self, item):
Expand All @@ -12,21 +9,43 @@ def __getitem__(self, item: str):
raise ModelFieldNotFoundError(f"Can't find field '{item}' in the object of type {type(self)}.")


@dataclass
class ParamsModelBase(ModelBase):
pass


@dataclass
class DataModelBase(ModelBase):
pass


@dataclass
class ResponseModelBase(ModelBase):
pass


@dataclass
class RecordModelBase(ModelBase):
class UnknownDict(ModelBase):
pass


class DotDict(UnknownDict):
def __init__(self, data: dict):
self._data = data

def __getattr__(self, item):
return self._data.get(item)

def __setattr__(self, key, value):
if key == '_data':
super().__setattr__(key, value)
return

self._data[key] = value

def __delattr__(self, item):
del self._data[item]


class UnknownRecord(UnknownDict):
pass


class RecordModelBase(UnknownRecord):
pass
8 changes: 4 additions & 4 deletions atproto/xrpc_client/models/com/atproto/admin/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class RepoView(base.ModelBase):
handle: str #: Handle.
indexedAt: str #: Indexed at.
moderation: 'models.ComAtprotoAdminDefs.Moderation' #: Moderation.
relatedRecords: t.List['base.RecordModelBase'] #: Related records.
relatedRecords: t.List['base.UnknownDict'] #: Related records.
email: t.Optional[str] = None #: Email.
invitedBy: t.Optional['models.ComAtprotoServerDefs.InviteCode'] = None #: Invited by.
invitesDisabled: t.Optional[bool] = None #: Invites disabled.
Expand All @@ -163,7 +163,7 @@ class RepoViewDetail(base.ModelBase):
handle: str #: Handle.
indexedAt: str #: Indexed at.
moderation: 'models.ComAtprotoAdminDefs.ModerationDetail' #: Moderation.
relatedRecords: t.List['base.RecordModelBase'] #: Related records.
relatedRecords: t.List['base.UnknownDict'] #: Related records.
email: t.Optional[str] = None #: Email.
invitedBy: t.Optional['models.ComAtprotoServerDefs.InviteCode'] = None #: Invited by.
invites: t.Optional[t.List['models.ComAtprotoServerDefs.InviteCode']] = None #: Invites.
Expand Down Expand Up @@ -204,7 +204,7 @@ class RecordView(base.ModelBase):
moderation: 'models.ComAtprotoAdminDefs.Moderation' #: Moderation.
repo: 'models.ComAtprotoAdminDefs.RepoView' #: Repo.
uri: str #: Uri.
value: 'base.RecordModelBase' #: Value.
value: 'base.UnknownDict' #: Value.

_type: str = 'com.atproto.admin.defs#recordView'

Expand All @@ -220,7 +220,7 @@ class RecordViewDetail(base.ModelBase):
moderation: 'models.ComAtprotoAdminDefs.ModerationDetail' #: Moderation.
repo: 'models.ComAtprotoAdminDefs.RepoView' #: Repo.
uri: str #: Uri.
value: 'base.RecordModelBase' #: Value.
value: 'base.UnknownDict' #: Value.
labels: t.Optional[t.List['models.ComAtprotoLabelDefs.Label']] = None #: Labels.

_type: str = 'com.atproto.admin.defs#recordViewDetail'
Expand Down
4 changes: 2 additions & 2 deletions atproto/xrpc_client/models/com/atproto/repo/apply_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class Create(base.ModelBase):
"""Definition model for :obj:`com.atproto.repo.applyWrites`. Create a new record."""

collection: str #: Collection.
value: 'base.RecordModelBase' #: Value.
value: 'base.UnknownDict' #: Value.
rkey: t.Optional[str] = None #: Rkey.

_type: str = 'com.atproto.repo.applyWrites#create'
Expand All @@ -49,7 +49,7 @@ class Update(base.ModelBase):

collection: str #: Collection.
rkey: str #: Rkey.
value: 'base.RecordModelBase' #: Value.
value: 'base.UnknownDict' #: Value.

_type: str = 'com.atproto.repo.applyWrites#update'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Data(base.DataModelBase):
"""Input data model for :obj:`com.atproto.repo.createRecord`."""

collection: str #: The NSID of the record collection.
record: 'base.RecordModelBase' #: The record to create.
record: 'base.UnknownDict' #: The record to create.
repo: str #: The handle or DID of the repo.
rkey: t.Optional[str] = None #: The key of the record.
swapCommit: t.Optional[str] = None #: Compare and swap with the previous commit by cid.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ class Response(base.ResponseModelBase):

collections: t.List[str] #: Collections.
did: str #: Did.
didDoc: 'base.RecordModelBase' #: Did doc.
didDoc: 'base.UnknownDict' #: Did doc.
handle: str #: Handle.
handleIsCorrect: bool #: Handle is correct.
2 changes: 1 addition & 1 deletion atproto/xrpc_client/models/com/atproto/repo/get_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ class Response(base.ResponseModelBase):
"""Output data model for :obj:`com.atproto.repo.getRecord`."""

uri: str #: Uri.
value: 'base.RecordModelBase' #: Value.
value: 'base.UnknownDict' #: Value.
cid: t.Optional[str] = None #: Cid.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@ class Record(base.ModelBase):

cid: str #: Cid.
uri: str #: Uri.
value: 'base.RecordModelBase' #: Value.
value: 'base.UnknownDict' #: Value.

_type: str = 'com.atproto.repo.listRecords#record'
2 changes: 1 addition & 1 deletion atproto/xrpc_client/models/com/atproto/repo/put_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Data(base.DataModelBase):
"""Input data model for :obj:`com.atproto.repo.putRecord`."""

collection: str #: The NSID of the record collection.
record: 'base.RecordModelBase' #: The record to write.
record: 'base.UnknownDict' #: The record to write.
repo: str #: The handle or DID of the repo.
rkey: str #: The key of the record.
swapCommit: t.Optional[str] = None #: Compare and swap with the previous commit by cid.
Expand Down
18 changes: 18 additions & 0 deletions atproto/xrpc_client/models/unknown_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import typing as t

import typing_extensions as te

if t.TYPE_CHECKING:
from atproto.xrpc_client import models

UnknownRecordType: te.TypeAlias = t.Union[
'models.AppBskyFeedGenerator.Main',
'models.AppBskyActorProfile.Main',
'models.AppBskyFeedRepost.Main',
'models.AppBskyGraphListitem.Main',
'models.AppBskyFeedLike.Main',
'models.AppBskyGraphFollow.Main',
'models.AppBskyGraphList.Main',
'models.AppBskyGraphBlock.Main',
'models.AppBskyFeedPost.Main',
]
15 changes: 9 additions & 6 deletions atproto/xrpc_client/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
UnexpectedFieldError,
WrongTypeError,
)
from atproto.xrpc_client.models.base import ModelBase, RecordModelBase
from atproto.xrpc_client.models.base import DotDict, ModelBase, UnknownDict
from atproto.xrpc_client.models.blob_ref import BlobRef
from atproto.xrpc_client.models.type_conversion import RECORD_TYPE_TO_MODEL_CLASS
from atproto.xrpc_client.models.unknown_type import UnknownRecordType

if t.TYPE_CHECKING:
from atproto.xrpc_client.request import Response
Expand All @@ -25,10 +26,12 @@
ModelData: te.TypeAlias = t.Union[M, dict, None]


def _record_model_type_hook(data: dict) -> RecordModelBase:
# used for inner Record types
record_type = data.pop('$type')
return get_or_create_model(data, RECORD_TYPE_TO_MODEL_CLASS[record_type])
def _unknown_type_hook(data: dict) -> t.Union[UnknownRecordType, DotDict]:
if '$type' in data:
# $type used for inner Record types
return get_or_create_model(data, RECORD_TYPE_TO_MODEL_CLASS[data.pop('$type')])
# any another unknown (not described by lexicon) type
return DotDict(data)


def _decode_cid_hook(ref: t.Union[CID, str]) -> CID:
Expand All @@ -41,7 +44,7 @@ def _decode_cid_hook(ref: t.Union[CID, str]) -> CID:
_TYPE_HOOKS = {
BlobRef: lambda ref: BlobRef.from_dict(ref),
CID: _decode_cid_hook,
RecordModelBase: _record_model_type_hook,
UnknownDict: _unknown_type_hook,
}
_DACITE_CONFIG = Config(cast=[Enum], type_hooks=_TYPE_HOOKS)

Expand Down
28 changes: 22 additions & 6 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import threading
import typing as t
from datetime import datetime

from atproto import CAR, AsyncClient, AtUri, Client, exceptions, models
from atproto.firehose import (
Expand Down Expand Up @@ -45,10 +46,23 @@ def sync_main():
client = Client()
client.login(os.environ['USERNAME'], os.environ['PASSWORD'])

valid = client.com.atproto.repo.get_record(
{'collection': ids.AppBskyFeedGenerator, 'repo': 'test.marshal.dev', 'rkey': 'whats-alf'}
# r = client.send_post('regress test')
# print(r)

did_doc = client.com.atproto.repo.describe_repo({'repo': 'did:plc:ze3uieyyns7prike7itbdjiy'}).didDoc
print(did_doc.service)
print(did_doc['service'])
print(did_doc['@context'])

atproto_feed = client.com.atproto.repo.get_record(
{'collection': ids.AppBskyFeedGenerator, 'repo': 'marshal.dev', 'rkey': 'atproto'}
).value
print(valid)
print(atproto_feed)
print(atproto_feed.createdAt)
print(atproto_feed['createdAt'])
print(type(atproto_feed))

exit(0)

# client.com.atproto.admin.get_moderation_actions()

Expand Down Expand Up @@ -158,8 +172,10 @@ def _get_ops_by_type(commit: models.ComAtprotoSyncSubscribeRepos.Commit) -> dict

def _custom_feed_firehose():
client = FirehoseSubscribeReposClient({'cursor': 93278360})
client = FirehoseSubscribeReposClient()

def on_message_handler(message: 'MessageFrame') -> None:
# return
commit = parse_subscribe_repos_message(message)
if not isinstance(commit, models.ComAtprotoSyncSubscribeRepos.Commit):
return
Expand Down Expand Up @@ -221,9 +237,9 @@ async def _stop_after_n_sec():


if __name__ == '__main__':
# sync_main()
sync_main()
# asyncio.get_event_loop().run_until_complete(main())

# _custom_feed_firehose()
_main_firehose_test()
asyncio.get_event_loop().run_until_complete(_main_async_firehose_test())
# _main_firehose_test()
# asyncio.get_event_loop().run_until_complete(_main_async_firehose_test())

0 comments on commit 50f67c4

Please sign in to comment.