Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Horovod distributed backend to set the root_gpu property #1669

Merged
merged 9 commits into from
May 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed broken link in PR template ([#1675](https://github.com/PyTorchLightning/pytorch-lightning/pull/1675))
- Fixed ModelCheckpoint not None checking filepath ([1654](https://github.com/PyTorchLightning/pytorch-lightning/pull/1654))

- Trainer now calls `on_load_checkpoint()` when resuming from a checkpoint ([1666](https://github.com/PyTorchLightning/pytorch-lightning/pull/1666))

- Fixed Horovod distributed backend to set the `root_gpu` property ([#1669](https://github.com/PyTorchLightning/pytorch-lightning/pull/1669))


## [0.7.5] - 2020-04-27

Expand Down
16 changes: 12 additions & 4 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,7 @@ def set_distributed_mode(self, distributed_backend):

if distributed_backend is None:
if self.has_horovodrun():
self.check_horovod()
self.use_horovod = True
self._set_horovod_backend()
elif self.num_gpus == 0:
if self.num_nodes > 1 or self.num_processes > 1:
self.use_ddp = True # ddp_cpu
Expand Down Expand Up @@ -235,8 +234,7 @@ def set_distributed_mode(self, distributed_backend):
self.data_parallel_device_ids = None
self.on_gpu = False
elif distributed_backend == 'horovod':
self.check_horovod()
self.use_horovod = True
self._set_horovod_backend()

# throw error to force user ddp or ddp2 choice
if self.num_nodes > 1 and not (self.use_ddp2 or self.use_ddp):
Expand Down Expand Up @@ -421,6 +419,16 @@ def resolve_root_node_address(self, root_node):

return root_node

def _set_horovod_backend(self):
self.check_horovod()
self.use_horovod = True

# Initialize Horovod to get rank / size info
hvd.init()
if self.on_gpu:
# Horovod assigns one local GPU per process
self.root_gpu = hvd.local_rank()

def check_horovod(self):
"""Raises a `MisconfigurationException` if the Trainer is not configured correctly for Horovod."""
if not HOROVOD_AVAILABLE:
Expand Down
8 changes: 3 additions & 5 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,13 +570,11 @@ def dp_train(self, model):
model.forward = model_autocast_original_forward

def horovod_train(self, model):
# Horovod: initialize library
hvd.init()

if torch.cuda.is_available() and self.on_gpu:
# Horovod: pin GPU to local rank
torch.cuda.set_device(hvd.local_rank())
model.cuda(hvd.local_rank())
assert self.root_gpu == hvd.local_rank()
torch.cuda.set_device(self.root_gpu)
model.cuda(self.root_gpu)

# Only show progress bar from the first worker
self.progress_bar_refresh_rate = self.progress_bar_refresh_rate if hvd.rank() == 0 else 0
Expand Down
8 changes: 7 additions & 1 deletion tests/models/data/horovod/train_default_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@
PATH_ROOT = os.path.join(PATH_HERE, '..', '..', '..', '..')
sys.path.insert(0, os.path.abspath(PATH_ROOT))

from pytorch_lightning import Trainer # noqa: E402
from pytorch_lightning.callbacks import ModelCheckpoint # noqa: E402
import tests.base.utils as tutils # noqa: E402


parser = argparse.ArgumentParser()
parser.add_argument('--trainer-options', required=True)
parser.add_argument('--on-gpu', action='store_true', default=False)


def run_test_from_config(trainer_options):
Expand All @@ -44,11 +46,15 @@ def run_test_from_config(trainer_options):
trainer_options['checkpoint_callback'] = ModelCheckpoint(ckpt_path)

model, hparams = tutils.get_default_model()
tutils.run_model_test(trainer_options, model, version=0, with_hpc=False)
tutils.run_model_test(trainer_options, model, on_gpu=args.on_gpu, version=0, with_hpc=False)

# Horovod should be initialized following training. If not, this will raise an exception.
assert hvd.size() == 2

if args.on_gpu:
# Test the root_gpu property
assert Trainer(gpus=1, distributed_backend='horovod', max_epochs=1).root_gpu == hvd.local_rank()


if __name__ == "__main__":
args = parser.parse_args()
Expand Down
7 changes: 5 additions & 2 deletions tests/models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ def _nccl_available():
return False


def _run_horovod(trainer_options):
def _run_horovod(trainer_options, on_gpu=False):
"""Execute the training script across multiple workers in parallel."""
cmdline = ['horovodrun', '-np', '2', sys.executable, TEST_SCRIPT,
'--trainer-options', shlex.quote(json.dumps(trainer_options))]
if on_gpu:
cmdline += ['--on-gpu']
exit_code = subprocess.call(' '.join(cmdline), shell=True, env=os.environ.copy())
assert exit_code == 0

Expand Down Expand Up @@ -93,7 +95,7 @@ def test_horovod_multi_gpu(tmpdir):
gpus=1,
distributed_backend='horovod'
)
_run_horovod(trainer_options)
_run_horovod(trainer_options, on_gpu=True)


@pytest.mark.skipif(sys.version_info >= (3, 8), reason="Horovod not yet supported in Python 3.8")
Expand Down Expand Up @@ -159,5 +161,6 @@ def get_model_params(model):
def get_optimizer_params(optimizer):
return set([p for group in optimizer.param_groups for p in group.get('params', [])])

assert get_model_params(model.generator) != get_model_params(model.discriminator)
assert get_model_params(model.generator) == get_optimizer_params(trainer.optimizers[0])
assert get_model_params(model.discriminator) == get_optimizer_params(trainer.optimizers[1])