Skip to content

Commit

Permalink
Adds back the slow spawn ddp implementation that people want (#2115)
Browse files Browse the repository at this point in the history
* training batch clean up

* training batch clean up

* training batch clean up

* training batch clean up

* training batch clean up

* training batch clean up

* training batch clean up

* training batch clean up

* training batch clean up

* training batch clean up

* training batch clean up

* training batch clean up

* training batch clean up

* training batch clean up

* training batch clean up

* training batch clean up

* training batch clean up

* training batch clean up

* training batch clean up

* training batch clean up

* adding spawn

* adding spawn

* adding spawn

* adding spawn

* adding spawn

* adding spawn

* adding spawn

* adding spawn
  • Loading branch information
williamFalcon authored and justusschock committed Jun 29, 2020
1 parent ee05ee1 commit 42b6fd2
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 5 deletions.
112 changes: 111 additions & 1 deletion docs/source/multi_gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ Distributed modes
Lightning allows multiple ways of training

- Data Parallel (`distributed_backend='dp'`) (multiple-gpus, 1 machine)
- DistributedDataParallel (`distributed_backend='ddp'`) (multiple-gpus across many machines).
- DistributedDataParallel (`distributed_backend='ddp'`) (multiple-gpus across many machines (python script based)).
- DistributedDataParallel (`distributed_backend='ddp_spawn'`) (multiple-gpus across many machines (spawn based)).
- DistributedDataParallel 2 (`distributed_backend='ddp2'`) (dp in a machine, ddp across machines).
- Horovod (`distributed_backend='horovod'`) (multi-machine, multi-gpu, configured at runtime)
- TPUs (`tpu_cores=8|x`) (tpu or TPU pod)
Expand Down Expand Up @@ -253,6 +254,26 @@ Distributed Data Parallel
# train on 32 GPUs (4 nodes)
trainer = Trainer(gpus=8, distributed_backend='ddp', num_nodes=4)
This Lightning implementation of ddp calls your script under the hood multiple times with the correct environment
variables. If your code does not support this (ie: jupyter notebook, colab, or a nested script without a root package),
use `dp` or `ddp_spawn`

.. code-block:: bash
# example for 3 GPUs ddp
MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=0 LOCAL_RANK=0 python my_file.py --gpus 3 --etc
MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=1 LOCAL_RANK=0 python my_file.py --gpus 3 --etc
MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=2 LOCAL_RANK=0 python my_file.py --gpus 3 --etc
The reason we use ddp this way is because `ddp_spawn` has a few limitations (because of Python and PyTorch):

1. Since `.spawn()` trains the model in subprocesses, the model on the main process does not get updated.
2. Dataloader(num_workers=N) where N is large bottlenecks training with ddp...
ie: it will be VERY slow or not work at all. This is a PyTorch limitation.
3. Forces everything to be picklable.

However, if you don't mind these limitations, please use `ddp_spawn`.

Distributed Data Parallel 2
^^^^^^^^^^^^^^^^^^^^^^^^^^^
In certain cases, it's advantageous to use all batches on the same machine instead of a subset.
Expand All @@ -275,6 +296,75 @@ In this case, we can use ddp2 which behaves like dp in a machine and ddp across
# train on 32 GPUs (4 nodes)
trainer = Trainer(gpus=8, distributed_backend='ddp2', num_nodes=4)
Distributed Data Parallel Spawn
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
`ddp_spawn` is exactly like `ddp` except that it uses .spawn to start the training processes.

.. warning:: It is STRONGLY recommended to use `ddp` for speed and performance.

.. code-block:: python
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model, ))
Here's how to call this.

.. code-block:: python
# train on 8 GPUs (same machine (ie: node))
trainer = Trainer(gpus=8, distributed_backend='ddp')
Use this method if your script does not support being called from the command line (ie: it is nested without a root
project module). However, we STRONGLY discourage this use because it has limitations (because of Python and PyTorch):

1. The model you pass in will not update. Please save a checkpoint and restore from there.
2. Set Dataloader(num_workers=0) or it will bottleneck training.

`ddp` is MUCH faster than `ddp_spawn`. We recommend you install a top-level module for your project using setup.py

.. code-block:: python
# setup.py
#!/usr/bin/env python
from setuptools import setup, find_packages
setup(name='src',
version='0.0.1',
description='Describe Your Cool Project',
author='',
author_email='',
url='https://github.com/YourSeed', # REPLACE WITH YOUR OWN GITHUB PROJECT LINK
install_requires=[
'pytorch-lightning'
],
packages=find_packages()
)
Then setup your project like so:

