diff --git a/CHANGELOG.md b/CHANGELOG.md index d1c347c00a3f1f..7732e8ec1564f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -137,6 +137,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed logger creating directory structure too early in DDP ([#6380](https://github.com/PyTorchLightning/pytorch-lightning/pull/6380)) +- Fixed comparing required versions ([#6434](https://github.com/PyTorchLightning/pytorch-lightning/pull/6434)) + + ## [1.2.2] - 2021-03-02 ### Added diff --git a/pl_examples/__init__.py b/pl_examples/__init__.py index ffd60f9ed71af4..e3c5a78124137d 100644 --- a/pl_examples/__init__.py +++ b/pl_examples/__init__.py @@ -15,10 +15,10 @@ _DATASETS_PATH = os.path.join(_PACKAGE_ROOT, 'Datasets') _TORCHVISION_AVAILABLE = _module_available("torchvision") -_TORCHVISION_MNIST_AVAILABLE = True +_TORCHVISION_MNIST_AVAILABLE = _TORCHVISION_AVAILABLE _DALI_AVAILABLE = _module_available("nvidia.dali") -if _TORCHVISION_AVAILABLE: +if _TORCHVISION_MNIST_AVAILABLE: try: from torchvision.datasets.mnist import MNIST MNIST(_DATASETS_PATH, download=True) diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 41a13d6c678a0d..7bb6f51b1195ee 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """General utilities""" +import importlib import operator import platform import sys @@ -19,7 +20,7 @@ from importlib.util import find_spec import torch -from pkg_resources import DistributionNotFound, get_distribution +from pkg_resources import DistributionNotFound def _module_available(module_path: str) -> bool: @@ -42,8 +43,17 @@ def _module_available(module_path: str) -> bool: def _compare_version(package: str, op, version) -> bool: + """Compare package version with some requirements + + >>> _compare_version("torch", operator.ge, "0.1") + True + """ + if not _module_available(package): + return False try: - pkg_version = LooseVersion(get_distribution(package).version) + pkg = importlib.import_module(package) + assert hasattr(pkg, '__version__') + pkg_version = pkg.__version__ return op(pkg_version, LooseVersion(version)) except DistributionNotFound: return False