Skip to content

Commit

Permalink
[ddp] Support multi-node distributed execution under torchelastic (#1811
Browse files Browse the repository at this point in the history
)

The changes are quite local and limited in nature -- viz., checking for
some indicator environment variables. We check for (SLURM_LOCALID,
NODE_RANK, GROUP_RANK) in order. If multiple are found set, a warning is
logged.

This patch also fixes a minor bug with comparing the `WORLD_SIZE`
environment variable. This can be a string type.
  • Loading branch information
ashwinb committed May 13, 2020
1 parent b1d9656 commit aefc531
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 14 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for Pytorch elastic distributed launch environment ([#1811](https://github.com/PyTorchLightning/pytorch-lightning/pull/1811))

- Added callback for logging learning rates ([#1498](https://github.com/PyTorchLightning/pytorch-lightning/pull/1498))

- Added transfer learning example (for a binary classification task in computer vision) ([#1564](https://github.com/PyTorchLightning/pytorch-lightning/pull/1564))
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,11 +944,12 @@ def init_ddp_connection(
os.environ['MASTER_PORT'] = '12910'
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != world_size:
log.warning("WORLD_SIZE environment variable is not equal to the computed "
"world size. Ignored.")
if 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) != world_size:
log.warning(f"WORLD_SIZE environment variable ({os.environ['WORLD_SIZE']}) "
f"is not equal to the computed world size ({world_size}). Ignored.")

torch_backend = "nccl" if self.trainer.on_gpu else "gloo"
log.info(f"initializing proc_rank {proc_rank} world {world_size}")
torch_distrib.init_process_group(torch_backend, rank=proc_rank, world_size=world_size)

def configure_apex(
Expand Down
28 changes: 19 additions & 9 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,25 @@ def configure_slurm_ddp(self, num_gpu_nodes):
if self.is_slurm_managing_tasks:
log.info('Multi-processing is handled by Slurm.')

def determine_ddp_node_rank(self):
if self.is_slurm_managing_tasks:
return int(os.environ['SLURM_NODEID'])

# torchelastic uses the envvar GROUP_RANK, whereas other systems(?) use NODE_RANK.
# otherwise use given node rank or default to node rank 0
env_vars = ['NODE_RANK', 'GROUP_RANK']
node_ids = [(k, os.environ.get(k, None)) for k in env_vars]
node_ids = [(k, v) for k, v in node_ids if v is not None]
if len(node_ids) == 0:
log.warning("No environment variable for node rank defined. Set as 0.")
return 0
if len(node_ids) > 1:
log.warning(f"Multiple environment variables ({node_ids}) defined for node rank. "
f"Using the first one.")
k, rank = node_ids.pop()
log.info(f"Using environment variable {k} for node rank ({rank}).")
return int(rank)

def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
if data_parallel_device_ids is None:
return
Expand All @@ -305,15 +324,6 @@ def ddp_train(self, process_idx, model):
:param cluster_obj:
:return:
"""
# node rank using relative slurm id if under slurm management
# otherwise use given node rank or default to node rank 0
try:
node_id = os.environ['SLURM_NODEID'] if self.is_slurm_managing_tasks else os.environ['NODE_RANK']
self.node_rank = int(node_id)
except KeyError:
log.warning("SLURM_NODEID or NODE_RANK environment variable is not defined. Set as 0.")
self.node_rank = 0

# show progressbar only on progress_rank 0
if (self.node_rank != 0 or process_idx != 0) and self.progress_bar_callback is not None:
self.progress_bar_callback.disable()
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,8 @@ def __init__(
# init flags for SLURM+ddp to work
self.proc_rank = 0
self.world_size = 1
self.node_rank = 0
self.configure_slurm_ddp(self.num_nodes)
self.node_rank = self.determine_ddp_node_rank()

# nvidia setup
self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)
Expand Down Expand Up @@ -796,11 +796,14 @@ def fit(
if self.use_ddp2:
task = int(os.environ['SLURM_LOCALID'])
self.ddp_train(task, model)

elif self.use_ddp:
if self.is_slurm_managing_tasks:
task = int(os.environ['SLURM_LOCALID'])
self.ddp_train(task, model)
# torchelastic
elif 'WORLD_SIZE' in os.environ and 'GROUP_RANK' in os.environ:
task = int(os.environ['LOCAL_RANK'])
self.ddp_train(task, model)
else:
self.__set_random_port()
# track for predict
Expand Down

0 comments on commit aefc531

Please sign in to comment.