Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds back the slow spawn ddp implementation that people want #2115

Merged
merged 28 commits into from
Jun 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
27b7d8d
training batch clean up
williamFalcon Jun 8, 2020
d210d19
training batch clean up
williamFalcon Jun 8, 2020
3b57580
training batch clean up
williamFalcon Jun 8, 2020
9c69793
training batch clean up
williamFalcon Jun 8, 2020
32f7e26
training batch clean up
williamFalcon Jun 8, 2020
fca185f
training batch clean up
williamFalcon Jun 8, 2020
a7fb179
training batch clean up
williamFalcon Jun 8, 2020
40f7d0a
training batch clean up
williamFalcon Jun 8, 2020
24c905e
training batch clean up
williamFalcon Jun 8, 2020
6962d2e
training batch clean up
williamFalcon Jun 8, 2020
7e36504
training batch clean up
williamFalcon Jun 8, 2020
8d08b4d
training batch clean up
williamFalcon Jun 8, 2020
3f81631
training batch clean up
williamFalcon Jun 8, 2020
661bfa3
training batch clean up
williamFalcon Jun 8, 2020
4a7f4fa
training batch clean up
williamFalcon Jun 8, 2020
ae7ffee
training batch clean up
williamFalcon Jun 8, 2020
aa78f87
training batch clean up
williamFalcon Jun 8, 2020
a352478
training batch clean up
williamFalcon Jun 8, 2020
8368b03
training batch clean up
williamFalcon Jun 8, 2020
6fde56a
training batch clean up
williamFalcon Jun 8, 2020
77a442e
adding spawn
williamFalcon Jun 8, 2020
7256e27
adding spawn
williamFalcon Jun 8, 2020
26d7e9e
adding spawn
williamFalcon Jun 8, 2020
34954e7
adding spawn
williamFalcon Jun 8, 2020
6d285eb
adding spawn
williamFalcon Jun 8, 2020
c890a70
adding spawn
williamFalcon Jun 8, 2020
0f97fe0
adding spawn
williamFalcon Jun 8, 2020
cf1de18
adding spawn
williamFalcon Jun 8, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 111 additions & 1 deletion docs/source/multi_gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ Distributed modes
Lightning allows multiple ways of training

- Data Parallel (`distributed_backend='dp'`) (multiple-gpus, 1 machine)
- DistributedDataParallel (`distributed_backend='ddp'`) (multiple-gpus across many machines).
- DistributedDataParallel (`distributed_backend='ddp'`) (multiple-gpus across many machines (python script based)).
- DistributedDataParallel (`distributed_backend='ddp_spawn'`) (multiple-gpus across many machines (spawn based)).
- DistributedDataParallel 2 (`distributed_backend='ddp2'`) (dp in a machine, ddp across machines).
- Horovod (`distributed_backend='horovod'`) (multi-machine, multi-gpu, configured at runtime)
- TPUs (`tpu_cores=8|x`) (tpu or TPU pod)
Expand Down Expand Up @@ -253,6 +254,26 @@ Distributed Data Parallel
# train on 32 GPUs (4 nodes)
trainer = Trainer(gpus=8, distributed_backend='ddp', num_nodes=4)

This Lightning implementation of ddp calls your script under the hood multiple times with the correct environment
variables. If your code does not support this (ie: jupyter notebook, colab, or a nested script without a root package),
use `dp` or `ddp_spawn`

.. code-block:: bash

# example for 3 GPUs ddp
MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=0 LOCAL_RANK=0 python my_file.py --gpus 3 --etc
MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=1 LOCAL_RANK=0 python my_file.py --gpus 3 --etc
MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=2 LOCAL_RANK=0 python my_file.py --gpus 3 --etc

The reason we use ddp this way is because `ddp_spawn` has a few limitations (because of Python and PyTorch):

1. Since `.spawn()` trains the model in subprocesses, the model on the main process does not get updated.
2. Dataloader(num_workers=N) where N is large bottlenecks training with ddp...
ie: it will be VERY slow or not work at all. This is a PyTorch limitation.
3. Forces everything to be picklable.

However, if you don't mind these limitations, please use `ddp_spawn`.

Distributed Data Parallel 2
^^^^^^^^^^^^^^^^^^^^^^^^^^^
In certain cases, it's advantageous to use all batches on the same machine instead of a subset.
Expand All @@ -275,6 +296,75 @@ In this case, we can use ddp2 which behaves like dp in a machine and ddp across
# train on 32 GPUs (4 nodes)
trainer = Trainer(gpus=8, distributed_backend='ddp2', num_nodes=4)

Distributed Data Parallel Spawn
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
`ddp_spawn` is exactly like `ddp` except that it uses .spawn to start the training processes.

.. warning:: It is STRONGLY recommended to use `ddp` for speed and performance.

.. code-block:: python

mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model, ))

Here's how to call this.

.. code-block:: python

# train on 8 GPUs (same machine (ie: node))
trainer = Trainer(gpus=8, distributed_backend='ddp')

