Skip to content

Commit

Permalink
Merge 8be4ef3 into 55dd3a4
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Mar 9, 2021
2 parents 55dd3a4 + 8be4ef3 commit 4eeeedc
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed logger creating directory structure too early in DDP ([#6380](https://github.com/PyTorchLightning/pytorch-lightning/pull/6380))


- Fixed comparing required versions ([#6434](https://github.com/PyTorchLightning/pytorch-lightning/pull/6434))


## [1.2.2] - 2021-03-02

### Added
Expand Down
4 changes: 2 additions & 2 deletions pl_examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
_DATASETS_PATH = os.path.join(_PACKAGE_ROOT, 'Datasets')

_TORCHVISION_AVAILABLE = _module_available("torchvision")
_TORCHVISION_MNIST_AVAILABLE = True
_TORCHVISION_MNIST_AVAILABLE = _TORCHVISION_AVAILABLE
_DALI_AVAILABLE = _module_available("nvidia.dali")

if _TORCHVISION_AVAILABLE:
if _TORCHVISION_MNIST_AVAILABLE:
try:
from torchvision.datasets.mnist import MNIST
MNIST(_DATASETS_PATH, download=True)
Expand Down
14 changes: 12 additions & 2 deletions pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities"""
import importlib
import operator
import platform
import sys
from distutils.version import LooseVersion
from importlib.util import find_spec

import torch
from pkg_resources import DistributionNotFound, get_distribution
from pkg_resources import DistributionNotFound


def _module_available(module_path: str) -> bool:
Expand All @@ -42,8 +43,17 @@ def _module_available(module_path: str) -> bool:


def _compare_version(package: str, op, version) -> bool:
"""Compare package version with some requirements
>>> _compare_version("torch", operator.ge, "0.1")
True
"""
if not _module_available(package):
return False
try:
pkg_version = LooseVersion(get_distribution(package).version)
pkg = importlib.import_module(package)
assert hasattr(pkg, '__version__')
pkg_version = pkg.__version__
return op(pkg_version, LooseVersion(version))
except DistributionNotFound:
return False
Expand Down

0 comments on commit 4eeeedc

Please sign in to comment.