Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add as_dict option to Algorithm.to_jwk #881

Merged
merged 16 commits into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 109 additions & 30 deletions jwt/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import hashlib
import hmac
import json
import sys
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, cast
from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, Union, cast, overload

from .exceptions import InvalidKeyError
from .types import HashlibHash, JWKDict
Expand All @@ -20,6 +21,12 @@
to_base64url_uint,
)

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal


try:
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
Expand Down Expand Up @@ -184,9 +191,21 @@ def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
for the specified message and key values.
"""

@overload
@staticmethod
@abstractmethod
def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict:
... # pragma: no cover

@overload
@staticmethod
@abstractmethod
def to_jwk(key_obj) -> str:
def to_jwk(key_obj, as_dict: Literal[False] = False) -> str:
... # pragma: no cover

@staticmethod
@abstractmethod
def to_jwk(key_obj, as_dict: bool = False) -> Union[JWKDict, str]:
"""
Serializes a given key into a JWK
"""
Expand Down Expand Up @@ -221,7 +240,7 @@ def verify(self, msg: bytes, key: None, sig: bytes) -> bool:
return False

@staticmethod
def to_jwk(key_obj: Any) -> NoReturn:
def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn:
raise NotImplementedError()

@staticmethod
Expand Down Expand Up @@ -253,14 +272,27 @@ def prepare_key(self, key: str | bytes) -> bytes:

return key_bytes

@overload
@staticmethod
def to_jwk(key_obj: str | bytes) -> str:
return json.dumps(
{
"k": base64url_encode(force_bytes(key_obj)).decode(),
"kty": "oct",
}
)
def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict:
... # pragma: no cover

@overload
@staticmethod
def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str:
... # pragma: no cover

@staticmethod
def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> Union[JWKDict, str]:
jwk = {
"k": base64url_encode(force_bytes(key_obj)).decode(),
"kty": "oct",
}

if as_dict:
return jwk
else:
return json.dumps(jwk)

@staticmethod
def from_jwk(jwk: str | JWKDict) -> bytes:
Expand Down Expand Up @@ -320,8 +352,20 @@ def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
except ValueError:
return cast(RSAPublicKey, load_pem_public_key(key_bytes))

@overload
@staticmethod
def to_jwk(key_obj: AllowedRSAKeys) -> str:
def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict:
... # pragma: no cover

@overload
@staticmethod
def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str:
... # pragma: no cover

@staticmethod
def to_jwk(
key_obj: AllowedRSAKeys, as_dict: bool = False
) -> Union[JWKDict, str]:
obj: dict[str, Any] | None = None

if hasattr(key_obj, "private_numbers"):
Expand Down Expand Up @@ -354,7 +398,10 @@ def to_jwk(key_obj: AllowedRSAKeys) -> str:
else:
raise InvalidKeyError("Not a public or private key")

return json.dumps(obj)
if as_dict:
return obj
else:
return json.dumps(obj)

@staticmethod
def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
Expand Down Expand Up @@ -503,8 +550,20 @@ def verify(self, msg: bytes, key: "AllowedECKeys", sig: bytes) -> bool:
except InvalidSignature:
return False

@overload
@staticmethod
def to_jwk(key_obj: AllowedECKeys) -> str:
def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict:
... # pragma: no cover

@overload
@staticmethod
def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str:
... # pragma: no cover

@staticmethod
def to_jwk(
key_obj: AllowedECKeys, as_dict: bool = False
) -> Union[JWKDict, str]:
if isinstance(key_obj, EllipticCurvePrivateKey):
public_numbers = key_obj.public_key().public_numbers()
elif isinstance(key_obj, EllipticCurvePublicKey):
Expand Down Expand Up @@ -535,7 +594,10 @@ def to_jwk(key_obj: AllowedECKeys) -> str:
key_obj.private_numbers().private_value
).decode()

return json.dumps(obj)
if as_dict:
return obj
else:
return json.dumps(obj)

@staticmethod
def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
Expand Down Expand Up @@ -707,21 +769,35 @@ def verify(
except InvalidSignature:
return False

@overload
@staticmethod
def to_jwk(key: AllowedOKPKeys) -> str:
def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict:
... # pragma: no cover

@overload
@staticmethod
def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str:
... # pragma: no cover

@staticmethod
def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> Union[JWKDict, str]:
if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
x = key.public_bytes(
encoding=Encoding.Raw,
format=PublicFormat.Raw,
)
crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
return json.dumps(
{
"x": base64url_encode(force_bytes(x)).decode(),
"kty": "OKP",
"crv": crv,
}
)

obj = {
"x": base64url_encode(force_bytes(x)).decode(),
"kty": "OKP",
"crv": crv,
}

if as_dict:
return obj
else:
return json.dumps(obj)

if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
d = key.private_bytes(
Expand All @@ -736,14 +812,17 @@ def to_jwk(key: AllowedOKPKeys) -> str:
)

crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
return json.dumps(
{
"x": base64url_encode(force_bytes(x)).decode(),
"d": base64url_encode(force_bytes(d)).decode(),
"kty": "OKP",
"crv": crv,
}
)
obj = {
"x": base64url_encode(force_bytes(x)).decode(),
"d": base64url_encode(force_bytes(d)).decode(),
"kty": "OKP",
"crv": crv,
}

if as_dict:
return obj
else:
return json.dumps(obj)

raise InvalidKeyError("Not a public or private key")

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ source = ["jwt", ".tox/*/site-packages"]

[tool.coverage.report]
show_missing = true
exclude_lines = ["if TYPE_CHECKING:"]
exclude_lines = ["if TYPE_CHECKING:", "pragma: no cover"]

[tool.isort]
profile = "black"
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ zip_safe = false
include_package_data = true
python_requires = >=3.7
packages = find:
install_requires =
typing_extensions; python_version<="3.7"

[options.package_data]
* = py.typed
Expand Down
Loading