Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Horovod distributed backend #1529

Merged
merged 17 commits into from
Apr 22, 2020
Merged
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ references:
run:
name: Install Dependences
command: |
sudo apt-get update && sudo apt-get install -y cmake
pip install "$TORCH_VERSION"
pip install -r requirements.txt -q
sudo pip install pytest pytest-cov pytest-flake8 -q
Expand Down
13 changes: 11 additions & 2 deletions .drone.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,33 @@ name: torch-GPU

steps:
- name: testing
image: pytorch/pytorch:1.4-cuda10.1-cudnn7-runtime
image: pytorch/pytorch:1.4-cuda10.1-cudnn7-devel

environment:
SLURM_LOCALID: 0
CODECOV_TOKEN:
from_secret: codecov_token
HOROVOD_GPU_ALLREDUCE: NCCL
HOROVOD_GPU_BROADCAST: NCCL
HOROVOD_WITH_PYTORCH: 1
HOROVOD_WITHOUT_TENSORFLOW: 1
HOROVOD_WITHOUT_MXNET: 1
HOROVOD_WITH_GLOO: 1
HOROVOD_WITHOUT_MPI: 1

#volumes:
# # Mount pip cache from host
# - name: pip_cache
# path: /opt/conda/lib/python3.7/site-packages

commands:
- export PATH="$PATH:/root/.local/bin"
- python --version
- pip install pip -U
- pip --version
- nvidia-smi
- bash ./tests/install_AMP.sh
# - bash ./tests/install_AMP.sh
- apt-get update && apt-get install -y cmake
- pip install -r requirements.txt --user -q
- pip install coverage pytest pytest-cov pytest-flake8 codecov -q
- pip install -r ./tests/requirements.txt --user -q
Expand Down
11 changes: 9 additions & 2 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,15 @@ jobs:
if: runner.os == 'macOS'
run: |
brew install libomp # https://github.com/pytorch/pytorch/issues/20030
brew install openmpi # Horovod on macOS requires OpenMPI, Gloo not currently supported
# TODO: remove after https://github.com/pytorch/pytorch/issues/32186 is resolved
- name: Setup Windows
if: runner.os == 'windows'
run: |
python -c "lines = [line for line in open('requirements-extra.txt').readlines() if not line.startswith('horovod')] ; open('requirements-extra.txt', 'w').writelines(lines)"
# TODO: remove after https://github.com/pytorch/pytorch/issues/32186 is resolved
- name: Setup Windows on Latest
if: runner.os == 'windows' && matrix.requires == 'latest'
run: |
python -c "req = open('requirements.txt').read().replace('torch>=1.1', 'torch<1.5') ; open('requirements.txt', 'w').write(req)"
Expand Down Expand Up @@ -75,11 +81,12 @@ jobs:
run: |
# python -m pip install --upgrade --user pip
pip install -r requirements.txt -U -f https://download.pytorch.org/whl/torch_stable.html -q
pip install -r ./tests/requirements.txt -q
HOROVOD_BUILD_ARCH_FLAGS="-mfma" pip install -r ./tests/requirements.txt -q
# pip install tox coverage
python --version
pip --version
pip list
shell: bash

