Skip to content

Commit

Permalink
Add ddp_cpu backend for testing ddp without GPUs (#1158)
Browse files Browse the repository at this point in the history
* Add tests for distributed backend config

* Refactor set_distributed_mode

* Use gloo backend on cpu

* Use 127.0.0.1 instead of 127.0.0.2

Not totally clear on why this is necessary, but it seemt to work

* Update LightningDDP so that it works with CPU

* Add ddp_cpu backend and num_processes Trainer arg

* PEP8

* Fix test skipping. Inequalities are hard :/

* Skip ddp_cpu test on Windows

* Make a few more cases fall back to ddp_cpu

* New function name

* Flake8

* Don't test distributed on MacOS with torch < 1.3

Support for distributed in MacOS was added in Torch 1.3.0

* Add ddp_cpu and num_processes to docs

* Parametrize trainer config tests

* Tweak warning

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Remove redundant test

* Replace pass branches with comments

* Add missing warnings import

* save_path -> root_dir

* Use new rank_zero_warn

* Whitespace

* Apply suggestions from code review

* formatting

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
  • Loading branch information
3 people authored Apr 16, 2020
1 parent 3431c62 commit e3001a0
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 54 deletions.
5 changes: 3 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,11 +944,12 @@ def init_ddp_connection(self):
try:
root_node = os.environ['SLURM_NODELIST'].split(' ')[0]
except Exception:
root_node = '127.0.0.2'
root_node = '127.0.0.1'

root_node = self.trainer.resolve_root_node_address(root_node)
os.environ['MASTER_ADDR'] = root_node
torch_distrib.init_process_group('nccl', rank=proc_rank, world_size=world_size)
torch_backend = "nccl" if self.trainer.on_gpu else "gloo"
torch_distrib.init_process_group(torch_backend, rank=proc_rank, world_size=world_size)

def configure_apex(
self,
Expand Down
9 changes: 8 additions & 1 deletion pytorch_lightning/overrides/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,14 @@ def forward(self, *inputs, **kwargs): # pragma: no-cover
output = self.gather(outputs, self.output_device)
else:
# normal
output = self.module(*inputs, **kwargs)
# output = self.module(*inputs, **kwargs)
# lightning (ddp_cpu)
if self.module.training:
output = self.module.training_step(*inputs, **kwargs)
elif self.module.testing:
output = self.module.test_step(*inputs, **kwargs)
else:
output = self.module.validation_step(*inputs, **kwargs)

if torch.is_grad_enabled():
# We'll return the output object verbatim since it is a freeform
Expand Down
19 changes: 19 additions & 0 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,10 @@ def on_train_end(self):
- (```dp```) is DataParallel (split batch among GPUs of same machine)
- (```ddp```) is DistributedDataParallel (each gpu on each node trains, and syncs grads)
- (```ddp_cpu```) is DistributedDataParallel on CPU (same as `ddp`, but does not use GPUs.
Useful for multi-node CPU training or single-node debugging. Note that this will **not** give
a speedup on a single node, since Torch already makes effient use of multiple CPUs on a single
machine.)
- (```ddp2```) dp on node, ddp across nodes. Useful for things like increasing
the number of negative samples
Expand Down Expand Up @@ -510,6 +514,21 @@ def on_train_end(self):
Use `num_nodes` instead. Will remove 0.8.0.
num_processes
^^^^^^^^^^^^^
Number of processes to train with. Automatically set to the number of GPUs
when using ``distrbuted_backend="ddp"``. Set to a number greater than 1 when
using ``distributed_backend="ddp_cpu"`` to mimic distributed training on a
machine without GPUs. This is useful for debugging, but **will not** provide
any speedup, since single-process Torch already makes effient use of multiple
CPUs.
Example::
# Simulate DDP for debugging on your GPU-less laptop
trainer = Trainer(distributed_backend="ddp_cpu", num_processes=2)
num_sanity_val_steps
^^^^^^^^^^^^^^^^^^^^
Expand Down
100 changes: 55 additions & 45 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,44 +174,54 @@ def init_tpu(self):
# enable tpu
self.use_tpu = True

def set_distributed_mode(self, distributed_backend, num_gpu_nodes):
# skip for CPU
if self.num_gpus == 0:
return

# single GPU case
# in single gpu case we allow ddp so we can train on multiple
# nodes, 1 gpu per node
if self.num_gpus == 1:
self.single_gpu = True

if distributed_backend is not None:
self.use_dp = distributed_backend == 'dp'
self.use_ddp = distributed_backend == 'ddp'
self.use_ddp2 = distributed_backend == 'ddp2'

# disable single gpu when using ddp2
if self.use_ddp2:
self.single_gpu = False

# multiple GPU case
elif self.num_gpus > 1:
if distributed_backend is not None:
# DP, DDP case
self.use_dp = distributed_backend == 'dp'
self.use_ddp = distributed_backend == 'ddp'
self.use_ddp2 = distributed_backend == 'ddp2'

elif distributed_backend is None:
def set_distributed_mode(self, distributed_backend):
self.use_dp = False
self.use_ddp = False
self.use_ddp2 = False
self.single_gpu = False

if distributed_backend is None:
if self.num_gpus == 0:
if self.num_nodes > 1 or self.num_processes > 1:
self.use_ddp = True # ddp_cpu
elif self.num_gpus == 1:
self.single_gpu = True
elif self.num_gpus > 1:
rank_zero_warn('You requested multiple GPUs but did not specify a backend, e.g.'
' Trainer(distributed_backend=dp) (or ddp, ddp2).'
' Setting distributed_backend=dp for you.')
self.use_dp = True
self.use_ddp = False
self.use_ddp2 = False
elif distributed_backend == "dp":
# do nothing if num_gpus == 0
if self.num_gpus == 1:
self.single_gpu = True
self.use_dp = True
elif self.num_gpus > 1:
self.use_dp = True
elif distributed_backend == "ddp":
if self.num_gpus == 0:
if self.num_nodes > 1 or self.num_processes > 1:
self.use_ddp = True # ddp_cpu
elif self.num_gpus == 1:
self.single_gpu = True
self.use_ddp = True
elif self.num_gpus > 1:
self.use_ddp = True
self.num_processes = self.num_gpus
elif distributed_backend == "ddp2":
# do nothing if num_gpus == 0
if self.num_gpus >= 1:
self.use_ddp2 = True
elif distributed_backend == "ddp_cpu":
if self.num_gpus > 0:
rank_zero_warn('You requested one or more GPUs, but set the backend to `ddp_cpu`.'
' Training will not use GPUs.')
self.use_ddp = True
self.data_parallel_device_ids = None
self.on_gpu = False

# throw error to force user ddp or ddp2 choice
if num_gpu_nodes > 1 and not (self.use_ddp2 or self.use_ddp):
if self.num_nodes > 1 and not (self.use_ddp2 or self.use_ddp):
raise MisconfigurationException(
'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. '
'To silence this warning set distributed_backend=ddp or distributed_backend=ddp2'
Expand Down Expand Up @@ -267,7 +277,7 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):

log.info(f'VISIBLE GPUS: {os.environ["CUDA_VISIBLE_DEVICES"]}')

def ddp_train(self, gpu_idx, model):
def ddp_train(self, process_idx, model):
"""
Entry point into a DP thread
:param gpu_idx:
Expand All @@ -284,16 +294,18 @@ def ddp_train(self, gpu_idx, model):
self.node_rank = 0

# show progressbar only on progress_rank 0
self.progress_bar_refresh_rate = self.progress_bar_refresh_rate if self.node_rank == 0 and gpu_idx == 0 else 0
self.progress_bar_refresh_rate = (
self.progress_bar_refresh_rate if self.node_rank == 0 and process_idx == 0 else 0
)

# determine which process we are and world size
if self.use_ddp:
self.proc_rank = self.node_rank * self.num_gpus + gpu_idx
self.world_size = self.num_gpu_nodes * self.num_gpus
self.proc_rank = self.node_rank * self.num_processes + process_idx
self.world_size = self.num_nodes * self.num_processes

elif self.use_ddp2:
self.proc_rank = self.node_rank
self.world_size = self.num_gpu_nodes
self.world_size = self.num_nodes
# set warning rank
set_proc_rank(self.proc_rank)

Expand All @@ -313,16 +325,14 @@ def ddp_train(self, gpu_idx, model):

# MODEL
# copy model to each gpu
if self.distributed_backend == 'ddp':
torch.cuda.set_device(gpu_idx)
model.cuda(gpu_idx)
if self.on_gpu:
self.root_gpu = self.data_parallel_device_ids[process_idx]
torch.cuda.set_device(self.root_gpu)
model.cuda(self.root_gpu)

# set model properties before going into wrapper
self.copy_trainer_model_properties(model)

# override root GPU
self.root_gpu = gpu_idx

# AMP
# run through amp wrapper before going to distributed DP
if self.use_amp:
Expand All @@ -332,10 +342,10 @@ def ddp_train(self, gpu_idx, model):

# DDP2 uses all GPUs on the machine
if self.distributed_backend == 'ddp':
device_ids = [gpu_idx]
device_ids = [self.root_gpu]
elif self.use_ddp2:
device_ids = self.data_parallel_device_ids
else:
else: # includes ddp_cpu
device_ids = None

# allow user to configure ddp
Expand Down
13 changes: 7 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(
gradient_clip_val: float = 0,
process_position: int = 0,
num_nodes: int = 1,
num_processes: int = 1,
gpus: Optional[Union[List[int], str, int]] = None,
auto_select_gpus: bool = False,
num_tpu_cores: Optional[int] = None,
Expand Down Expand Up @@ -321,6 +322,10 @@ def __init__(
self.num_tpu_cores = num_tpu_cores
assert num_tpu_cores in [1, 8, None], 'num_tpu_cores can only be 1 or 8'

if num_processes != 1 and distributed_backend != "ddp_cpu":
rank_zero_warn("num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it.")
self.num_processes = num_processes

self.process_position = process_position
self.weights_summary = weights_summary

Expand Down Expand Up @@ -441,12 +446,8 @@ def __init__(
self.tpu_global_core_rank = None

# distributed backend choice
self.use_ddp = False
self.use_ddp2 = False
self.use_dp = False
self.single_gpu = False
self.distributed_backend = distributed_backend
self.set_distributed_mode(distributed_backend, self.num_nodes)
self.set_distributed_mode(distributed_backend)

# override dist backend when using tpus
if self.on_tpu:
Expand Down Expand Up @@ -732,7 +733,7 @@ def fit(
self.model = model

# train
mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,))
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model,))

# load weights if not interrupted
self.load_spawn_weights(model)
Expand Down
27 changes: 27 additions & 0 deletions tests/models/test_cpu.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import platform
import warnings

import pytest
import torch
from packaging.version import parse as version_parse

import tests.base.utils as tutils
from pytorch_lightning import Trainer
Expand Down Expand Up @@ -40,6 +42,31 @@ def test_early_stopping_cpu_model(tmpdir):
model.unfreeze()


@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif((platform.system() == "Darwin" and
version_parse(torch.__version__) < version_parse("1.3.0")),
reason="Distributed training is not supported on MacOS before Torch 1.3.0")
def test_multi_cpu_model_ddp(tmpdir):
"""Make sure DDP works."""
tutils.reset_seed()
tutils.set_random_master_port()

model, hparams = tutils.get_default_model()
trainer_options = dict(
default_root_dir=tmpdir,
show_progress_bar=False,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
gpus=None,
num_processes=2,
distributed_backend='ddp_cpu'
)

tutils.run_model_test(trainer_options, model, on_gpu=False)


def test_lbfgs_cpu_model(tmpdir):
"""Test each of the trainer options."""
tutils.reset_seed()
Expand Down
90 changes: 90 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,3 +702,93 @@ def test_gpu_choice(tmpdir):

with pytest.raises(RuntimeError, match=r'.*No GPUs available.*'):
Trainer(**trainer_options, gpus=num_gpus + 1, auto_select_gpus=True)


@pytest.mark.parametrize("trainer_kwargs,expected", [
pytest.param(
dict(distributed_backend=None, gpus=None),
dict(use_dp=False, use_ddp=False, use_ddp2=False, num_gpus=0, on_gpu=False, single_gpu=False, num_processes=1)
),
pytest.param(
dict(distributed_backend="dp", gpus=None),
dict(use_dp=False, use_ddp=False, use_ddp2=False, num_gpus=0, on_gpu=False, single_gpu=False, num_processes=1)
),
pytest.param(
dict(distributed_backend="dp", gpus=None),
dict(use_dp=False, use_ddp=False, use_ddp2=False, num_gpus=0, on_gpu=False, single_gpu=False, num_processes=1)
),
pytest.param(
dict(distributed_backend="ddp", gpus=None),
dict(use_dp=False, use_ddp=False, use_ddp2=False, num_gpus=0, on_gpu=False, single_gpu=False, num_processes=1)
),
pytest.param(
dict(distributed_backend="ddp", num_processes=2, gpus=None),
dict(use_dp=False, use_ddp=True, use_ddp2=False, num_gpus=0, on_gpu=False, single_gpu=False, num_processes=2)
),
pytest.param(
dict(distributed_backend="ddp", num_nodes=2, gpus=None),
dict(use_dp=False, use_ddp=True, use_ddp2=False, num_gpus=0, on_gpu=False, single_gpu=False, num_processes=1)
),
pytest.param(
dict(distributed_backend="ddp_cpu", num_processes=2, gpus=None),
dict(use_dp=False, use_ddp=True, use_ddp2=False, num_gpus=0, on_gpu=False, single_gpu=False, num_processes=2)
),
pytest.param(
dict(distributed_backend="ddp2", gpus=None),
dict(use_dp=False, use_ddp=False, use_ddp2=False, num_gpus=0, on_gpu=False, single_gpu=False, num_processes=1)
),
pytest.param(
dict(distributed_backend=None, gpus=1),
dict(use_dp=False, use_ddp=False, use_ddp2=False, num_gpus=1, on_gpu=True, single_gpu=True, num_processes=1),
marks=[pytest.mark.skipif(torch.cuda.device_count() == 0, reason="GPU needed")]
),
pytest.param(
dict(distributed_backend="dp", gpus=1),
dict(use_dp=True, use_ddp=False, use_ddp2=False, num_gpus=1, on_gpu=True, single_gpu=True, num_processes=1),
marks=[pytest.mark.skipif(torch.cuda.device_count() == 0, reason="GPU needed")]
),
pytest.param(
dict(distributed_backend="ddp", gpus=1),
dict(use_dp=False, use_ddp=True, use_ddp2=False, num_gpus=1, on_gpu=True, single_gpu=True, num_processes=1),
marks=[pytest.mark.skipif(torch.cuda.device_count() == 0, reason="GPU needed")]
),
pytest.param(
dict(distributed_backend="ddp_cpu", num_processes=2, gpus=1),
dict(use_dp=False, use_ddp=True, use_ddp2=False, num_gpus=0, on_gpu=False, single_gpu=False, num_processes=2),
marks=[pytest.mark.skipif(torch.cuda.device_count() == 0, reason="GPU needed")]
),
pytest.param(
dict(distributed_backend="ddp2", gpus=1),
dict(use_dp=False, use_ddp=False, use_ddp2=True, num_gpus=1, on_gpu=True, single_gpu=False, num_processes=1),
marks=[pytest.mark.skipif(torch.cuda.device_count() == 0, reason="GPU needed")]
),
pytest.param(
dict(distributed_backend=None, gpus=2),
dict(use_dp=True, use_ddp=False, use_ddp2=False, num_gpus=2, on_gpu=True, single_gpu=False, num_processes=1),
marks=[pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Multiple GPUs needed")]
),
pytest.param(
dict(distributed_backend="dp", gpus=2),
dict(use_dp=True, use_ddp=False, use_ddp2=False, num_gpus=2, on_gpu=True, single_gpu=False, num_processes=1),
marks=[pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Multiple GPUs needed")]
),
pytest.param(
dict(distributed_backend="ddp", gpus=2),
dict(use_dp=False, use_ddp=True, use_ddp2=False, num_gpus=2, on_gpu=True, single_gpu=False, num_processes=2),
marks=[pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Multiple GPUs needed")]
),
pytest.param(
dict(distributed_backend="ddp2", gpus=2),
dict(use_dp=False, use_ddp=False, use_ddp2=True, num_gpus=2, on_gpu=True, single_gpu=False, num_processes=1),
marks=[pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Multiple GPUs needed")]
),
])
def test_trainer_config(trainer_kwargs, expected):
trainer = Trainer(**trainer_kwargs)
assert trainer.use_dp is expected["use_dp"]
assert trainer.use_ddp is expected["use_ddp"]
assert trainer.use_ddp2 is expected["use_ddp2"]
assert trainer.num_gpus == expected["num_gpus"]
assert trainer.on_gpu is expected["on_gpu"]
assert trainer.single_gpu is expected["single_gpu"]
assert trainer.num_processes == expected["num_processes"]

0 comments on commit e3001a0

Please sign in to comment.