Skip to content

Commit

Permalink
made ddp the default if no backend specified with multiple GPUs (#1789)
Browse files Browse the repository at this point in the history
* made ddp the default if no backend specified with multiple GPUs

* fix

* spawn

Co-authored-by: Jirka <jirka.borovec@seznam.cz>
  • Loading branch information
williamFalcon and Borda committed May 12, 2020
1 parent acab068 commit 10b16db
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docs/source/multi_gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ Lightning allows multiple ways of training
- Horovod (`distributed_backend='horovod'`) (multi-machine, multi-gpu, configured at runtime)
- TPUs (`num_tpu_cores=8|x`) (tpu or TPU pod)

.. note:: If you request multiple GPUs without setting a mode, ddp will be automatically used.

Data Parallel (dp)
^^^^^^^^^^^^^^^^^^
`DataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel>`_ splits a batch across k GPUs. That is, if you have a batch of 32 and use dp with 2 gpus,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ def set_distributed_mode(self, distributed_backend):
elif self.num_gpus > 1:
rank_zero_warn('You requested multiple GPUs but did not specify a backend, e.g.'
' Trainer(distributed_backend=dp) (or ddp, ddp2).'
' Setting distributed_backend=dp for you.')
self.use_dp = True
' Setting distributed_backend=ddp for you.')
self.use_ddp = True
elif distributed_backend == "dp":
# do nothing if num_gpus == 0
if self.num_gpus == 1:
Expand Down
1 change: 1 addition & 0 deletions tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def assert_pred_same():
trainer.fit(model)


@pytest.mark.spawn
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_multi_gpu_none_backend(tmpdir):
"""Make sure when using multiple GPUs the user can't use `distributed_backend = None`."""
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def test_gpu_choice(tmpdir):
),
pytest.param(
dict(distributed_backend=None, gpus=2),
dict(use_dp=True, use_ddp=False, use_ddp2=False, num_gpus=2, on_gpu=True, single_gpu=False, num_processes=1),
dict(use_dp=False, use_ddp=True, use_ddp2=False, num_gpus=2, on_gpu=True, single_gpu=False, num_processes=1),
marks=[pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Multiple GPUs needed")]
),
pytest.param(
Expand Down

0 comments on commit 10b16db

Please sign in to comment.