- name: Tests
# env:
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `ddp_cpu` backend for testing ddp without GPUs ([#1158](https://github.com/PyTorchLightning/pytorch-lightning/pull/1158))

- Added [Horovod](http://horovod.ai) support as a distributed backend `Trainer(distributed_backend='horovod')` ([#1529](https://github.com/PyTorchLightning/pytorch-lightning/pull/1529))

### Changed

Expand Down
38 changes: 38 additions & 0 deletions docs/source/multi_gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ Lightning allows multiple ways of training
- Data Parallel (`distributed_backend='dp'`) (multiple-gpus, 1 machine)
- DistributedDataParallel (`distributed_backend='ddp'`) (multiple-gpus across many machines).
- DistributedDataParallel2 (`distributed_backend='ddp2'`) (dp in a machine, ddp across machines).
- Horovod (`distributed_backend='horovod'`) (multi-machine, multi-gpu, configured at runtime)
- TPUs (`num_tpu_cores=8|x`) (tpu or TPU pod)

Data Parallel (dp)
Expand Down Expand Up @@ -136,6 +137,43 @@ In this case, we can use ddp2 which behaves like dp in a machine and ddp across
# train on 32 GPUs (4 nodes)
trainer = pl.Trainer(gpus=8, distributed_backend='ddp2', num_nodes=4)

Horovod
^^^^^^^
`Horovod <http://horovod.ai>`_ allows the same training script to be used for single-GPU,
multi-GPU, and multi-node training.

Like Distributed Data Parallel, every process in Horovod operates on a single GPU with a fixed
subset of the data. Gradients are averaged across all GPUs in parallel during the backward pass,
then synchronously applied before beginning the next step.

The number of worker processes is configured by a driver application (`horovodrun` or `mpirun`). In
the training script, Horovod will detect the number of workers from the environment, and automatically
scale the learning rate to compensate for the increased total batch size.

Horovod can be configured in the training script to run with any number of GPUs / processes as follows:

.. code-block:: python

# train Horovod on GPU (number of GPUs / machines provided on command-line)
trainer = pl.Trainer(distributed_backend='horovod', gpus=1)

# train Horovod on CPU (number of processes / machines provided on command-line)
trainer = pl.Trainer(distributed_backend='horovod')

When starting the training job, the driver application will then be used to specify the total
number of worker processes:

.. code-block:: bash

# run training with 4 GPUs on a single machine
horovodrun -np 4 python train.py

# run training with 8 GPUs on two machines (4 GPUs each)
horovodrun -np 8 -H hostname1:4,hostname2:4 python train.py

See the official `Horovod documentation <https://horovod.readthedocs.io/en/stable>`_ for details
on installation and performance tuning.

DP/DDP2 caveats
^^^^^^^^^^^^^^^
In DP and DDP2 each GPU within a machine sees a portion of a batch.
Expand Down
18 changes: 17 additions & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@
else:
XLA_AVAILABLE = True

try:
import horovod.torch as hvd
except ImportError:
HOROVOD_AVAILABLE = False
else:
HOROVOD_AVAILABLE = True
williamFalcon marked this conversation as resolved.
Show resolved Hide resolved


def _has_len(dataloader: DataLoader) -> bool:
""" Checks if a given Dataloader has __len__ method implemented i.e. if
Expand All @@ -47,6 +54,7 @@ class TrainerDataLoadingMixin(ABC):
proc_rank: int
use_ddp: bool
use_ddp2: bool
use_horovod: bool
Comment on lines 55 to +57
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just thinking that this may become quite messy soon, as all are bool and but in fact, it shall be an enum, right as you cannot use do and ddp2 at the same time

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we can do a refactor for this later

shown_warnings: ...
val_check_interval: float
use_tpu: bool
Expand Down Expand Up @@ -89,7 +97,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
# don't do anything if it's not a dataloader
if not isinstance(dataloader, DataLoader):
return dataloader
need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_tpu)
need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu)
if self.replace_sampler_ddp and need_dist_sampler:

skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']
Expand All @@ -104,6 +112,10 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
)
elif self.use_horovod:
sampler = DistributedSampler(dataloader.dataset,
num_replicas=hvd.size(),
rank=hvd.rank())
else:
world_size = {
'ddp': self.num_nodes * self.num_processes,
Expand Down Expand Up @@ -254,6 +266,10 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
# all processes wait until data download has happened
torch_xla.core.xla_model.rendezvous('pl.TrainerDataLoadingMixin.get_dataloaders')

elif self.use_horovod:
# all processes wait until data download has happened
hvd.join()

return dataloader

def determine_data_use_amount(self, train_percent_check: float, val_percent_check: float,
Expand Down
35 changes: 34 additions & 1 deletion pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ def train_fx(trial_hparams, cluster_manager, _):
else:
APEX_AVAILABLE = True

try:
import horovod.torch as hvd
except ImportError:
HOROVOD_AVAILABLE = False
else:
HOROVOD_AVAILABLE = True


class TrainerDDPMixin(ABC):

Expand Down Expand Up @@ -178,10 +185,14 @@ def set_distributed_mode(self, distributed_backend):
self.use_dp = False
self.use_ddp = False
self.use_ddp2 = False
self.use_horovod = False
self.single_gpu = False

if distributed_backend is None:
if self.num_gpus == 0:
if self.has_horovodrun():
self.check_horovod()
self.use_horovod = True
elif self.num_gpus == 0:
if self.num_nodes > 1 or self.num_processes > 1:
self.use_ddp = True # ddp_cpu
elif self.num_gpus == 1:
Expand Down Expand Up @@ -219,6 +230,9 @@ def set_distributed_mode(self, distributed_backend):
self.use_ddp = True
self.data_parallel_device_ids = None
self.on_gpu = False
elif distributed_backend == 'horovod':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nice to be transparent to the user.
can we automate setting this? this way the abstraction doesn’t bleed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(the mpirun thing)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure I understand you correctly: is the idea that when running via horovodrun or mpirun, if the user has not specified distributed_backend, then we will automatically set distributed_backend='horovod' here?

We could certainly do that when running with horovodrun + our Gloo backend, as we have special environment variables we can check (HOROVOD_RANK for example). Doing so with mpirun is more tricky, because different MPI implementations have different environment variables. Also, in the future, there might be another distributed backend other than Horovod that uses MPI.

So maybe we could automate it for horovodrun but still require them to set it explicitly for mpirun? (Let me know if I misunderstood your suggestion).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure I understand you correctly: is the idea that when running via horovodrun or mpirun, if the user has not specified distributed_backend, then we will automatically set distributed_backend='horovod' here?

Yes!

So maybe we could automate it for horovodrun but still require them to set it explicitly for mpirun? (Let me know if I misunderstood your suggestion).

Let's do this for now (v1) and for v2 maybe we set it explicitely for mpirun? i just don't know enough about mpirun yet, but if mpirun can run any backend then the user should be forced to set it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! I added a has_horovodrun() check in distrib_data_parallel.py that checks for Gloo or OpenMPI environment variables set by horovodrun. Also added a test. Let me know if that aligns with what you were thinking.

self.check_horovod()
self.use_horovod = True

# throw error to force user ddp or ddp2 choice
if self.num_nodes > 1 and not (self.use_ddp2 or self.use_ddp):
Expand Down Expand Up @@ -402,3 +416,22 @@ def resolve_root_node_address(self, root_node):
root_node = name + number

return root_node

def check_horovod(self):
"""Raises a `MisconfigurationException` if the Trainer is not configured correctly for Horovod."""
if not HOROVOD_AVAILABLE:
raise MisconfigurationException(
'Requested `distributed_backend="horovod"`, but Horovod is not installed.'
'Install with \n $HOROVOD_WITH_PYTORCH=1 pip install horovod[pytorch]'
)

if self.num_gpus > 1 or self.num_nodes > 1:
raise MisconfigurationException(
'Horovod does not support setting num_nodes / num_gpus explicitly. Use '
'horovodrun / mpirun to configure the number of processes.'
)

@staticmethod
def has_horovodrun():
"""Returns True if running with `horovodrun` using Gloo or OpenMPI."""
return 'OMPI_COMM_WORLD_RANK' in os.environ or 'HOROVOD_RANK' in os.environ
69 changes: 69 additions & 0 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,13 +337,16 @@

"""

from contextlib import ExitStack
import os
from abc import ABC, abstractmethod
import time
import random
import torch
from typing import Union

from pytorch_lightning import _logger as log
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.overrides.data_parallel import (
LightningDistributedDataParallel,
LightningDataParallel,
Expand All @@ -365,6 +368,13 @@
else:
XLA_AVAILABLE = True

try:
import horovod.torch as hvd
except ImportError:
HOROVOD_AVAILABLE = False
else:
HOROVOD_AVAILABLE = True


class TrainerDPMixin(ABC):

Expand All @@ -385,6 +395,7 @@ class TrainerDPMixin(ABC):
tpu_global_core_rank: int
use_tpu: bool
data_parallel_device_ids: ...
logger: Union[LightningLoggerBase, bool]

@property
@abstractmethod
Expand Down Expand Up @@ -540,6 +551,64 @@ def dp_train(self, model):

self.run_pretrain_routine(model)

def horovod_train(self, model):
# Horovod: initialize library
hvd.init()

if torch.cuda.is_available() and self.on_gpu:
# Horovod: pin GPU to local rank
torch.cuda.set_device(hvd.local_rank())
model.cuda(hvd.local_rank())

# Only show progress bar from the first worker
self.progress_bar_refresh_rate = self.progress_bar_refresh_rate if hvd.rank() == 0 else 0

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)

# Horovod: scale the learning rate by the number of workers to account for
# increased total batch size
for optimizer in self.optimizers:
for param_group in optimizer.param_groups:
param_group['lr'] *= hvd.size()

if self.use_amp:
# An example
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.optimizers = optimizers

# Horovod: broadcast parameters & optimizer state to ensure consistent initialization
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
for optimizer in self.optimizers:
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

def filter_named_parameters(model, optimizer):
opt_params = set([p for group in optimizer.param_groups for p in group.get('params', [])])
return [(name, p) for name, p in model.named_parameters() if p in opt_params]

# Horovod: wrap optimizers to perform gradient aggregation via allreduce
self.optimizers = [
hvd.DistributedOptimizer(optimizer, named_parameters=filter_named_parameters(model, optimizer))
for optimizer in self.optimizers
]

# Update logger rank info from Horovod to avoid race conditions from different ranks
# creating directories / writing files in the same locations.
self.proc_rank = hvd.rank()
set_proc_rank(self.proc_rank)
if self.logger:
self.logger.rank = self.proc_rank
if model.logger:
model.logger.rank = self.proc_rank
Comment on lines +600 to +603
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the global set the loggers does not need it, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's refactor this later in a separate PR

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, it was not meant to do now but it would be good...


with ExitStack() as stack:
for optimizer in self.optimizers:
# Synchronization will be performed explicitly following backward()
stack.enter_context(optimizer.skip_synchronize())

self.run_pretrain_routine(model)


def normalize_parse_gpu_string_input(s):
if isinstance(s, str):
Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@
else:
XLA_AVAILABLE = True

try:
import horovod.torch as hvd
except ImportError:
HOROVOD_AVAILABLE = False
else:
HOROVOD_AVAILABLE = True


class TrainerEvaluationLoopMixin(ABC):

Expand All @@ -153,9 +160,11 @@ class TrainerEvaluationLoopMixin(ABC):
test_progress_bar: ...
val_progress_bar: ...
main_progress_bar: ...
on_gpu: bool
use_ddp: bool
use_dp: bool
use_ddp2: bool
use_horovod: bool
single_gpu: bool
data_parallel_device_ids: ...
model: LightningModule
Expand Down Expand Up @@ -429,6 +438,11 @@ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode:
output = model(*args)
return output

# Horovod
if self.use_horovod and self.on_gpu:
batch = self.transfer_batch_to_gpu(batch, hvd.local_rank())
args[0] = batch

# single GPU data transfer
if self.single_gpu:
# for single GPU put inputs on gpu manually
Expand Down
Loading