From 06e8910f066a3198337db3db0467ca2f96deba19 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 31 Jul 2020 11:18:32 +0200 Subject: [PATCH 1/8] pytorch 1.6 (#2745) * pt 1.6 * don't use the new zipfile serialization for now * quick flake8 fixes * remove unnecessary f * coalesce strings * remove comma * remove extra commas * Apply suggestions from code review Co-authored-by: Peter Yu <2057325+yukw777@users.noreply.github.com> * set _use_new_zipfile_serialization to False only for pytorch 1.6.0 * remove unnecessary comments * flake8 fixes * use pkg_resources instead of packaging * readme * format * version * chlog Co-authored-by: Peter Yu Co-authored-by: Peter Yu <2057325+yukw777@users.noreply.github.com> --- .codecov.yml | 4 +- .github/workflows/ci-testing.yml | 8 +-- .github/workflows/docker-builds.yml | 2 +- .github/workflows/pt-conda.yml | 2 +- CHANGELOG.md | 2 + README.md | 16 ++--- .../trainer/distrib_data_parallel.py | 26 +++++--- pytorch_lightning/trainer/training_io.py | 23 ++++--- requirements/base.txt | 2 +- tests/base/datamodules.py | 3 +- tests/callbacks/test_model_checkpoint.py | 24 +++---- tests/models/test_restore.py | 2 +- tests/models/test_tpu.py | 62 ++++--------------- tests/trainer/test_trainer.py | 4 +- 14 files changed, 77 insertions(+), 103 deletions(-) diff --git a/.codecov.yml b/.codecov.yml index 6abf8b1a16a8d0..3bcfe7fb9f624e 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -9,7 +9,7 @@ codecov: strict_yaml_branch: "yaml-config" require_ci_to_pass: yes notify: - after_n_builds: 21 + after_n_builds: 22 wait_for_ci: yes # https://docs.codecov.io/docs/codecov-yaml#section-expired-reports max_report_age: off @@ -50,4 +50,4 @@ comment: layout: header, diff require_changes: false behavior: default # update if exists else create new - after_n_builds: 21 + after_n_builds: 22 diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index a3b9f42ef952af..df2e0c1aa66d83 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -57,10 +57,10 @@ jobs: python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if not line.startswith('horovod')] ; open(fname, 'w').writelines(lines)" # TODO: remove after https://github.com/pytorch/pytorch/issues/32186 is resolved - - name: Setup Windows on Latest - if: runner.os == 'windows' && matrix.requires == 'latest' - run: | - python -c "fname = 'requirements/base.txt' ; req = open(fname).read().replace('torch>=1.3', 'torch<1.5') ; open(fname, 'w').write(req)" + #- name: Setup Windows on Latest + # if: runner.os == 'windows' && matrix.requires == 'latest' + # run: | + # python -c "fname = 'requirements/base.txt' ; req = open(fname).read().replace('torch>=1.3', 'torch<1.5') ; open(fname, 'w').write(req)" # versions <= 1.3 may have issues on mac with some BLAS ops due to missing mkl (https://github.com/pytorch/pytorch/issues/18996) - name: Setup MacOS Minimal diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 714d15f191a373..d596ea764389ee 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -15,7 +15,7 @@ jobs: strategy: matrix: python_version: [3.6, 3.7, 3.8] - pytorch_version: [1.3, 1.4, 1.5] + pytorch_version: [1.3, 1.4, 1.5, 1.6] exclude: # excludes PT 1.3 as it is missing on pypi - python_version: 3.8 diff --git a/.github/workflows/pt-conda.yml b/.github/workflows/pt-conda.yml index 98c9a038ed23fa..342d2bfb5cc1a8 100644 --- a/.github/workflows/pt-conda.yml +++ b/.github/workflows/pt-conda.yml @@ -23,7 +23,7 @@ jobs: os: [ubuntu-20.04] python-version: [3.7] # todo: add nightly versions - pytorch-version: [1.3, 1.4, 1.5] # , 1.6, 1.7 + pytorch-version: [1.3, 1.4, 1.5, 1.6] # , 1.7 # Timeout: https://stackoverflow.com/a/59076067/4521646 timeout-minutes: 20 diff --git a/CHANGELOG.md b/CHANGELOG.md index f1382f621d2eba..a4cd39873d84a5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `Trainer(num_sanity_val_steps=-1)` to check all validation data before training ([#2246](https://github.com/PyTorchLightning/pytorch-lightning/pull/2246)) +- Added support for PyTorch 1.6 ([#2745](https://github.com/PyTorchLightning/pytorch-lightning/pull/2745)) + ### Changed - Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594)) diff --git a/README.md b/README.md index 8e4de149192860..f0f022e600f4f5 100644 --- a/README.md +++ b/README.md @@ -38,14 +38,14 @@ ## Continuous Integration
-| System / PyTorch ver. | 1.3 (min. req.) | 1.4 | 1.5 (latest) | -| :---: | :---: | :---: | :---: | -| Conda py3.7 [linux] | ![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg) | ![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg) | ![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg) | -| Linux py3.7 [GPU] | - | - | [![Build Status](http://35.192.60.23/api/badges/PyTorchLightning/pytorch-lightning/status.svg)](http://35.192.60.23/PyTorchLightning/pytorch-lightning) | -| Linux py3.7 [TPU] | - | - | ![TPU tests](https://github.com/PyTorchLightning/pytorch-lightning/workflows/TPU%20tests/badge.svg) | -| Linux py3.6 / py3.7 / py3.8 | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | -| OSX py3.6 / py3.7 | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | -| Windows py3.6 / py3.7 / py3.8 | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) |[![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | +| System / PyTorch ver. | 1.3 (min. req.) [w/o py3.8] | 1.4 | 1.5 | 1.6 (latest) | +| :---: | :---: | :---: | :---: | :---: | +| Conda py3.7 [linux] | ![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg) | ![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg) | ![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg) | ![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg) | +| Linux py3.7 [GPU] | - | - | - | [![Build Status](http://35.192.60.23/api/badges/PyTorchLightning/pytorch-lightning/status.svg)](http://35.192.60.23/PyTorchLightning/pytorch-lightning) | +| Linux py3.7 [TPU] | - | - | - | ![TPU tests](https://github.com/PyTorchLightning/pytorch-lightning/workflows/TPU%20tests/badge.svg) | +| Linux py3.6 / py3.7 / py3.8 | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | +| OSX py3.6 / py3.7 | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | +| Windows py3.6 / py3.7 / py3.8 | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22)
diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 7b2c504488bd98..3ca5f6ffa68f3c 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -130,12 +130,14 @@ def train_fx(trial_hparams, cluster_manager, _): import os import re from abc import ABC, abstractmethod +from distutils.version import LooseVersion from typing import Union, List, Optional, Callable, Tuple import subprocess import sys from time import sleep import numpy as np from os.path import abspath +from pkg_resources import parse_version import torch from pytorch_lightning import _logger as log @@ -273,9 +275,11 @@ def set_distributed_mode(self, distributed_backend): elif self.num_gpus == 1: self.use_single_gpu = True elif self.num_gpus > 1: - rank_zero_warn('You requested multiple GPUs but did not specify a backend, e.g.' - ' Trainer(distributed_backend=dp) (or ddp, ddp2).' - ' Setting distributed_backend=ddp_spawn for you.') + rank_zero_warn( + 'You requested multiple GPUs but did not specify a backend, e.g.' + ' Trainer(distributed_backend=dp) (or ddp, ddp2).' + ' Setting distributed_backend=ddp_spawn for you.' + ) self.distributed_backend = 'ddp_spawn' distributed_backend = 'ddp_spawn' @@ -304,8 +308,9 @@ def set_distributed_mode(self, distributed_backend): self.use_ddp2 = True elif distributed_backend == "ddp_cpu": if self.num_gpus > 0: - rank_zero_warn('You requested one or more GPUs, but set the backend to `ddp_cpu`.' - ' Training will not use GPUs.') + rank_zero_warn( + 'You requested one or more GPUs, but set the backend to `ddp_cpu`. Training will not use GPUs.' + ) self.use_ddp = True self.data_parallel_device_ids = None self.on_gpu = False @@ -378,8 +383,7 @@ def determine_ddp_node_rank(self): if len(node_ids) == 0: return 0 if len(node_ids) > 1: - log.warning(f"Multiple environment variables ({node_ids}) defined for node rank. " - f"Using the first one.") + log.warning(f"Multiple environment variables ({node_ids}) defined for node rank. Using the first one.") k, rank = node_ids.pop() rank_zero_info(f"Using environment variable {k} for node rank ({rank}).") return int(rank) @@ -610,7 +614,13 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): last_path = None if not self.testing and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) - torch.save(model.state_dict(), last_path) + # Can't use the new zipfile serialization for 1.6.0 because there's a bug in + # torch.hub.load_state_dict_from_url() that prevents it from loading the new files. + # More details can be found here: https://github.com/pytorch/pytorch/issues/42239 + if LooseVersion(torch.__version__).version[:3] == [1, 6, 0]: + torch.save(model.state_dict(), last_path, _use_new_zipfile_serialization=False) + else: + torch.save(model.state_dict(), last_path) mp_queue.put(last_path) def save_spawn_weights(self, model): diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 3234800a62c51c..381331fae2b096 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -87,7 +87,9 @@ import re import signal from abc import ABC +from distutils.version import LooseVersion from subprocess import call +from pkg_resources import parse_version import torch import torch.distributed as torch_distrib @@ -151,8 +153,7 @@ class TrainerIOMixin(ABC): scaler: ... def get_model(self): - is_dp_module = isinstance(self.model, (LightningDistributedDataParallel, - LightningDataParallel)) + is_dp_module = isinstance(self.model, (LightningDistributedDataParallel, LightningDataParallel)) model = self.model.module if is_dp_module else self.model return model @@ -261,7 +262,13 @@ def _atomic_save(self, checkpoint, filepath: str): This points to the file that the checkpoint will be stored in. """ tmp_path = str(filepath) + ".part" - torch.save(checkpoint, tmp_path) + # Can't use the new zipfile serialization for 1.6.0 because there's a bug in + # torch.hub.load_state_dict_from_url() that prevents it from loading the new files. + # More details can be found here: https://github.com/pytorch/pytorch/issues/42239 + if LooseVersion(torch.__version__).version[:3] == [1, 6, 0]: + torch.save(checkpoint, tmp_path, _use_new_zipfile_serialization=False) + else: + torch.save(checkpoint, tmp_path) os.replace(tmp_path, filepath) def save_checkpoint(self, filepath, weights_only: bool = False): @@ -274,8 +281,9 @@ def save_checkpoint(self, filepath, weights_only: bool = False): except AttributeError as err: if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] - rank_zero_warn('Warning, `module_arguments` dropped from checkpoint.' - f' An attribute is not picklable {err}') + rank_zero_warn( + 'Warning, `module_arguments` dropped from checkpoint.' f' An attribute is not picklable {err}' + ) self._atomic_save(checkpoint, filepath) def restore(self, checkpoint_path: str, on_gpu: bool): @@ -493,8 +501,9 @@ def hpc_save(self, folderpath: str, logger): except AttributeError as err: if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] - rank_zero_warn('warning, `module_arguments` dropped from checkpoint.' - f' An attribute is not picklable {err}') + rank_zero_warn( + 'warning, `module_arguments` dropped from checkpoint.' f' An attribute is not picklable {err}' + ) self._atomic_save(checkpoint, filepath) return filepath diff --git a/requirements/base.txt b/requirements/base.txt index 8eff26906e81eb..4072df9466dc74 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,7 +1,7 @@ # the default package dependencies numpy>=1.16.4 -torch>=1.3, <1.6 # TODO: temporary freeze for Horovod incompatibility with 1.6 +torch>=1.3 tensorboard>=1.14 future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 diff --git a/tests/base/datamodules.py b/tests/base/datamodules.py index 23c07f93d46976..d863c85605af78 100644 --- a/tests/base/datamodules.py +++ b/tests/base/datamodules.py @@ -5,7 +5,6 @@ class TrialMNISTDataModule(LightningDataModule): - def __init__(self, data_dir: str = './'): super().__init__() self.data_dir = data_dir @@ -13,7 +12,7 @@ def __init__(self, data_dir: str = './'): def prepare_data(self): TrialMNIST(self.data_dir, train=True, download=True) TrialMNIST(self.data_dir, train=False, download=True) - + def setup(self): mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True) self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64]) diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py index 4cb52a54610e37..71b38ac6980adb 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/callbacks/test_model_checkpoint.py @@ -21,19 +21,13 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k) - trainer = Trainer( - default_root_dir=tmpdir, - checkpoint_callback=checkpoint, - overfit_batches=0.20, - max_epochs=2, - ) + trainer = Trainer(default_root_dir=tmpdir, checkpoint_callback=checkpoint, overfit_batches=0.20, max_epochs=2) trainer.fit(model) - assert checkpoint.dirpath == tmpdir / trainer.logger.name / f'version_0' / 'checkpoints' + assert checkpoint.dirpath == tmpdir / trainer.logger.name / 'version_0' / 'checkpoints' @pytest.mark.parametrize( - 'logger_version,expected', - [(None, 'version_0'), (1, 'version_1'), ('awesome', 'awesome')], + 'logger_version,expected', [(None, 'version_0'), (1, 'version_1'), ('awesome', 'awesome')], ) def test_model_checkpoint_path(tmpdir, logger_version, expected): """Test that "version_" prefix is only added when logger's version is an integer""" @@ -41,12 +35,7 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected): model = EvalModelTemplate() logger = TensorBoardLogger(str(tmpdir), version=logger_version) - trainer = Trainer( - default_root_dir=tmpdir, - overfit_batches=0.2, - max_epochs=2, - logger=logger, - ) + trainer = Trainer(default_root_dir=tmpdir, overfit_batches=0.2, max_epochs=2, logger=logger) trainer.fit(model) ckpt_version = Path(trainer.checkpoint_callback.dirpath).parent.name @@ -83,8 +72,9 @@ def _save_model(self, filepath, trainer, pl_module): def on_train_end(self, trainer, pl_module): super().on_train_end(trainer, pl_module) # on rank 0 we expect the saved files and on all others no saves - assert (trainer.global_rank == 0 and self.count == self.expected_count) \ - or (trainer.global_rank > 0 and self.count == 0) + assert (trainer.global_rank == 0 and self.count == self.expected_count) or ( + trainer.global_rank > 0 and self.count == 0 + ) @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 111ee869684c32..85c609b94d7655 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -309,7 +309,7 @@ def test_model_saving_loading(tmpdir): hparams_path = os.path.join(hparams_path, 'hparams.yaml') model_2 = EvalModelTemplate.load_from_checkpoint( checkpoint_path=new_weights_path, - hparams_file=hparams_path + hparams_file=hparams_path, ) model_2.eval() diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 8bb2a76a24313d..ccc68cb59b5688 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -13,6 +13,7 @@ try: import torch_xla import torch_xla.distributed.xla_multiprocessing as xmp + SERIAL_EXEC = xmp.MpSerialExecutor() # TODO: The tests are aborted if the following lines are uncommented. Must be resolved with XLA team # device = torch_xla.core.xla_model.xla_device() @@ -24,19 +25,12 @@ TPU_AVAILABLE = True -_LARGER_DATASET = TrialMNIST( - download=True, - num_samples=2000, - digits=(0, 1, 2, 5, 8), -) +_LARGER_DATASET = TrialMNIST(download=True, num_samples=2000, digits=(0, 1, 2, 5, 8)) # 8 cores needs a big dataset def _serial_train_loader(): - return DataLoader( - _LARGER_DATASET, - batch_size=32, - ) + return DataLoader(_LARGER_DATASET, batch_size=32) @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") @@ -200,28 +194,10 @@ def test_model_16bit_tpu_index_5(tmpdir): assert os.environ.get('XLA_USE_BF16') == str(1), "XLA_USE_BF16 was not set in environment variables" +@pytest.mark.parametrize('tpu_core', [1, 5]) @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") @pl_multi_process_test -def test_early_stop_checkpoints_on_tpu(tmpdir): - """Test if single TPU core training works""" - model = EvalModelTemplate() - trainer = Trainer( - early_stop_callback=True, - default_root_dir=tmpdir, - progress_bar_refresh_rate=0, - max_epochs=50, - limit_train_batches=10, - limit_val_batches=10, - distributed_backend='tpu', - tpu_cores=[1], - ) - trainer.fit(model) - assert torch_xla._XLAC._xla_get_default_device() == 'xla:1' - - -@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") -@pl_multi_process_test -def test_early_stop_checkpoints_on_tpu(tmpdir): +def test_train_on_single_tpu(tmpdir, tpu_core): """Test if single TPU core training works""" model = EvalModelTemplate() trainer = Trainer( @@ -232,7 +208,7 @@ def test_early_stop_checkpoints_on_tpu(tmpdir): limit_train_batches=10, limit_val_batches=10, distributed_backend='tpu', - tpu_cores=[5], + tpu_cores=[tpu_core], ) trainer.fit(model) assert torch_xla._XLAC._xla_get_default_device() == 'xla:5' @@ -264,26 +240,15 @@ def test_dataloaders_passed_to_fit(tmpdir): model = EvalModelTemplate() - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - distributed_backend='tpu', - tpu_cores=8, - ) - result = trainer.fit( - model, - train_dataloader=model.train_dataloader(), - val_dataloaders=model.val_dataloader(), - ) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, distributed_backend='tpu', tpu_cores=8) + result = trainer.fit(model, train_dataloader=model.train_dataloader(), val_dataloaders=model.val_dataloader()) assert result, "TPU doesn't work with dataloaders passed to fit()." -@pytest.mark.parametrize(['tpu_cores', 'expected_tpu_id'], [ - pytest.param(1, None), - pytest.param(8, None), - pytest.param([1], 1), - pytest.param([8], 8), -]) +@pytest.mark.parametrize( + ['tpu_cores', 'expected_tpu_id'], + [pytest.param(1, None), pytest.param(8, None), pytest.param([1], 1), pytest.param([8], 8)], +) def test_tpu_id_to_be_as_expected(tpu_cores, expected_tpu_id): """Test if trainer.tpu_id is set as expected""" assert Trainer(tpu_cores=tpu_cores).tpu_id == expected_tpu_id @@ -293,8 +258,7 @@ def test_tpu_misconfiguration(): """Test if trainer.tpu_id is set as expected""" with pytest.raises(MisconfigurationException, match="`tpu_cores` can only be"): Trainer( - tpu_cores=[1, 8], - distributed_backend='tpu', + tpu_cores=[1, 8], distributed_backend='tpu', ) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 27fa6815509620..c7652ebecf3f9a 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -61,7 +61,7 @@ def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): ckpt_path = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}' if url_ckpt else new_weights_path model_2 = EvalModelTemplate.load_from_checkpoint( checkpoint_path=ckpt_path, - hparams_file=hparams_path + hparams_file=hparams_path, ) model_2.eval() @@ -99,7 +99,7 @@ def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): ckpt_path = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}' if url_ckpt else new_weights_path model_2 = EvalModelTemplate.load_from_checkpoint( checkpoint_path=ckpt_path, - hparams_file=hparams_path + hparams_file=hparams_path, ) model_2.eval() From b7afac351b61a1f90e9b0611e267731058c8cda0 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Fri, 31 Jul 2020 15:57:57 +0530 Subject: [PATCH 2/8] Add onnx export (#2596) * export model to onnx * prepare data before exporting * support for dataloaders and tensors * added tests * use example_input_array add to changelog * updated docstring * added onnx inference tests * temp commit * removed schema valid test * add onnxruntime to environment.yml * moved onnxruntime to environment.yml pip * add example in doc * add lines between code block * added PR to changelog * is file check Co-authored-by: Jirka Borovec * remove * Co-authored-by: Jirka Borovec * infer example outputs * added doctest for onnx * fix windows tests * moved eval within condition block * self.forward to self * added docs * fixed docs error * added to toctree * Update CHANGELOG.md Co-authored-by: Jirka Borovec --- CHANGELOG.md | 2 + docs/source/index.rst | 1 + docs/source/production_inference.rst | 28 +++++++ environment.yml | 1 + pytorch_lightning/core/lightning.py | 39 +++++++++ requirements/extra.txt | 2 + tests/base/model_template.py | 4 +- tests/models/test_onnx_save.py | 114 +++++++++++++++++++++++++++ 8 files changed, 188 insertions(+), 3 deletions(-) create mode 100644 docs/source/production_inference.rst create mode 100644 tests/models/test_onnx_save.py diff --git a/CHANGELOG.md b/CHANGELOG.md index a4cd39873d84a5..ed0880c345765f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added SSIM metrics ([#2671](https://github.com/PyTorchLightning/pytorch-lightning/pull/2671)) - Added BLEU metrics ([#2535](https://github.com/PyTorchLightning/pytorch-lightning/pull/2535)) +- Added support to export a model to ONNX format ([#2596](https://github.com/PyTorchLightning/pytorch-lightning/pull/2596)) + - Added support for `Trainer(num_sanity_val_steps=-1)` to check all validation data before training ([#2246](https://github.com/PyTorchLightning/pytorch-lightning/pull/2246)) - Added support for PyTorch 1.6 ([#2745](https://github.com/PyTorchLightning/pytorch-lightning/pull/2745)) diff --git a/docs/source/index.rst b/docs/source/index.rst index 4b1b7c697a6c80..3637892848b513 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -99,6 +99,7 @@ PyTorch Lightning Documentation transfer_learning tpu test_set + production_inference .. toctree:: :maxdepth: 1 diff --git a/docs/source/production_inference.rst b/docs/source/production_inference.rst new file mode 100644 index 00000000000000..3159abe630b686 --- /dev/null +++ b/docs/source/production_inference.rst @@ -0,0 +1,28 @@ +Inference in Production +======================= +PyTorch Lightning eases the process of deploying models into production. + + +Exporting to ONNX +----------------- +PyTorch Lightning provides a handy function to quickly export your model to ONNX format, which allows the model to be independent of PyTorch and run on an ONNX Runtime. + +To export your model to ONNX format call the `to_onnx` function on your Lightning Module with the filepath and input_sample. + +.. code-block:: python + + filepath = 'model.onnx' + model = SimpleModel() + input_sample = torch.randn((1, 64)) + model.to_onnx(filepath, input_sample, export_params=True) + +You can also skip passing the input sample if the `example_input_array` property is specified in your LightningModule. + +Once you have the exported model, you can run it on your ONNX runtime in the following way: + +.. code-block:: python + + ort_session = onnxruntime.InferenceSession(filepath) + input_name = ort_session.get_inputs()[0].name + ort_inputs = {input_name: np.random.randn(1, 64).astype(np.float32)} + ort_outs = ort_session.run(None, ort_inputs) diff --git a/environment.yml b/environment.yml index 9c48f6d7e2c398..07afe8055753f2 100644 --- a/environment.yml +++ b/environment.yml @@ -48,3 +48,4 @@ dependencies: - wandb>=0.8.21 - neptune-client>=0.4.109 - horovod>=0.19.1 + - onnxruntime>=1.3.0 diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index afb9fa0a9266a6..5ff64156e1b337 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -2,6 +2,7 @@ import inspect import os import re +import tempfile from abc import ABC, abstractmethod from argparse import Namespace from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union @@ -1723,6 +1724,44 @@ def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None: else: self._hparams = hp + def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwargs): + """Saves the model in ONNX format + + Args: + file_path: The path of the file the model should be saved to. + input_sample: A sample of an input tensor for tracing. + **kwargs: Will be passed to torch.onnx.export function. + + Example: + >>> class SimpleModel(LightningModule): + ... def __init__(self): + ... super().__init__() + ... self.l1 = torch.nn.Linear(in_features=64, out_features=4) + ... + ... def forward(self, x): + ... return torch.relu(self.l1(x.view(x.size(0), -1))) + + >>> with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile: + ... model = SimpleModel() + ... input_sample = torch.randn((1, 64)) + ... model.to_onnx(tmpfile.name, input_sample, export_params=True) + ... os.path.isfile(tmpfile.name) + True + """ + + if isinstance(input_sample, Tensor): + input_data = input_sample + elif self.example_input_array is not None: + input_data = self.example_input_array + else: + raise ValueError(f'input_sample and example_input_array tensors are both missing.') + + if 'example_outputs' not in kwargs: + self.eval() + kwargs['example_outputs'] = self(input_data) + + torch.onnx.export(self, input_data, file_path, **kwargs) + @property def hparams(self) -> Union[AttributeDict, str]: if not hasattr(self, '_hparams'): diff --git a/requirements/extra.txt b/requirements/extra.txt index 191d24125a21d6..31ea41c083d4b8 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -12,3 +12,5 @@ omegaconf>=2.0.0 # scipy>=0.13.3 scikit-learn>=0.20.0 torchtext>=0.3.1, <0.7 # TODO: temporary fix fix for compatibility +onnx>=1.7.0 +onnxruntime>=1.3.0 \ No newline at end of file diff --git a/tests/base/model_template.py b/tests/base/model_template.py index f529ce5735b89a..19fcd42195b96b 100644 --- a/tests/base/model_template.py +++ b/tests/base/model_template.py @@ -73,9 +73,7 @@ def __init__( self.test_step_end_called = False self.test_epoch_end_called = False - # if you specify an example input, the summary will show input/output for each layer - # TODO: to be fixed in #1773 - # self.example_input_array = torch.rand(5, 28 * 28) + self.example_input_array = torch.rand(5, 28 * 28) # build model self.__build_model() diff --git a/tests/models/test_onnx_save.py b/tests/models/test_onnx_save.py new file mode 100644 index 00000000000000..f824f33c93bc14 --- /dev/null +++ b/tests/models/test_onnx_save.py @@ -0,0 +1,114 @@ +import os + +import onnxruntime +import pytest +import torch +import numpy as np +import tests.base.develop_pipelines as tpipes +import tests.base.develop_utils as tutils +from pytorch_lightning import Trainer +from tests.base import EvalModelTemplate + + +def test_model_saves_with_input_sample(tmpdir): + """Test that ONNX model saves with input sample and size is greater than 3 MB""" + model = EvalModelTemplate() + trainer = Trainer(max_epochs=1) + trainer.fit(model) + + file_path = os.path.join(tmpdir, "model.onxx") + input_sample = torch.randn((1, 28 * 28)) + model.to_onnx(file_path, input_sample) + assert os.path.isfile(file_path) + assert os.path.getsize(file_path) > 3e+06 + + +def test_model_saves_with_example_output(tmpdir): + """Test that ONNX model saves when provided with example output""" + model = EvalModelTemplate() + trainer = Trainer(max_epochs=1) + trainer.fit(model) + + file_path = os.path.join(tmpdir, "model.onxx") + input_sample = torch.randn((1, 28 * 28)) + model.eval() + example_outputs = model.forward(input_sample) + model.to_onnx(file_path, input_sample, example_outputs=example_outputs) + assert os.path.exists(file_path) is True + + +def test_model_saves_with_example_input_array(tmpdir): + """Test that ONNX model saves with_example_input_array and size is greater than 3 MB""" + model = EvalModelTemplate() + file_path = os.path.join(tmpdir, "model.onxx") + model.to_onnx(file_path) + assert os.path.exists(file_path) is True + assert os.path.getsize(file_path) > 3e+06 + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_model_saves_on_multi_gpu(tmpdir): + """Test that ONNX model saves on a distributed backend""" + tutils.set_random_master_port() + + trainer_options = dict( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=10, + limit_val_batches=10, + gpus=[0, 1], + distributed_backend='ddp_spawn', + progress_bar_refresh_rate=0 + ) + + model = EvalModelTemplate() + + tpipes.run_model_test(trainer_options, model) + + file_path = os.path.join(tmpdir, "model.onxx") + model.to_onnx(file_path) + assert os.path.exists(file_path) is True + + +def test_verbose_param(tmpdir, capsys): + """Test that output is present when verbose parameter is set""" + model = EvalModelTemplate() + file_path = os.path.join(tmpdir, "model.onxx") + model.to_onnx(file_path, verbose=True) + captured = capsys.readouterr() + assert "graph(%" in captured.out + + +def test_error_if_no_input(tmpdir): + """Test that an exception is thrown when there is no input tensor""" + model = EvalModelTemplate() + model.example_input_array = None + file_path = os.path.join(tmpdir, "model.onxx") + with pytest.raises(ValueError, match=r'input_sample and example_input_array tensors are both missing'): + model.to_onnx(file_path) + + +def test_if_inference_output_is_valid(tmpdir): + """Test that the output inferred from ONNX model is same as from PyTorch""" + model = EvalModelTemplate() + trainer = Trainer(max_epochs=5) + trainer.fit(model) + + model.eval() + with torch.no_grad(): + torch_out = model(model.example_input_array) + + file_path = os.path.join(tmpdir, "model.onxx") + model.to_onnx(file_path, model.example_input_array, export_params=True) + + ort_session = onnxruntime.InferenceSession(file_path) + + def to_numpy(tensor): + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + + # compute ONNX Runtime output prediction + ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(model.example_input_array)} + ort_outs = ort_session.run(None, ort_inputs) + + # compare ONNX Runtime and PyTorch results + assert np.allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) From fcfdb4df13628b689b64e6e71ceae6b57d4b20a4 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 31 Jul 2020 12:31:23 +0200 Subject: [PATCH 3/8] conda speedup (#2546) * conda speedup * cache * add pip cache * suggestion * cache * cache * req --- .github/workflows/ci-test-base.yml | 4 ++-- .github/workflows/ci-testing.yml | 20 ++++++++++++-------- .github/workflows/code-formatting.yml | 6 +++--- .github/workflows/docs-checks.yml | 4 ++-- .github/workflows/pt-conda.yml | 22 ++++++++++++++++------ requirements/base.txt | 2 +- 6 files changed, 36 insertions(+), 22 deletions(-) diff --git a/.github/workflows/ci-test-base.yml b/.github/workflows/ci-test-base.yml index e1f64150544806..7def5ca4b1ee90 100644 --- a/.github/workflows/ci-test-base.yml +++ b/.github/workflows/ci-test-base.yml @@ -47,7 +47,7 @@ jobs: python -c "from pip._internal.locations import USER_CACHE_DIR; print('::set-output name=dir::' + USER_CACHE_DIR)" - name: Cache pip - uses: actions/cache@v1 + uses: actions/cache@v2 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements/base.txt') }} @@ -66,7 +66,7 @@ jobs: shell: bash - name: Cache datasets - uses: actions/cache@v1 + uses: actions/cache@v2 with: path: Datasets # This path is specific to Ubuntu # Look to see if there is a cache hit for the corresponding requirements file diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index df2e0c1aa66d83..932e70bcd2b48a 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -44,6 +44,10 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Update Pip + run: | + pip install -U -q "pip>=20.1" # needed for get pip cacher folder + # Github Actions: Run step on specific OS: https://stackoverflow.com/a/57948488/4521646 - name: Setup macOS if: runner.os == 'macOS' @@ -77,23 +81,23 @@ jobs: # Note: This uses an internal pip API and may not always work # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow - - name: Get pip cache + - name: Get pip cache dir id: pip-cache run: | - python -c "from pip._internal.locations import USER_CACHE_DIR; print('::set-output name=dir::' + USER_CACHE_DIR)" + echo "::set-output name=dir::$(pip cache dir)" - - name: Cache pip - uses: actions/cache@v1 + - name: pip cache + uses: actions/cache@v2 with: path: ${{ steps.pip-cache.outputs.dir }} - key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements/base.txt') }}-${{ hashFiles('requirements/extra.txt') }} + key: ${{ runner.os }}-pip-py${{ matrix.python-version }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements/base.txt') }}-${{ hashFiles('requirements/extra.txt') }} restore-keys: | - ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ matrix.requires }}-pip- + ${{ runner.os }}-pip-py${{ matrix.python-version }}-${{ matrix.requires }}-pip- - name: Install dependencies run: | - # python -m pip install --upgrade --user pip pip install --requirement requirements/base.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet --upgrade-strategy only-if-needed + # pip install -q "PyYAML>=5.3.1" # needed for installing dependencues HOROVOD_BUILD_ARCH_FLAGS="-mfma" pip install --requirement ./requirements/devel.txt --quiet --upgrade-strategy "only-if-needed" python --version pip --version @@ -112,7 +116,7 @@ jobs: shell: bash - name: Cache datasets - uses: actions/cache@v1 + uses: actions/cache@v2 with: path: Datasets # This path is specific to Ubuntu # Look to see if there is a cache hit for the corresponding requirements file diff --git a/.github/workflows/code-formatting.yml b/.github/workflows/code-formatting.yml index 9938be4d9ff866..8ae2f6829bb48e 100644 --- a/.github/workflows/code-formatting.yml +++ b/.github/workflows/code-formatting.yml @@ -42,12 +42,12 @@ jobs: # Note: This uses an internal pip API and may not always work # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow - name: Cache pip - uses: actions/cache@v1 + uses: actions/cache@v2 with: path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('requirements/base.txt') }}-${{ hashFiles('requirements/extra.txt') }} + key: ${{ runner.os }}-pip-extras-${{ hashFiles('requirements/base.txt') }}-${{ hashFiles('requirements/extra.txt') }} restore-keys: | - ${{ runner.os }}-pip- + ${{ runner.os }}-pip-extras- - name: Install dependencies run: | diff --git a/.github/workflows/docs-checks.yml b/.github/workflows/docs-checks.yml index c4d4893cd668cd..b9442216052996 100644 --- a/.github/workflows/docs-checks.yml +++ b/.github/workflows/docs-checks.yml @@ -29,7 +29,7 @@ jobs: # Note: This uses an internal pip API and may not always work # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow - name: Cache pip - uses: actions/cache@v1 + uses: actions/cache@v2 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('requirements/base.txt') }} @@ -67,7 +67,7 @@ jobs: # Note: This uses an internal pip API and may not always work # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow - name: Cache pip - uses: actions/cache@v1 + uses: actions/cache@v2 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('requirements/base.txt') }} diff --git a/.github/workflows/pt-conda.yml b/.github/workflows/pt-conda.yml index 342d2bfb5cc1a8..c60c21be98a078 100644 --- a/.github/workflows/pt-conda.yml +++ b/.github/workflows/pt-conda.yml @@ -38,12 +38,21 @@ jobs: # TODO: set source for nightly - name: Cache conda - uses: actions/cache@v1 - env: # Increase this value to reset cache if etc/example-environment.yml has not changed - CACHE_NUMBER: 0 + uses: actions/cache@v2 with: path: ~/conda_pkgs_dir - key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{ hashFiles('environment.yml') }} + key: ${{ runner.os }}-conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}-${{ hashFiles('environment.yml') }} + restore-keys: | + ${{ runner.os }}-conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}- + + # Add another cache for Pip as not all packages lives in Conda env + - name: Cache pip + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}-${{ hashFiles('requirements/base.txt') }} + restore-keys: | + ${{ runner.os }}-pip-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}- # https://docs.conda.io/projects/conda/en/4.6.0/_downloads/52a95608c49671267e40c689e0bc00ca/conda-cheatsheet.pdf # https://gist.github.com/mwouts/9842452d020c08faf9e84a3bba38a66f @@ -52,7 +61,8 @@ jobs: with: # auto-update-conda: true auto-activate-base: false - miniconda-version: 4.7.12 + # miniconda-version: 4.7.12 # This downloads a new conda, use the conda-version + conda-version: 4.7.12 python-version: ${{ matrix.python-version }} environment-file: environment.yml activate-environment: pl-env @@ -70,7 +80,7 @@ jobs: shell: bash -l {0} - name: Cache datasets - uses: actions/cache@v1 + uses: actions/cache@v2 with: path: Datasets # This path is specific to Ubuntu # Look to see if there is a cache hit for the corresponding requirements file diff --git a/requirements/base.txt b/requirements/base.txt index 4072df9466dc74..4282f6a12d2eb2 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -5,5 +5,5 @@ torch>=1.3 tensorboard>=1.14 future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 -PyYAML>=5.1 # OmegaConf requirement +PyYAML>=5.1 # OmegaConf requirement >=5.1 tqdm>=4.41.0 From 78a07e5f2d7393d5076f78e2bca4c62d547a6ea7 Mon Sep 17 00:00:00 2001 From: siahuat0727 Date: Fri, 31 Jul 2020 19:42:47 +0800 Subject: [PATCH 4/8] Fix doc typo (#2773) --- pytorch_lightning/trainer/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 188e9cf8de0dd7..c4425b036f1e5f 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -118,7 +118,7 @@ def forward(self, x): --------------- To ensure full reproducibility from run to run you need to set seeds for pseudo-random generators, -and set ``deterministic``` flag in ``Trainer``. +and set ``deterministic`` flag in ``Trainer``. Example:: From b88fc4387183c75bb98850db182e2a83d1ae14ab Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 31 Jul 2020 13:52:17 +0200 Subject: [PATCH 5/8] re-enable skipped tests (#2762) * re-enable skipped * timeout --- .github/workflows/ci-testing.yml | 4 ---- .github/workflows/pt-conda.yml | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 932e70bcd2b48a..d846950a375da0 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -30,10 +30,6 @@ jobs: # TODO: temporary fix till hanging jobs on macOS for py38 is resolved - python-version: 3.8 os: macOS-10.15 - # TODO: temporary fix till pyYaml can be installed, see: https://github.com/actions/setup-python/issues/114 - - python-version: 3.7 - os: ubuntu-18.04 - requires: 'minimal' # Timeout: https://stackoverflow.com/a/59076067/4521646 timeout-minutes: 25 diff --git a/.github/workflows/pt-conda.yml b/.github/workflows/pt-conda.yml index c60c21be98a078..a5e41530d23e5c 100644 --- a/.github/workflows/pt-conda.yml +++ b/.github/workflows/pt-conda.yml @@ -26,7 +26,7 @@ jobs: pytorch-version: [1.3, 1.4, 1.5, 1.6] # , 1.7 # Timeout: https://stackoverflow.com/a/59076067/4521646 - timeout-minutes: 20 + timeout-minutes: 35 steps: - uses: actions/checkout@v2 From a6719f09f0a383034f4285d65cba880208a03ae4 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf <42753790+thschaaf@users.noreply.github.com> Date: Fri, 31 Jul 2020 07:53:08 -0400 Subject: [PATCH 6/8] Bugfix/torchtext include lengths (#2689) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Test using torchtext.data.Field with include_lengths=True/False * Fix issue that Tensors in a Batch generated by torchtext with torchtext.data.Field configured as include_lengths=True * Add description for fix of issue #2688 * changes to accomodate CodeFactor issues * Another attemt to make last CodeFactor issue pass (it's a false alarm) * temporarly disable test of test_grad_tracking to check if testing will pass * reenable test in test_grad_norm * Update CHANGELOG.md Co-authored-by: Jirka Borovec * Renamed get_torchtext_data_iterator to _get_torchtext_data_iterator as suggested by @borda * Update pytorch_lightning/utilities/apply_func.py Co-authored-by: Adrian Wälchli * adding tests more specific to batch_move_data_to_device with tochtext Batch * added check that Tensors were moved to target device * removed tests using RNN models to be moved into a separate PR * fixing FLAKE8 errors that showed up after merge from master branch modified: tests/base/datamodules.py modified: tests/callbacks/test_model_checkpoint.py * parameterized test to reduce code duplication * Added check only if length tensor exist. Removed left over comments. * rearranged device parameterization and added pytest.param * Try to figure out why only one device is tested on Linux machines * Testing on CPU and GPU devices (GPU test is skip if no cuda device is available. * added test for TPU device (experimental) * Adding test parameterization for TPU test (experimental) * change import statement to limit what is imported for a TPU environment * made test work with TPU * Change to trigger CI * Change to trigger CI * uncommented TPU test to check CI * reenabling TPU test * small change to trigger CI build * small change to trigger CI build * small change to trigger CI build * adding tests/utilities/test_apply_func_torchtext.py to CI TPU test * try to make test not skipped on CI with TPU * remove testing on TPU * undo an accidental change to test_tpu.py (file should not have been touched) * small change to trigger CI build * small change to trigger CI build * Update tests/utilities/test_apply_func_torchtext.py * Revert to previous version * Apply suggestions from code review * Change to trigger CI Co-authored-by: Thomas Schaaf Co-authored-by: Jirka Borovec Co-authored-by: Adrian Wälchli Co-authored-by: Thomas Schaaf --- CHANGELOG.md | 2 + pytorch_lightning/utilities/apply_func.py | 9 ++-- tests/utilities/test_apply_func_torchtext.py | 52 ++++++++++++++++++++ 3 files changed, 59 insertions(+), 4 deletions(-) create mode 100644 tests/utilities/test_apply_func_torchtext.py diff --git a/CHANGELOG.md b/CHANGELOG.md index ed0880c345765f..c44db54cac5a1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed test metrics not being logged with `LoggerCollection` ([#2723](https://github.com/PyTorchLightning/pytorch-lightning/pull/2723)) +- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) + ## [0.8.5] - 2020-07-09 ### Added diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 6f9b8e176ffe1e..75130b297ddccd 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -6,6 +6,7 @@ import torch import importlib + TORCHTEXT_AVAILABLE = importlib.util.find_spec("torchtext") is not None if TORCHTEXT_AVAILABLE: from torchtext.data import Batch @@ -92,6 +93,7 @@ def move_data_to_device(batch: Any, device: torch.device): - :meth:`torch.Tensor.to` - :class:`torch.device` """ + def batch_to(data): # try to move torchtext data first if TORCHTEXT_AVAILABLE and isinstance(data, Batch): @@ -99,11 +101,10 @@ def batch_to(data): # Shallow copy because each Batch has a reference to Dataset which contains all examples device_data = copy(data) for field in data.fields: - # Batch contains output of Field.process(...) which is tensor hence .to(...) exists - device_field = getattr(data, field).to(device, non_blocking=True) + device_field = move_data_to_device(getattr(data, field), device) setattr(device_data, field, device_field) return device_data - else: - return data.to(device, non_blocking=True) + + return data.to(device, non_blocking=True) return apply_to_collection(batch, dtype=(TransferableDataType, Batch), function=batch_to) diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py new file mode 100644 index 00000000000000..9ea29420788d7e --- /dev/null +++ b/tests/utilities/test_apply_func_torchtext.py @@ -0,0 +1,52 @@ +import pytest +import torch +import torchtext +from torchtext.data.example import Example + +from pytorch_lightning.utilities.apply_func import move_data_to_device + + +def _get_torchtext_data_iterator(include_lengths=False): + text_field = torchtext.data.Field(sequential=True, pad_first=False, # nosec + init_token="", eos_token="", # nosec + include_lengths=include_lengths) # nosec + + example1 = Example.fromdict({"text": "a b c a c"}, {"text": ("text", text_field)}) + example2 = Example.fromdict({"text": "b c a a"}, {"text": ("text", text_field)}) + example3 = Example.fromdict({"text": "c b a"}, {"text": ("text", text_field)}) + + dataset = torchtext.data.Dataset( + [example1, example2, example3], + {"text": text_field}, + ) + text_field.build_vocab(dataset) + + iterator = torchtext.data.Iterator(dataset, batch_size=3, + sort_key=None, device=None, + batch_size_fn=None, + train=True, repeat=False, shuffle=None, + sort=None, sort_within_batch=None) + return iterator, text_field + + +@pytest.mark.parametrize('include_lengths', [False, True]) +@pytest.mark.parametrize(['device'], [pytest.param(torch.device('cuda', 0))]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test assumes GPU machine") +def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, device): + data_iterator, _ = _get_torchtext_data_iterator(include_lengths=include_lengths) + data_iter = iter(data_iterator) + batch = next(data_iter) + batch_on_device = move_data_to_device(batch, device) + + if include_lengths: + # tensor with data + assert (batch_on_device.text[0].device == device) + # tensor with length of data + assert (batch_on_device.text[1].device == device) + else: + assert (batch_on_device.text.device == device) + + +@pytest.mark.parametrize('include_lengths', [False, True]) +def test_batch_move_data_to_device_torchtext_include_lengths_cpu(include_lengths): + test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, torch.device('cpu')) From bc7a08fbe00860dc19c37bbd2575d0beefe8f57c Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 31 Jul 2020 14:23:13 +0200 Subject: [PATCH 7/8] test dockers & add AMP in pt-1.6 (#1584) * exist images * names * images * args * pt 1.6 dev * circleci * update * refactor * build * fix * MKL --- .drone.yml | 9 ++++--- .github/workflows/docker-builds.yml | 26 +++++++++++++++++++ dockers/README.md | 2 +- dockers/cuda-extras/Dockerfile | 40 +++++++++++++++++++++++++++++ tests/Dockerfile | 27 ------------------- tests/README.md | 2 +- 6 files changed, 73 insertions(+), 33 deletions(-) create mode 100644 dockers/cuda-extras/Dockerfile delete mode 100644 tests/Dockerfile diff --git a/.drone.yml b/.drone.yml index 71532d96ed5a48..67f0c38758dc93 100644 --- a/.drone.yml +++ b/.drone.yml @@ -6,12 +6,13 @@ name: torch-GPU steps: - name: testing - image: pytorchlightning/pytorch_lightning:devel-pt1.4 + image: pytorchlightning/pytorch_lightning:cuda-extras-py3.7-torch1.5 environment: SLURM_LOCALID: 0 CODECOV_TOKEN: from_secret: codecov_token + MKL_THREADING_LAYER: GNU HOROVOD_GPU_ALLREDUCE: NCCL HOROVOD_GPU_BROADCAST: NCCL HOROVOD_WITH_PYTORCH: 1 @@ -33,10 +34,10 @@ steps: - nvidia-smi #- bash ./tests/install_AMP.sh - apt-get update && apt-get install -y cmake - - pip install -r ./requirements/base.txt --user -q - - pip install -r ./requirements/devel.txt --user -q + - pip install -r ./requirements/base.txt --user -q --upgrade-strategy only-if-needed + - pip install -r ./requirements/devel.txt --user -q --upgrade-strategy only-if-needed #- pip install -r ./requirements/docs.txt --user -q - - pip install -r ./requirements/examples.txt --user -q + - pip install -r ./requirements/examples.txt --user -q --upgrade-strategy only-if-needed - pip list - python -c "import torch ; print(' & '.join([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]) if torch.cuda.is_available() else 'only CPU')" - coverage run --source pytorch_lightning -m py.test pytorch_lightning tests -v --durations=25 # --flake8 diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index d596ea764389ee..5cb1aec47abc4b 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -82,3 +82,29 @@ jobs: build_args: PYTHON_VERSION=${{ matrix.python_version }} tags: "XLA-extras-py${{ matrix.python_version }}" timeout-minutes: 25 + + build-cuda: + runs-on: ubuntu-20.04 + strategy: + matrix: + python_version: [3.7] + pytorch_version: [1.3, 1.4, 1.5, 1.6.0] + steps: + - name: Checkout + uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: 3.7 + + - name: Publish Master to Docker + # publish master + uses: docker/build-push-action@v1.1.0 + if: github.event_name == 'push' + with: + repository: pytorchlightning/pytorch_lightning + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + dockerfile: dockers/tpu-extras/Dockerfile + build_args: PYTHON_VERSION=${{ matrix.python_version }},PYTORCH_VERSION=${{ matrix.pytorch_version }} + tags: "cuda-extras-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}" + timeout-minutes: 40 diff --git a/dockers/README.md b/dockers/README.md index b03c3d7a578910..7b3063e00f79c4 100644 --- a/dockers/README.md +++ b/dockers/README.md @@ -12,7 +12,7 @@ or with specific arguments ```bash git clone docker image build \ - -t pytorch-lightning:py38 \ + -t pytorch-lightning:py3.8 \ -f dockers/conda/Dockerfile \ --build-arg PYTHON_VERSION=3.8 \ --build-arg PYTORCH_VERSION=1.4 \ diff --git a/dockers/cuda-extras/Dockerfile b/dockers/cuda-extras/Dockerfile new file mode 100644 index 00000000000000..c4bc5cfb641fd9 --- /dev/null +++ b/dockers/cuda-extras/Dockerfile @@ -0,0 +1,40 @@ +# Existing images: +# --build-arg TORCH_VERSION=1.6.0 --build-arg CUDA_VERSION=10.1 +# --build-arg TORCH_VERSION=1.5 --build-arg CUDA_VERSION=10.1 +# --build-arg TORCH_VERSION=1.4 --build-arg CUDA_VERSION=10.1 +# --build-arg TORCH_VERSION=1.3 --build-arg CUDA_VERSION=10.1 +# --build-arg TORCH_VERSION=1.2 --build-arg CUDA_VERSION=10.0 +# --build-arg TORCH_VERSION=1.1.0 --build-arg CUDA_VERSION=10.0 --build-arg CUDNN_VERSION=7.5 + +ARG TORCH_VERSION=1.6 +ARG CUDA_VERSION=10.1 +ARG CUDNN_VERSION=7 + +FROM pytorch/pytorch:${TORCH_VERSION}-cuda${CUDA_VERSION}-cudnn${CUDNN_VERSION}-devel + +ENV HOROVOD_GPU_ALLREDUCE=NCCL +ENV HOROVOD_GPU_BROADCAST=NCCL +ENV HOROVOD_WITH_PYTORCH=1 +ENV HOROVOD_WITHOUT_TENSORFLOW=1 +ENV HOROVOD_WITHOUT_MXNET=1 +ENV HOROVOD_WITH_GLOO=1 +ENV HOROVOD_WITHOUT_MPI=1 +ENV PATH="$PATH:/root/.local/bin" +ENV MAKEFLAGS="-j$(nproc)" + +COPY ./tests/install_AMP.sh install_AMP.sh +COPY ./requirements/base.txt requirements.txt +COPY ./requirements/extra.txt requirements-extra.txt +COPY ./requirements/test.txt requirements-tests.txt +COPY ./requirements/examples.txt requirements-examples.txt + +RUN apt-get update && apt-get install -y cmake && \ + # Install AMP + bash install_AMP.sh && \ + pip install -r requirements.txt && \ + # HOROVOD_BUILD_ARCH_FLAGS="-mfma" && \ + pip install -r requirements-extra.txt && \ + pip install -r requirements-examples.txt && \ + pip install -r requirements-tests.txt && \ + rm requirements* && \ + pip list diff --git a/tests/Dockerfile b/tests/Dockerfile deleted file mode 100644 index 65c75c1ba34598..00000000000000 --- a/tests/Dockerfile +++ /dev/null @@ -1,27 +0,0 @@ -ARG TORCH_VERSION=1.4 -ARG CUDA_VERSION=10.1 - -FROM pytorch/pytorch:${TORCH_VERSION}-cuda${CUDA_VERSION}-cudnn7-devel - -ENV HOROVOD_GPU_ALLREDUCE: NCCL -ENV HOROVOD_GPU_BROADCAST: NCCL -ENV HOROVOD_WITH_PYTORCH: 1 -ENV HOROVOD_WITHOUT_TENSORFLOW: 1 -ENV HOROVOD_WITHOUT_MXNET: 1 -ENV HOROVOD_WITH_GLOO: 1 -ENV HOROVOD_WITHOUT_MPI: 1 -ENV PATH: "$PATH:/root/.local/bin" -ENV MAKEFLAGS: "-j$(nproc)" - -COPY ./tests/install_AMP.sh install_AMP.sh -COPY ./requirements/base.txt requirements.txt -COPY ./requirements/extra.txt requirements-extra.txt -COPY ./requirements/test.txt requirements-tests.txt - -# Install AMP -RUN apt-get update && apt-get install -y cmake && \ - bash install_AMP.sh && \ - pip install -r requirements.txt --user && \ - pip install -r requirements-extra.txt --user && \ - pip install -r requirements-tests.txt --user && \ - pip list diff --git a/tests/README.md b/tests/README.md index 6286e8b9e81278..ccd62301aa1e25 100644 --- a/tests/README.md +++ b/tests/README.md @@ -54,7 +54,7 @@ coverage xml You can build it on your own, note it takes lots of time, be prepared. ```bash git clone -docker image build -t pytorch_lightning:devel-torch1.4 -f tests/Dockerfile --build-arg TORCH_VERSION=1.4 . +docker image build -t pytorch_lightning:devel-torch1.4 -f dockers/cuda-extras/Dockerfile --build-arg TORCH_VERSION=1.4 . ``` To build other versions, select different Dockerfile. ```bash From 3772601cd6872cde006aab9284e103e857955457 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 31 Jul 2020 20:50:06 +0200 Subject: [PATCH 8/8] update CI testing with pip upgrade (#2380) * try pt1.5 * cpu * upgrade * tpu * user * [blocked by #2380] freeze GPU PT 1.4 (#2780) * freeze * user --- .drone.yml | 2 +- .github/workflows/ci-test-base.yml | 2 +- .github/workflows/ci-testing.yml | 13 ++----- README.md | 2 + tests/models/test_tpu.py | 60 ++++++------------------------ 5 files changed, 20 insertions(+), 59 deletions(-) diff --git a/.drone.yml b/.drone.yml index 67f0c38758dc93..edb6f48bbb0e3d 100644 --- a/.drone.yml +++ b/.drone.yml @@ -6,7 +6,7 @@ name: torch-GPU steps: - name: testing - image: pytorchlightning/pytorch_lightning:cuda-extras-py3.7-torch1.5 + image: pytorchlightning/pytorch_lightning:devel-pt1.4 environment: SLURM_LOCALID: 0 diff --git a/.github/workflows/ci-test-base.yml b/.github/workflows/ci-test-base.yml index 7def5ca4b1ee90..855a9831fd878c 100644 --- a/.github/workflows/ci-test-base.yml +++ b/.github/workflows/ci-test-base.yml @@ -57,7 +57,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade --user pip - pip install --requirement ./requirements/base.txt --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade-strategy only-if-needed + pip install --requirement ./requirements/base.txt --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade pip install --requirement ./requirements/test.txt --quiet --upgrade-strategy only-if-needed # pip install tox coverage python --version diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index d846950a375da0..0ed2db475546f2 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -42,7 +42,7 @@ jobs: - name: Update Pip run: | - pip install -U -q "pip>=20.1" # needed for get pip cacher folder + pip install --quiet "pip>=20.1" --upgrade --user # needed for get pip cacher folder # Github Actions: Run step on specific OS: https://stackoverflow.com/a/57948488/4521646 - name: Setup macOS @@ -54,14 +54,9 @@ jobs: - name: Setup Windows if: runner.os == 'windows' run: | + # remove Horovod from requirements python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if not line.startswith('horovod')] ; open(fname, 'w').writelines(lines)" - # TODO: remove after https://github.com/pytorch/pytorch/issues/32186 is resolved - #- name: Setup Windows on Latest - # if: runner.os == 'windows' && matrix.requires == 'latest' - # run: | - # python -c "fname = 'requirements/base.txt' ; req = open(fname).read().replace('torch>=1.3', 'torch<1.5') ; open(fname, 'w').write(req)" - # versions <= 1.3 may have issues on mac with some BLAS ops due to missing mkl (https://github.com/pytorch/pytorch/issues/18996) - name: Setup MacOS Minimal if: runner.os == 'macOS' && matrix.requires == 'minimal' @@ -92,8 +87,8 @@ jobs: - name: Install dependencies run: | - pip install --requirement requirements/base.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet --upgrade-strategy only-if-needed - # pip install -q "PyYAML>=5.3.1" # needed for installing dependencues + # python -m pip install --upgrade --user pip + pip install --requirement requirements/base.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet --upgrade HOROVOD_BUILD_ARCH_FLAGS="-mfma" pip install --requirement ./requirements/devel.txt --quiet --upgrade-strategy "only-if-needed" python --version pip --version diff --git a/README.md b/README.md index f0f022e600f4f5..2b2132799965dd 100644 --- a/README.md +++ b/README.md @@ -437,6 +437,8 @@ You can also install any past release `0.X.Y` from this repository: pip install https://github.com/PytorchLightning/pytorch-lightning/archive/0.X.Y.zip --upgrade ``` +--- + ## Lightning team #### Leads diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index ccc68cb59b5688..ecbeb821a3bfa1 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -51,42 +51,24 @@ def test_model_tpu_cores_1(tmpdir): tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) +@pytest.mark.parametrize('tpu_core', [1, 5]) @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") @pl_multi_process_test -def test_model_tpu_index_1(tmpdir): - """Make sure model trains on TPU.""" - trainer_options = dict( - default_root_dir=tmpdir, - progress_bar_refresh_rate=0, - max_epochs=1, - distributed_backend='tpu', - tpu_cores=[1], - limit_train_batches=0.4, - limit_val_batches=0.4, - ) - - model = EvalModelTemplate() - tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) - assert torch_xla._XLAC._xla_get_default_device() == 'xla:1' - - -@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") -@pl_multi_process_test -def test_model_tpu_index_5(tmpdir): +def test_model_tpu_index(tmpdir, tpu_core): """Make sure model trains on TPU.""" trainer_options = dict( default_root_dir=tmpdir, progress_bar_refresh_rate=0, max_epochs=1, distributed_backend='tpu', - tpu_cores=[5], + tpu_cores=[tpu_core], limit_train_batches=0.4, limit_val_batches=0.4, ) model = EvalModelTemplate() tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) - assert torch_xla._XLAC._xla_get_default_device() == 'xla:5' + assert torch_xla._XLAC._xla_get_default_device() == f'xla:{tpu_core}' @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") @@ -131,24 +113,27 @@ def test_model_16bit_tpu_cores_1(tmpdir): assert os.environ.get('XLA_USE_BF16') == str(1), "XLA_USE_BF16 was not set in environment variables" +@pytest.mark.parametrize('tpu_core', [1, 5]) @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") @pl_multi_process_test -def test_model_16bit_tpu_index_1(tmpdir): +def test_model_16bit_tpu_index(tmpdir, tpu_core): """Make sure model trains on TPU.""" trainer_options = dict( default_root_dir=tmpdir, precision=16, progress_bar_refresh_rate=0, + train_percent_check=0.4, + val_percent_check=0.2, max_epochs=1, distributed_backend='tpu', - tpu_cores=[1], + tpu_cores=[tpu_core], limit_train_batches=0.4, limit_val_batches=0.4, ) model = EvalModelTemplate() tpipes.run_model_test(trainer_options, model, on_gpu=False) - assert torch_xla._XLAC._xla_get_default_device() == 'xla:1' + assert torch_xla._XLAC._xla_get_default_device() == f'xla:{tpu_core}' assert os.environ.get('XLA_USE_BF16') == str(1), "XLA_USE_BF16 was not set in environment variables" @@ -177,27 +162,7 @@ def test_model_16bit_tpu_cores_8(tmpdir): @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") @pl_multi_process_test -def test_model_16bit_tpu_index_5(tmpdir): - """Test if distributed TPU core training works""" - model = EvalModelTemplate() - trainer = Trainer( - default_root_dir=tmpdir, - precision=16, - max_epochs=1, - train_percent_check=0.4, - val_percent_check=0.2, - distributed_backend='tpu', - tpu_cores=[5], - ) - trainer.fit(model) - assert torch_xla._XLAC._xla_get_default_device() == 'xla:5' - assert os.environ.get('XLA_USE_BF16') == str(1), "XLA_USE_BF16 was not set in environment variables" - - -@pytest.mark.parametrize('tpu_core', [1, 5]) -@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") -@pl_multi_process_test -def test_train_on_single_tpu(tmpdir, tpu_core): +def test_model_tpu_early_stop(tmpdir): """Test if single TPU core training works""" model = EvalModelTemplate() trainer = Trainer( @@ -208,10 +173,9 @@ def test_train_on_single_tpu(tmpdir, tpu_core): limit_train_batches=10, limit_val_batches=10, distributed_backend='tpu', - tpu_cores=[tpu_core], + tpu_cores=1, ) trainer.fit(model) - assert torch_xla._XLAC._xla_get_default_device() == 'xla:5' @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")