Skip to content

Commit

Permalink
Switch from dataclasses to attrs
Browse files Browse the repository at this point in the history
In order to work around <python/mypy#5374>
  • Loading branch information
jwodder committed Feb 27, 2022
1 parent 80dfd88 commit 254a8aa
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 32 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ include_package_data = True
python_requires = ~=3.8
install_requires =
anyio ~= 3.5
attrs >= 20.1.0
click >= 7.0
click-loglevel ~= 0.3
flatbencode >= 0.2
Expand Down
13 changes: 7 additions & 6 deletions src/demagnetize/core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from dataclasses import dataclass, field
from functools import partial
from random import randint
from time import time
from typing import Awaitable, Callable, List, Union
import attr
from torf import Magnet, Torrent
from yarl import URL
from .consts import CLIENT
Expand All @@ -22,11 +22,11 @@
)


@dataclass
@attr.define
class Demagnetizer:
key: Key = field(default_factory=Key.generate)
peer_id: bytes = field(default_factory=make_peer_id)
peer_port: int = field(default_factory=lambda: randint(1025, 65535))
key: Key = attr.Factory(Key.generate)
peer_id: bytes = attr.Factory(make_peer_id)
peer_port: int = attr.Factory(lambda: randint(1025, 65535))

def __post_init__(self) -> None:
log.log(TRACE, "Using key = %s", self.key)
Expand Down Expand Up @@ -88,7 +88,8 @@ def get_tracker(self, url: str) -> Tracker:
return HTTPTracker(app=self, url=u)
elif u.scheme == "udp":
try:
return UDPTracker(app=self, url=u)
# <https://github.com/python/mypy/issues/12259>
return UDPTracker(app=self, url=u) # type: ignore[call-arg]
except ValueError as e:
raise TrackerError(f"Invalid tracker URL: {e}")
else:
Expand Down
6 changes: 3 additions & 3 deletions src/demagnetize/peers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, AsyncIterator, Dict, Optional
from anyio import connect_tcp
from anyio.abc import AsyncResource, SocketStream
import attr
from .util import InfoHash, log


@dataclass
@attr.define
class Peer:
host: str
port: int
Expand Down Expand Up @@ -42,7 +42,7 @@ async def get_info(self, info_hash: InfoHash) -> dict:
return await connpeer.get_info(info_hash)


@dataclass
@attr.define
class ConnectedPeer(AsyncResource):
conn: SocketStream

Expand Down
6 changes: 3 additions & 3 deletions src/demagnetize/session.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from anyio import EndOfStream, create_memory_object_stream, create_task_group
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
import attr
from torf import Magnet
from .errors import DemagnetizeFailure, Error
from .peers import Peer
Expand All @@ -13,11 +13,11 @@
from .core import Demagnetizer


@dataclass
@attr.define
class TorrentSession:
app: Demagnetizer
magnet: Magnet
info_hash: InfoHash = field(init=False)
info_hash: InfoHash = attr.field(init=False)

def __post_init__(self) -> None:
# torf only accepts magnet URLs with valid info hashes, so this
Expand Down
4 changes: 2 additions & 2 deletions src/demagnetize/trackers/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from ipaddress import AddressValueError, IPv4Address, IPv6Address
import struct
from typing import TYPE_CHECKING, List
import attr
from yarl import URL
from ..peers import Peer
from ..util import InfoHash
Expand All @@ -12,7 +12,7 @@
from ..core import Demagnetizer


@dataclass
@attr.define
class Tracker(ABC):
app: Demagnetizer
url: URL
Expand Down
7 changes: 3 additions & 4 deletions src/demagnetize/trackers/http.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, List, Optional, Type, TypeVar, cast
from urllib.parse import quote
import attr
from flatbencode import decode
from httpx import AsyncClient, HTTPError
from .base import Tracker, unpack_peers, unpack_peers6
Expand All @@ -13,7 +13,6 @@
T = TypeVar("T")