Use this method if your script does not support being called from the command line (ie: it is nested without a root
project module). However, we STRONGLY discourage this use because it has limitations (because of Python and PyTorch):

1. The model you pass in will not update. Please save a checkpoint and restore from there.
2. Set Dataloader(num_workers=0) or it will bottleneck training.

`ddp` is MUCH faster than `ddp_spawn`. We recommend you install a top-level module for your project using setup.py

.. code-block:: python

# setup.py
#!/usr/bin/env python

from setuptools import setup, find_packages

setup(name='src',
version='0.0.1',
description='Describe Your Cool Project',
author='',
author_email='',
url='https://github.com/YourSeed', # REPLACE WITH YOUR OWN GITHUB PROJECT LINK
install_requires=[
'pytorch-lightning'
],
packages=find_packages()
)

Then setup your project like so:

.. code-block:: bash

/project
/src
some_file.py
/or_a_folder
setup.py

Then install as a root-level package

.. code-block:: bash

cd /project
pip install -e .

Now you can call your scripts anywhere

.. code-block:: bash

cd /project/src
python some_file.py --distributed_backend 'ddp' --gpus 8


Horovod
^^^^^^^
`Horovod <http://horovod.ai>`_ allows the same training script to be used for single-GPU,
Expand Down Expand Up @@ -516,3 +606,23 @@ And then launch the elastic job with:

See the official `PytorchElastic documentation <https://pytorch.org/elastic>`_ for details
on installation and more use cases.

Jupyter Notebooks
-----------------
Unfortunately any `ddp_` is not supported in jupyter notebooks. Please use `dp` for multiple GPUs. This is a known
Jupyter issue. If you feel like taking a stab at adding this support, feel free to submit a PR!

Pickle Errors
--------------
Multi-GPU training sometimes requires your model to be pickled. If you run into an issue with pickling
try the following to figure out the issue

.. code-block:: python

import pickle

model = YourModel()
pickle.dumps(model)

However, if you use `ddp` the pickling requirement is not there and you should be fine. If you use `ddp_spawn` the
pickling requirement remains. This is a limitation of Python.
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def _get_distributed_sampler(self, dataloader):
else:
world_size = {
'ddp': self.num_nodes * self.num_processes,
'ddp_spawn': self.num_nodes * self.num_processes,
'ddp2': self.num_nodes,
'ddp_cpu': self.num_processes * self.num_nodes
}
Expand Down
10 changes: 7 additions & 3 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def set_distributed_mode(self, distributed_backend):
elif self.num_gpus > 1:
self.use_dp = True

elif distributed_backend == "ddp":
elif distributed_backend in ['ddp', 'ddp_spawn']:
if self.num_gpus == 0:
if self.num_nodes > 1 or self.num_processes > 1:
self.use_ddp = True # ddp_cpu
Expand Down Expand Up @@ -378,6 +378,7 @@ def spawn_ddp_children(self, model):

self.interactive_ddp_procs = []
for local_rank in range(1, self.num_processes):
print('launching local_rank', local_rank)
env_copy = os.environ.copy()
env_copy['LOCAL_RANK'] = f'{local_rank}'

Expand All @@ -394,14 +395,17 @@ def spawn_ddp_children(self, model):
local_rank = 0
self.ddp_train(local_rank, model, is_master=True)

def ddp_train(self, process_idx, model, is_master=False):
def ddp_train(self, process_idx, model, is_master=False, proc_offset=0):
"""
Entry point into a DP thread
:param gpu_idx:
:param model:
:param cluster_obj:
:return:
"""
# offset the process id if requested
process_idx = process_idx + proc_offset

# show progressbar only on progress_rank 0
if (self.node_rank != 0 or process_idx != 0) and self.progress_bar_callback is not None:
self.progress_bar_callback.disable()
Expand Down Expand Up @@ -454,7 +458,7 @@ def ddp_train(self, process_idx, model, is_master=False):
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)

# DDP2 uses all GPUs on the machine
if self.distributed_backend == 'ddp':
if self.distributed_backend == 'ddp' or self.distributed_backend == 'ddp_spawn':
device_ids = [self.root_gpu]
elif self.use_ddp2:
device_ids = self.data_parallel_device_ids
Expand Down
9 changes: 8 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def __init__(

Use `row_log_interval` instead. Will remove 0.9.0.

distributed_backend: The distributed backend to use.
distributed_backend: The distributed backend to use (dp, ddp, ddp2, ddp_spawn)

use_amp:
.. warning:: .. deprecated:: 0.7.0
Expand Down Expand Up @@ -876,9 +876,16 @@ def fit(
self.ddp_train(task, model)

elif self.distributed_backend == 'cpu_ddp':
self.__set_random_port()
self.model = model
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model,))

elif self.distributed_backend == 'ddp_spawn':
model.share_memory()

# spin up peers
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model, ))

elif self.distributed_backend == 'ddp':
self.spawn_ddp_children(model)

Expand Down