From 38cf947179d7b2a064d422eeb18879b189110bf5 Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 30 Apr 2020 23:29:36 +0200 Subject: [PATCH 1/8] params --- tests/base/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/base/utils.py b/tests/base/utils.py index 1f0d582ed6e01..fc10d75bf868b 100644 --- a/tests/base/utils.py +++ b/tests/base/utils.py @@ -127,7 +127,7 @@ def get_default_model(lbfgs=False): hparams = get_default_hparams() if lbfgs: setattr(hparams, 'optimizer_name', 'lbfgs') - setattr(hparams, 'learning_rate', 0.002) + setattr(hparams, 'learning_rate', 0.005) model = LightningTestModel(hparams) From c501e26ebccaf063b43bb906c0acb45b2e0470b4 Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 30 Apr 2020 23:37:13 +0200 Subject: [PATCH 2/8] drop acc --- tests/models/test_cpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index eb3b28769e206..9c37bf36963fc 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -81,7 +81,7 @@ def test_lbfgs_cpu_model(tmpdir): ) model, hparams = tutils.get_default_model(lbfgs=True) - tutils.run_model_test_without_loggers(trainer_options, model, min_acc=0.5) + tutils.run_model_test_without_loggers(trainer_options, model, min_acc=0.2) def test_default_logger_callbacks_cpu_model(tmpdir): From b092416be291bbe703d4ba8e4ba192cb2db37ba5 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 29 Apr 2020 14:31:04 -0700 Subject: [PATCH 3/8] Fix Horovod distributed backend to set the root_gpu --- pytorch_lightning/trainer/distrib_parts.py | 5 +++-- tests/models/data/horovod/train_default_model.py | 8 +++++++- tests/models/test_horovod.py | 6 ++++-- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index db4e132c0b445..dd3ffa45df272 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -575,8 +575,9 @@ def horovod_train(self, model): if torch.cuda.is_available() and self.on_gpu: # Horovod: pin GPU to local rank - torch.cuda.set_device(hvd.local_rank()) - model.cuda(hvd.local_rank()) + self.root_gpu = hvd.local_rank() + torch.cuda.set_device(self.root_gpu) + model.cuda(self.root_gpu) # Only show progress bar from the first worker self.progress_bar_refresh_rate = self.progress_bar_refresh_rate if hvd.rank() == 0 else 0 diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index 6c11e2ca5e755..dc887029cbf2a 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -27,12 +27,14 @@ PATH_ROOT = os.path.join(PATH_HERE, '..', '..', '..', '..') sys.path.insert(0, os.path.abspath(PATH_ROOT)) +from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint # noqa: E402 import tests.base.utils as tutils # noqa: E402 parser = argparse.ArgumentParser() parser.add_argument('--trainer-options', required=True) +parser.add_argument('--on-gpu', action='store_true', default=False) def run_test_from_config(trainer_options): @@ -44,11 +46,15 @@ def run_test_from_config(trainer_options): trainer_options['checkpoint_callback'] = ModelCheckpoint(ckpt_path) model, hparams = tutils.get_default_model() - tutils.run_model_test(trainer_options, model, version=0, with_hpc=False) + tutils.run_model_test(trainer_options, model, on_gpu=args.on_gpu, version=0, with_hpc=False) # Horovod should be initialized following training. If not, this will raise an exception. assert hvd.size() == 2 + if args.on_gpu: + # Test the root_gpu property + assert Trainer(gpus=1, distributed_backend='horovod').root_gpu == hvd.local_rank() + if __name__ == "__main__": args = parser.parse_args() diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index c4bcb4b81b995..19e212f5243d1 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -38,10 +38,12 @@ def _nccl_available(): return False -def _run_horovod(trainer_options): +def _run_horovod(trainer_options, on_gpu=False): """Execute the training script across multiple workers in parallel.""" cmdline = ['horovodrun', '-np', '2', sys.executable, TEST_SCRIPT, '--trainer-options', shlex.quote(json.dumps(trainer_options))] + if on_gpu: + cmdline += ['--on-gpu'] exit_code = subprocess.call(' '.join(cmdline), shell=True, env=os.environ.copy()) assert exit_code == 0 @@ -93,7 +95,7 @@ def test_horovod_multi_gpu(tmpdir): gpus=1, distributed_backend='horovod' ) - _run_horovod(trainer_options) + _run_horovod(trainer_options, on_gpu=True) @pytest.mark.skipif(sys.version_info >= (3, 8), reason="Horovod not yet supported in Python 3.8") From 6deaacdf97729d5f9d815c7505391abeaf4188d2 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 29 Apr 2020 14:40:55 -0700 Subject: [PATCH 4/8] Fixed test --- tests/models/data/horovod/train_default_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index dc887029cbf2a..768e31eb8e9a5 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -53,7 +53,9 @@ def run_test_from_config(trainer_options): if args.on_gpu: # Test the root_gpu property - assert Trainer(gpus=1, distributed_backend='horovod').root_gpu == hvd.local_rank() + trainer = Trainer(gpus=1, distributed_backend='horovod') + trainer.fit(model) + assert trainer.root_gpu == hvd.local_rank() if __name__ == "__main__": From f4976326440c7fdb3a47183f5d0f19cf5ac9d3b2 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 29 Apr 2020 14:49:43 -0700 Subject: [PATCH 5/8] Fixed tests --- tests/models/data/horovod/train_default_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index 768e31eb8e9a5..83103cd062290 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -53,7 +53,8 @@ def run_test_from_config(trainer_options): if args.on_gpu: # Test the root_gpu property - trainer = Trainer(gpus=1, distributed_backend='horovod') + model, hparams = tutils.get_default_model() + trainer = Trainer(gpus=1, distributed_backend='horovod', max_epochs=1) trainer.fit(model) assert trainer.root_gpu == hvd.local_rank() From dcfd7d7459da20ea6bc74b59cd9ff70819948e06 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 29 Apr 2020 16:28:07 -0700 Subject: [PATCH 6/8] Fixed lint --- tests/models/data/horovod/train_default_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index 83103cd062290..6d1f152f7a560 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -27,7 +27,7 @@ PATH_ROOT = os.path.join(PATH_HERE, '..', '..', '..', '..') sys.path.insert(0, os.path.abspath(PATH_ROOT)) -from pytorch_lightning import Trainer +from pytorch_lightning import Trainer # noqa: E402 from pytorch_lightning.callbacks import ModelCheckpoint # noqa: E402 import tests.base.utils as tutils # noqa: E402 From f24f0d5127c1d5f0651589951c1ee4fbdc39ef0d Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 30 Apr 2020 13:11:17 -0700 Subject: [PATCH 7/8] Set root_gpu during initialization --- .../trainer/distrib_data_parallel.py | 16 ++++++++++++---- pytorch_lightning/trainer/distrib_parts.py | 5 +---- tests/models/data/horovod/train_default_model.py | 5 +---- tests/models/test_horovod.py | 1 + 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 56c7bae8ec6a7..8651dd5c1b5a0 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -194,8 +194,7 @@ def set_distributed_mode(self, distributed_backend): if distributed_backend is None: if self.has_horovodrun(): - self.check_horovod() - self.use_horovod = True + self._set_horovod_backend() elif self.num_gpus == 0: if self.num_nodes > 1 or self.num_processes > 1: self.use_ddp = True # ddp_cpu @@ -235,8 +234,7 @@ def set_distributed_mode(self, distributed_backend): self.data_parallel_device_ids = None self.on_gpu = False elif distributed_backend == 'horovod': - self.check_horovod() - self.use_horovod = True + self._set_horovod_backend() # throw error to force user ddp or ddp2 choice if self.num_nodes > 1 and not (self.use_ddp2 or self.use_ddp): @@ -421,6 +419,16 @@ def resolve_root_node_address(self, root_node): return root_node + def _set_horovod_backend(self): + self.check_horovod() + self.use_horovod = True + + # Initialize Horovod to get rank / size info + hvd.init() + if self.on_gpu: + # Horovod assigns one local GPU per process + self.root_gpu = hvd.local_rank() + def check_horovod(self): """Raises a `MisconfigurationException` if the Trainer is not configured correctly for Horovod.""" if not HOROVOD_AVAILABLE: diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index dd3ffa45df272..a9f4b6114522e 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -570,12 +570,9 @@ def dp_train(self, model): model.forward = model_autocast_original_forward def horovod_train(self, model): - # Horovod: initialize library - hvd.init() - if torch.cuda.is_available() and self.on_gpu: # Horovod: pin GPU to local rank - self.root_gpu = hvd.local_rank() + assert self.root_gpu == hvd.local_rank() torch.cuda.set_device(self.root_gpu) model.cuda(self.root_gpu) diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index 6d1f152f7a560..3410cdc1d5051 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -53,10 +53,7 @@ def run_test_from_config(trainer_options): if args.on_gpu: # Test the root_gpu property - model, hparams = tutils.get_default_model() - trainer = Trainer(gpus=1, distributed_backend='horovod', max_epochs=1) - trainer.fit(model) - assert trainer.root_gpu == hvd.local_rank() + assert Trainer(gpus=1, distributed_backend='horovod', max_epochs=1).root_gpu == hvd.local_rank() if __name__ == "__main__": diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 19e212f5243d1..21a90c191579b 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -161,5 +161,6 @@ def get_model_params(model): def get_optimizer_params(optimizer): return set([p for group in optimizer.param_groups for p in group.get('params', [])]) + assert get_model_params(model.generator) != get_model_params(model.discriminator) assert get_model_params(model.generator) == get_optimizer_params(trainer.optimizers[0]) assert get_model_params(model.discriminator) == get_optimizer_params(trainer.optimizers[1]) From e8be340be0ad44f45d983b6f5cfeccac21824dfe Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 30 Apr 2020 22:34:59 +0200 Subject: [PATCH 8/8] chlog --- CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f67e85a452ff9..94675a20111c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,10 +18,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- Fixed broken link in PR template ([#1675](https://github.com/PyTorchLightning/pytorch-lightning/pull/1675)) - Fixed ModelCheckpoint not None checking filepath ([1654](https://github.com/PyTorchLightning/pytorch-lightning/pull/1654)) + - Trainer now calls `on_load_checkpoint()` when resuming from a checkpoint ([1666](https://github.com/PyTorchLightning/pytorch-lightning/pull/1666)) +- Fixed Horovod distributed backend to set the `root_gpu` property ([#1669](https://github.com/PyTorchLightning/pytorch-lightning/pull/1669)) + ## [0.7.5] - 2020-04-27