Skip to content

Commit

Permalink
Add more type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
and-semakin authored and sybrenstuvel committed Jun 3, 2020
1 parent 1473cb8 commit ae1a906
Show file tree
Hide file tree
Showing 9 changed files with 31 additions and 26 deletions.
2 changes: 1 addition & 1 deletion rsa/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from struct import pack


def byte(num: int):
def byte(num: int) -> bytes:
"""
Converts a number between 0 and 255 (both inclusive) to a base-256 (byte)
representation.
Expand Down
10 changes: 5 additions & 5 deletions rsa/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(self) -> None:

@abc.abstractmethod
def perform_operation(self, indata: bytes, key: rsa.key.AbstractKey,
cli_args: Indexable):
cli_args: Indexable) -> typing.Any:
"""Performs the program's operation.
Implement in a subclass.
Expand Down Expand Up @@ -201,7 +201,7 @@ class EncryptOperation(CryptoOperation):
operation_progressive = 'encrypting'

def perform_operation(self, indata: bytes, pub_key: rsa.key.AbstractKey,
cli_args: Indexable = ()):
cli_args: Indexable = ()) -> bytes:
"""Encrypts files."""
assert isinstance(pub_key, rsa.key.PublicKey)
return rsa.encrypt(indata, pub_key)
Expand All @@ -219,7 +219,7 @@ class DecryptOperation(CryptoOperation):
key_class = rsa.PrivateKey

def perform_operation(self, indata: bytes, priv_key: rsa.key.AbstractKey,
cli_args: Indexable = ()):
cli_args: Indexable = ()) -> bytes:
"""Decrypts files."""
assert isinstance(priv_key, rsa.key.PrivateKey)
return rsa.decrypt(indata, priv_key)
Expand All @@ -242,7 +242,7 @@ class SignOperation(CryptoOperation):
'to stdout if this option is not present.')

def perform_operation(self, indata: bytes, priv_key: rsa.key.AbstractKey,
cli_args: Indexable):
cli_args: Indexable) -> bytes:
"""Signs files."""
assert isinstance(priv_key, rsa.key.PrivateKey)

Expand All @@ -269,7 +269,7 @@ class VerifyOperation(CryptoOperation):
has_output = False

def perform_operation(self, indata: bytes, pub_key: rsa.key.AbstractKey,
cli_args: Indexable):
cli_args: Indexable) -> None:
"""Verifies files."""
assert isinstance(pub_key, rsa.key.PublicKey)

Expand Down
2 changes: 1 addition & 1 deletion rsa/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


class NotRelativePrimeError(ValueError):
def __init__(self, a, b, d, msg=''):
def __init__(self, a: int, b: int, d: int, msg: str = '') -> None:
super().__init__(msg or "%d and %d are not relatively prime, divider=%i" % (a, b, d))
self.a = a
self.b = b
Expand Down
2 changes: 1 addition & 1 deletion rsa/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""


def assert_int(var: int, name: str):
def assert_int(var: int, name: str) -> None:
if isinstance(var, int):
return

