From 111c292b83f0e7e8b0899b1db49167b7bd23a6fb Mon Sep 17 00:00:00 2001 From: Ilya Siamionau Date: Thu, 25 May 2023 00:13:45 +0200 Subject: [PATCH] Fix request error handling (#41) --- atproto/exceptions.py | 22 ++++++++++++++++------ atproto/xrpc_client/request.py | 2 +- test.py | 20 ++++++++++++-------- 3 files changed, 29 insertions(+), 15 deletions(-) diff --git a/atproto/exceptions.py b/atproto/exceptions.py index 66ba418f..f0478633 100644 --- a/atproto/exceptions.py +++ b/atproto/exceptions.py @@ -1,3 +1,9 @@ +import typing as t + +if t.TYPE_CHECKING: + from atproto.xrpc_client.request import Response + + class AtProtocolError(Exception): """Base exception""" @@ -42,7 +48,12 @@ class ModelFieldNotFoundError(ModelError): ... -class NetworkError(AtProtocolError): +class RequestErrorBase(AtProtocolError): + def __init__(self, response: t.Optional['Response'] = None): + self.response: 'Response' = response + + +class NetworkError(RequestErrorBase): ... @@ -50,16 +61,15 @@ class InvokeTimeoutError(NetworkError): ... -class UnauthorizedError(AtProtocolError): - def __init__(self, response): - self.response = response +class UnauthorizedError(RequestErrorBase): + ... -class RequestException(AtProtocolError): +class RequestException(RequestErrorBase): ... -class BadRequestError(RequestException): +class BadRequestError(RequestErrorBase): ... diff --git a/atproto/xrpc_client/request.py b/atproto/xrpc_client/request.py index 95a80a1f..88122642 100644 --- a/atproto/xrpc_client/request.py +++ b/atproto/xrpc_client/request.py @@ -56,7 +56,7 @@ def _handle_response(response: httpx.Response) -> httpx.Response: if response.status_code in {401, 403}: raise exceptions.UnauthorizedError(error_response) - elif response.status_code == 404: + elif response.status_code == 400: raise exceptions.BadRequestError(error_response) elif response.status_code in {409, 413, 502}: raise exceptions.NetworkError(error_response) diff --git a/test.py b/test.py index 5c49241c..149ef4f1 100644 --- a/test.py +++ b/test.py @@ -3,7 +3,7 @@ import os import threading -from atproto import CAR, AsyncClient, AtUri, Client, models +from atproto import CAR, AsyncClient, AtUri, Client, exceptions, models from atproto.firehose import ( AsyncFirehoseSubscribeLabelsClient, AsyncFirehoseSubscribeReposClient, @@ -40,7 +40,7 @@ def sync_main(): client.login(os.environ['USERNAME'], os.environ['PASSWORD']) # repo = client.com.atproto.sync.get_repo({'did': client.me.did}) - did = client.com.atproto.identity.resolve_handle({'handle': 'bsky.app'}).did + # did = client.com.atproto.identity.resolve_handle({'handle': 'bsky.app'}).did # repo = client.com.atproto.sync.get_repo({'did': did}) # car_file = CAR.from_bytes(repo) # print(car_file.root) @@ -62,10 +62,14 @@ def sync_main(): # client.com.atproto.repo.get_record({'collection': 'app.bsky.feed.post', 'repo': 'arta.bsky.social'}) - # with open('cat2.jpg', 'rb') as f: - # cat_data = f.read() - # - # client.send_image('Cat looking for a Python', cat_data, 'cat alt') + with open('cat_big.png', 'rb') as f: + cat_data = f.read() + try: + client.send_image('Cat looking for a Python', cat_data, 'cat alt') + except exceptions.BadRequestError as e: + print('Status code:', e.response.status_code) + print('Error code:', e.response.content.error) + print('Error message:', e.response.content.message) # # resolve = client.com.atproto.identity.resolve_handle(models.ComAtprotoIdentityResolveHandle.Params(profile.handle)) # assert resolve.did == profile.did @@ -180,8 +184,8 @@ async def _stop_after_n_sec(): if __name__ == '__main__': # test_strange_embed_images_type() - # sync_main() + sync_main() # asyncio.get_event_loop().run_until_complete(main()) # _main_firehose_test() - asyncio.get_event_loop().run_until_complete(_main_async_firehose_test()) + # asyncio.get_event_loop().run_until_complete(_main_async_firehose_test())