.. code-block:: bash
/project
/src
some_file.py
/or_a_folder
setup.py
Then install as a root-level package

.. code-block:: bash
cd /project
pip install -e .
Now you can call your scripts anywhere

.. code-block:: bash
cd /project/src
python some_file.py --distributed_backend 'ddp' --gpus 8
Horovod
^^^^^^^
`Horovod <http://horovod.ai>`_ allows the same training script to be used for single-GPU,
Expand Down Expand Up @@ -516,3 +606,23 @@ And then launch the elastic job with:
See the official `PytorchElastic documentation <https://pytorch.org/elastic>`_ for details
on installation and more use cases.

Jupyter Notebooks
-----------------
Unfortunately any `ddp_` is not supported in jupyter notebooks. Please use `dp` for multiple GPUs. This is a known
Jupyter issue. If you feel like taking a stab at adding this support, feel free to submit a PR!

Pickle Errors
--------------
Multi-GPU training sometimes requires your model to be pickled. If you run into an issue with pickling
try the following to figure out the issue

.. code-block:: python
import pickle
model = YourModel()
pickle.dumps(model)
However, if you use `ddp` the pickling requirement is not there and you should be fine. If you use `ddp_spawn` the
pickling requirement remains. This is a limitation of Python.
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def _get_distributed_sampler(self, dataloader):
else:
world_size = {
'ddp': self.num_nodes * self.num_processes,
'ddp_spawn': self.num_nodes * self.num_processes,
'ddp2': self.num_nodes,
'ddp_cpu': self.num_processes * self.num_nodes
}
Expand Down
10 changes: 7 additions & 3 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def set_distributed_mode(self, distributed_backend):
elif self.num_gpus > 1:
self.use_dp = True

elif distributed_backend == "ddp":
elif distributed_backend in ['ddp', 'ddp_spawn']:
if self.num_gpus == 0:
if self.num_nodes > 1 or self.num_processes > 1:
self.use_ddp = True # ddp_cpu
Expand Down Expand Up @@ -378,6 +378,7 @@ def spawn_ddp_children(self, model):

self.interactive_ddp_procs = []
for local_rank in range(1, self.num_processes):
print('launching local_rank', local_rank)
env_copy = os.environ.copy()
env_copy['LOCAL_RANK'] = f'{local_rank}'

Expand All @@ -394,14 +395,17 @@ def spawn_ddp_children(self, model):
local_rank = 0
self.ddp_train(local_rank, model, is_master=True)

def ddp_train(self, process_idx, model, is_master=False):
def ddp_train(self, process_idx, model, is_master=False, proc_offset=0):
"""
Entry point into a DP thread
:param gpu_idx:
:param model:
:param cluster_obj:
:return:
"""
# offset the process id if requested
process_idx = process_idx + proc_offset

# show progressbar only on progress_rank 0
if (self.node_rank != 0 or process_idx != 0) and self.progress_bar_callback is not None:
self.progress_bar_callback.disable()
Expand Down Expand Up @@ -454,7 +458,7 @@ def ddp_train(self, process_idx, model, is_master=False):
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)

# DDP2 uses all GPUs on the machine
if self.distributed_backend == 'ddp':
if self.distributed_backend == 'ddp' or self.distributed_backend == 'ddp_spawn':
device_ids = [self.root_gpu]
elif self.use_ddp2:
device_ids = self.data_parallel_device_ids
Expand Down
9 changes: 8 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def __init__(
Use `row_log_interval` instead. Will remove 0.9.0.
distributed_backend: The distributed backend to use.
distributed_backend: The distributed backend to use (dp, ddp, ddp2, ddp_spawn)
use_amp:
.. warning:: .. deprecated:: 0.7.0
Expand Down Expand Up @@ -876,9 +876,16 @@ def fit(
self.ddp_train(task, model)

elif self.distributed_backend == 'cpu_ddp':
self.__set_random_port()
self.model = model
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model,))

elif self.distributed_backend == 'ddp_spawn':
model.share_memory()

# spin up peers
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model, ))

elif self.distributed_backend == 'ddp':
self.spawn_ddp_children(model)

Expand Down

0 comments on commit 42b6fd2

Please sign in to comment.