Expand Down
28 changes: 16 additions & 12 deletions rsa/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _save_pkcs1_der(self) -> bytes:
"""

@classmethod
def load_pkcs1(cls, keyfile: bytes, format='PEM') -> 'AbstractKey':
def load_pkcs1(cls, keyfile: bytes, format: str = 'PEM') -> 'AbstractKey':
"""Loads a key in PKCS#1 DER or PEM format.
:param keyfile: contents of a DER- or PEM-encoded file that contains
Expand Down Expand Up @@ -128,7 +128,7 @@ def _assert_format_exists(file_format: str, methods: typing.Mapping[str, typing.
raise ValueError('Unsupported format: %r, try one of %s' % (file_format,
formats))

def save_pkcs1(self, format='PEM') -> bytes:
def save_pkcs1(self, format: str = 'PEM') -> bytes:
"""Saves the key in PKCS#1 DER or PEM format.
:param format: the format to save; 'PEM' or 'DER'
Expand Down Expand Up @@ -203,7 +203,7 @@ class PublicKey(AbstractKey):

__slots__ = ('n', 'e')

def __getitem__(self, key):
def __getitem__(self, key: str) -> int:
return getattr(self, key)

def __repr__(self) -> str:
Expand Down Expand Up @@ -378,7 +378,7 @@ def __init__(self, n: int, e: int, d: int, p: int, q: int) -> None:
self.exp2 = int(d % (q - 1))
self.coef = rsa.common.inverse(q, p)

def __getitem__(self, key):
def __getitem__(self, key: str) -> int:
return getattr(self, key)

def __repr__(self) -> str:
Expand All @@ -388,7 +388,7 @@ def __getstate__(self) -> typing.Tuple[int, int, int, int, int, int, int, int]:
"""Returns the key as tuple for pickling."""
return self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef

def __setstate__(self, state: typing.Tuple[int, int, int, int, int, int, int, int]):
def __setstate__(self, state: typing.Tuple[int, int, int, int, int, int, int, int]) -> None:
"""Sets the key from tuple."""
self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef = state

Expand Down Expand Up @@ -574,7 +574,9 @@ def _save_pkcs1_pem(self) -> bytes:
return rsa.pem.save_pem(der, b'RSA PRIVATE KEY')


def find_p_q(nbits: int, getprime_func=rsa.prime.getprime, accurate=True) -> typing.Tuple[int, int]:
def find_p_q(nbits: int,
getprime_func: typing.Callable[[int], int] = rsa.prime.getprime,
accurate: bool = True) -> typing.Tuple[int, int]:
"""Returns a tuple of two different primes of nbits bits each.
The resulting p * q has exacty 2 * nbits bits, and the returned p and q
Expand Down Expand Up @@ -619,7 +621,7 @@ def find_p_q(nbits: int, getprime_func=rsa.prime.getprime, accurate=True) -> typ
log.debug('find_p_q(%i): Finding q', nbits)
q = getprime_func(qbits)

def is_acceptable(p, q):
def is_acceptable(p: int, q: int) -> bool:
"""Returns True iff p and q are acceptable:
- p and q differ
Expand Down Expand Up @@ -697,8 +699,8 @@ def calculate_keys(p: int, q: int) -> typing.Tuple[int, int]:

def gen_keys(nbits: int,
getprime_func: typing.Callable[[int], int],
accurate=True,
exponent=DEFAULT_EXPONENT) -> typing.Tuple[int, int, int, int]:
accurate: bool = True,
exponent: int = DEFAULT_EXPONENT) -> typing.Tuple[int, int, int, int]:
"""Generate RSA keys of nbits bits. Returns (p, q, e, d).
Note: this can take a long time, depending on the key size.
Expand Down Expand Up @@ -726,8 +728,10 @@ def gen_keys(nbits: int,
return p, q, e, d


def newkeys(nbits: int, accurate=True, poolsize=1, exponent=DEFAULT_EXPONENT) \
-> typing.Tuple[PublicKey, PrivateKey]:
def newkeys(nbits: int,
accurate: bool = True,
poolsize: int = 1,
exponent: int = DEFAULT_EXPONENT) -> typing.Tuple[PublicKey, PrivateKey]:
"""Generates public and private keys, and returns them as (pub, priv).
The public key is also known as the 'encryption key', and is a
Expand Down Expand Up @@ -763,7 +767,7 @@ def newkeys(nbits: int, accurate=True, poolsize=1, exponent=DEFAULT_EXPONENT) \
if poolsize > 1:
from rsa import parallel

def getprime_func(nbits):
def getprime_func(nbits: int) -> int:
return parallel.getprime(nbits, poolsize=poolsize)
else:
getprime_func = rsa.prime.getprime
Expand Down
3 changes: 2 additions & 1 deletion rsa/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
"""

import multiprocessing as mp
from multiprocessing.connection import Connection

import rsa.prime
import rsa.randnum


def _find_prime(nbits: int, pipe) -> None:
def _find_prime(nbits: int, pipe: Connection) -> None:
while True:
integer = rsa.randnum.read_random_odd_int(nbits)

Expand Down
6 changes: 3 additions & 3 deletions rsa/pem.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _pem_lines(contents: bytes, pem_start: bytes, pem_end: bytes) -> typing.Iter
# Handle start marker
if line == pem_start:
if in_pem_part:
raise ValueError('Seen start marker "%s" twice' % pem_start)
raise ValueError('Seen start marker "%r" twice' % pem_start)

in_pem_part = True
seen_pem_start = True
Expand All @@ -72,10 +72,10 @@ def _pem_lines(contents: bytes, pem_start: bytes, pem_end: bytes) -> typing.Iter

# Do some sanity checks
if not seen_pem_start:
raise ValueError('No PEM start marker "%s" found' % pem_start)
raise ValueError('No PEM start marker "%r" found' % pem_start)

if in_pem_part:
raise ValueError('No PEM end marker "%s" found' % pem_end)
raise ValueError('No PEM end marker "%r" found' % pem_end)


def load_pem(contents: FlexiText, pem_marker: FlexiText) -> bytes:
Expand Down
2 changes: 1 addition & 1 deletion rsa/pkcs1.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _pad_for_signing(message: bytes, target_length: int) -> bytes:
message])


def encrypt(message: bytes, pub_key: key.PublicKey):
def encrypt(message: bytes, pub_key: key.PublicKey) -> bytes:
"""Encrypts the given message using PKCS#1 v1.5
:param message: the message to encrypt. Must be a byte string no longer than
Expand Down
2 changes: 1 addition & 1 deletion rsa/pkcs1_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)


def mgf1(seed: bytes, length: int, hasher='SHA-1') -> bytes:
def mgf1(seed: bytes, length: int, hasher: str = 'SHA-1') -> bytes:
"""
MGF1 is a Mask Generation Function based on a hash function.
Expand Down

0 comments on commit ae1a906

Please sign in to comment.