diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index f08c277b71064..45540ca12fa5f 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -143,10 +143,9 @@ jobs: # NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003 coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml - # todo: put this back just when TorchVision can download datasets - #- name: Examples - # run: | - # python -m pytest pl_examples -v --durations=10 + - name: Examples + run: | + python -m pytest pl_examples -v --durations=10 - name: Upload pytest test results uses: actions/upload-artifact@v2 diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 6d67afc31f2e4..48db9ede12400 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -100,10 +100,12 @@ jobs: python -m pytest benchmarks -v --maxfail=2 --durations=0 displayName: 'Testing: benchmarks' - # todo: put this back just when TorchVision can download datasets - #- bash: | - # python -m pytest pl_examples -v --maxfail=2 --durations=0 - # python setup.py install --user --quiet - # bash pl_examples/run_ddp-example.sh - # pip uninstall -y pytorch-lightning - # displayName: 'Examples' + - bash: | + python -m pytest pl_examples -v --maxfail=2 --durations=0 + python setup.py install --user --quiet + bash pl_examples/run_ddp-example.sh + cd pl_examples/basic_examples + bash submit_ddp_job.sh + bash submit_ddp2_job.sh + pip uninstall -y pytorch-lightning + displayName: 'Examples' diff --git a/pl_examples/__init__.py b/pl_examples/__init__.py index 6ad0a4dfc0624..ffd60f9ed71af 100644 --- a/pl_examples/__init__.py +++ b/pl_examples/__init__.py @@ -1,14 +1,30 @@ import os +from urllib.error import HTTPError + +from six.moves import urllib from pytorch_lightning.utilities import _module_available +# TorchVision hotfix https://github.com/pytorch/vision/issues/1938 +opener = urllib.request.build_opener() +opener.addheaders = [('User-agent', 'Mozilla/5.0')] +urllib.request.install_opener(opener) + _EXAMPLES_ROOT = os.path.dirname(__file__) _PACKAGE_ROOT = os.path.dirname(_EXAMPLES_ROOT) _DATASETS_PATH = os.path.join(_PACKAGE_ROOT, 'Datasets') _TORCHVISION_AVAILABLE = _module_available("torchvision") +_TORCHVISION_MNIST_AVAILABLE = True _DALI_AVAILABLE = _module_available("nvidia.dali") +if _TORCHVISION_AVAILABLE: + try: + from torchvision.datasets.mnist import MNIST + MNIST(_DATASETS_PATH, download=True) + except HTTPError: + _TORCHVISION_MNIST_AVAILABLE = False + LIGHTNING_LOGO = """ #### ########### diff --git a/pl_examples/basic_examples/autoencoder.py b/pl_examples/basic_examples/autoencoder.py index c60c4faec4acd..b3188a21b7f04 100644 --- a/pl_examples/basic_examples/autoencoder.py +++ b/pl_examples/basic_examples/autoencoder.py @@ -20,9 +20,9 @@ from torch.utils.data import DataLoader, random_split import pytorch_lightning as pl -from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, cli_lightning_logo +from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo -if _TORCHVISION_AVAILABLE: +if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: from torchvision import transforms from torchvision.datasets.mnist import MNIST else: diff --git a/pl_examples/basic_examples/backbone_image_classifier.py b/pl_examples/basic_examples/backbone_image_classifier.py index ad50da18ff3fd..01a5dca0de3c7 100644 --- a/pl_examples/basic_examples/backbone_image_classifier.py +++ b/pl_examples/basic_examples/backbone_image_classifier.py @@ -19,9 +19,9 @@ from torch.utils.data import DataLoader, random_split import pytorch_lightning as pl -from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, cli_lightning_logo +from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo -if _TORCHVISION_AVAILABLE: +if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: from torchvision import transforms from torchvision.datasets.mnist import MNIST else: diff --git a/pl_examples/basic_examples/dali_image_classifier.py b/pl_examples/basic_examples/dali_image_classifier.py index d6e64d2b3de14..b4bf1407a9b26 100644 --- a/pl_examples/basic_examples/dali_image_classifier.py +++ b/pl_examples/basic_examples/dali_image_classifier.py @@ -23,9 +23,15 @@ from torch.utils.data import random_split import pytorch_lightning as pl -from pl_examples import _DALI_AVAILABLE, _DATASETS_PATH, _TORCHVISION_AVAILABLE, cli_lightning_logo - -if _TORCHVISION_AVAILABLE: +from pl_examples import ( + _DALI_AVAILABLE, + _DATASETS_PATH, + _TORCHVISION_AVAILABLE, + _TORCHVISION_MNIST_AVAILABLE, + cli_lightning_logo, +) + +if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: from torchvision import transforms from torchvision.datasets.mnist import MNIST else: diff --git a/pl_examples/basic_examples/mnist_datamodule.py b/pl_examples/basic_examples/mnist_datamodule.py index 46acc5a3a2a14..a50f67cdab301 100644 --- a/pl_examples/basic_examples/mnist_datamodule.py +++ b/pl_examples/basic_examples/mnist_datamodule.py @@ -17,10 +17,10 @@ from torch.utils.data import DataLoader, random_split -from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE +from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE from pytorch_lightning import LightningDataModule -if _TORCHVISION_AVAILABLE: +if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import MNIST else: diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index 2aa2a9f73db8b..285fba8b93f1b 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -26,15 +26,19 @@ import torch import torch.nn as nn import torch.nn.functional as F # noqa -import torchvision -import torchvision.transforms as transforms from torch.utils.data import DataLoader -from torchvision.datasets import MNIST -from pl_examples import cli_lightning_logo +from pl_examples import _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo from pytorch_lightning.core import LightningDataModule, LightningModule from pytorch_lightning.trainer import Trainer +if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: + import torchvision + import torchvision.transforms as transforms + from torchvision.datasets import MNIST +else: + from tests.helpers.datasets import MNIST + class Generator(nn.Module): """ diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index a10ada06af109..3f82ab3565403 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -19,7 +19,7 @@ import logging from copy import deepcopy from functools import partial -from typing import Any, Callable, List, Optional, Tuple, Union, Dict +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn.utils.prune as pytorch_prune @@ -27,7 +27,7 @@ from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_debug +from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException log = logging.getLogger(__name__) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 5817987520475..de799b394fe69 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union, Dict +from typing import Any, Callable, Dict, Iterable, Optional, TYPE_CHECKING, Union import torch from torch.nn import Module diff --git a/tests/utilities/test_parsing.py b/tests/utilities/test_parsing.py index da33ffa168e79..391b0d9c97f0d 100644 --- a/tests/utilities/test_parsing.py +++ b/tests/utilities/test_parsing.py @@ -11,25 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import pytest import inspect +import pytest + from pytorch_lightning.utilities.parsing import ( AttributeDict, clean_namespace, + collect_init_args, flatten_dict, + get_init_args, is_picklable, lightning_getattr, lightning_hasattr, lightning_setattr, parse_class_init_keys, - get_init_args, - collect_init_args, str_to_bool, str_to_bool_or_str, ) - unpicklable_function = lambda: None