diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index 7cecfbf7..f19b10ac 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -1,6 +1,7 @@ import json import urllib.request from functools import lru_cache +from ssl import SSLContext from typing import Any, Dict, List, Optional from urllib.error import URLError @@ -20,6 +21,7 @@ def __init__( lifespan: int = 300, headers: Optional[Dict[str, Any]] = None, timeout: int = 30, + ssl_context: Optional[SSLContext] = None, ): if headers is None: headers = {} @@ -27,6 +29,7 @@ def __init__( self.jwk_set_cache: Optional[JWKSetCache] = None self.headers = headers self.timeout = timeout + self.ssl_context = ssl_context if cache_jwk_set: # Init jwt set cache with default or given lifespan. @@ -48,7 +51,9 @@ def fetch_data(self) -> Any: jwk_set: Any = None try: r = urllib.request.Request(url=self.uri, headers=self.headers) - with urllib.request.urlopen(r, timeout=self.timeout) as response: + with urllib.request.urlopen( + r, timeout=self.timeout, context=self.ssl_context + ) as response: jwk_set = json.load(response) except (URLError, TimeoutError) as e: raise PyJWKClientConnectionError( diff --git a/tests/test_jwks_client.py b/tests/test_jwks_client.py index 1122af86..c3951eaa 100644 --- a/tests/test_jwks_client.py +++ b/tests/test_jwks_client.py @@ -1,5 +1,6 @@ import contextlib import json +import ssl import time from unittest import mock from urllib.error import URLError @@ -335,3 +336,22 @@ def test_get_jwt_set_timeout(self): jwks_client.get_jwk_set() assert 'Fail to fetch data from the url, err: "timed out"' in str(exc.value) + + def test_get_jwt_set_sslcontext_default(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + jwks_client = PyJWKClient(url, ssl_context=ssl.create_default_context()) + + jwk_set = jwks_client.get_jwk_set() + + assert jwk_set is not None + + def test_get_jwt_set_sslcontext_no_ca(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + jwks_client = PyJWKClient( + url, ssl_context=ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT) + ) + + with pytest.raises(PyJWKClientError): + jwks_client.get_jwk_set() + + assert "Failed to get an expected error"