@dataclass
class HTTPTracker(Tracker):
async def get_peers(self, info_hash: InfoHash) -> List[Peer]:
log.info("Requesting peers for %s from %s", info_hash, self)
Expand Down Expand Up @@ -63,7 +62,7 @@ async def get_peers(self, info_hash: InfoHash) -> List[Peer]:
return response.peers


@dataclass
@attr.define
class Response:
failure_reason: Optional[str] = None
warning_message: Optional[str] = None
Expand All @@ -72,7 +71,7 @@ class Response:
tracker_id: Optional[bytes] = None
complete: Optional[int] = None
incomplete: Optional[int] = None
peers: List[Peer] = field(default_factory=list)
peers: List[Peer] = attr.Factory(list)

@classmethod
def parse(cls, content: bytes) -> Response:
Expand Down
18 changes: 9 additions & 9 deletions src/demagnetize/trackers/udp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# <https://www.bittorrent.org/beps/bep_0015.html>
from __future__ import annotations
from contextlib import nullcontext
from dataclasses import dataclass, field
from functools import partial
from random import randint
from socket import AF_INET6
Expand All @@ -10,6 +9,7 @@
from typing import Any, Callable, ContextManager, List, Optional, TypeVar
from anyio import create_connected_udp_socket, fail_after
from anyio.abc import AsyncResource, ConnectedUDPSocket, SocketAttribute
import attr
from .base import Tracker, unpack_peers, unpack_peers6
from ..consts import LEFT, NUMWANT
from ..errors import TrackerError
Expand All @@ -21,12 +21,12 @@
PROTOCOL_ID = 0x41727101980


@dataclass
@attr.define
class UDPTracker(Tracker):
host: str = field(init=False)
port: int = field(init=False)
host: str = attr.field(init=False)
port: int = attr.field(init=False)

def __post_init__(self) -> None:
def __attrs_post_init__(self) -> None:
if self.url.scheme != "udp":
raise ValueError("URL scheme must be 'udp'")
if self.url.host is None:
Expand All @@ -47,7 +47,7 @@ async def get_peers(self, info_hash: InfoHash) -> List[Peer]:
)


@dataclass
@attr.define
class Communicator(AsyncResource):
tracker: UDPTracker
conn: ConnectedUDPSocket
Expand Down Expand Up @@ -134,11 +134,11 @@ async def connect(self) -> Connection:
return Connection(communicator=self, id=conn_id)


@dataclass
@attr.define
class Connection:
communicator: Communicator
id: int
expiration: float = field(init=False)
expiration: float = attr.field(init=False)

def __post_init__(self) -> None:
self.expiration = time() + 60
Expand All @@ -164,7 +164,7 @@ async def announce(self, info_hash: InfoHash) -> AnnounceResponse:
)


@dataclass
@attr.define
class AnnounceResponse:
interval: int
leechers: int
Expand Down
10 changes: 5 additions & 5 deletions src/demagnetize/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from base64 import b32decode
from binascii import unhexlify
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
import logging
from random import choices, randrange
import re
Expand All @@ -20,6 +19,7 @@
)
from anyio import create_memory_object_stream, create_task_group
from anyio.streams.memory import MemoryObjectSendStream
import attr
from torf import Magnet, Torrent
from .consts import PEER_ID_PREFIX

Expand All @@ -30,7 +30,7 @@
T = TypeVar("T")


@dataclass
@attr.define
class InfoHash:
as_str: str
as_bytes: bytes
Expand All @@ -52,7 +52,7 @@ def __bytes__(self) -> bytes:
return self.as_bytes


@dataclass
@attr.define
class Key:
value: int

Expand All @@ -70,11 +70,11 @@ def __bytes__(self) -> bytes:
return self.value.to_bytes(4, "big")


@dataclass
@attr.define
class Report:
#: Collection of magnet URLs and the files their torrents were saved to
#: (None if the demagnetization failed)
downloads: List[Tuple[Magnet, Optional[str]]] = field(default_factory=list)
downloads: List[Tuple[Magnet, Optional[str]]] = attr.Factory(list)

@classmethod
def for_success(cls, magnet: Magnet, filename: str) -> Report:
Expand Down

0 comments on commit 254a8aa

Please sign in to comment.