Skip to content

Commit

Permalink
Initial commit of Horovod distributed backend implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Apr 22, 2020
1 parent 4d24032 commit a76736a
Show file tree
Hide file tree
Showing 17 changed files with 597 additions and 25 deletions.
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ references:
run:
name: Install Dependences
command: |
sudo apt-get update && sudo apt-get install -y cmake
pip install "$TORCH_VERSION"
pip install -r requirements.txt -q
sudo pip install pytest pytest-cov pytest-flake8 -q
Expand Down
11 changes: 10 additions & 1 deletion .drone.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,33 @@ name: torch-GPU

steps:
- name: testing
image: pytorch/pytorch:1.4-cuda10.1-cudnn7-runtime
image: pytorch/pytorch:1.4-cuda10.1-cudnn7-devel

environment:
SLURM_LOCALID: 0
CODECOV_TOKEN:
from_secret: codecov_token
HOROVOD_GPU_ALLREDUCE: NCCL
HOROVOD_GPU_BROADCAST: NCCL
HOROVOD_WITH_PYTORCH: 1
HOROVOD_WITHOUT_TENSORFLOW: 1
HOROVOD_WITHOUT_MXNET: 1
HOROVOD_WITH_GLOO: 1
HOROVOD_WITHOUT_MPI: 1

#volumes:
# # Mount pip cache from host
# - name: pip_cache
# path: /opt/conda/lib/python3.7/site-packages

commands:
- export PATH="$PATH:/root/.local/bin"
- python --version
- pip install pip -U
- pip --version
- nvidia-smi
- bash ./tests/install_AMP.sh
- apt-get update && apt-get install -y cmake
- pip install -r requirements.txt --user -q
- pip install coverage pytest pytest-cov pytest-flake8 codecov -q
- pip install -r ./tests/requirements.txt --user -q
Expand Down
11 changes: 9 additions & 2 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,15 @@ jobs:
if: runner.os == 'macOS'
run: |
brew install libomp # https://github.com/pytorch/pytorch/issues/20030
brew install openmpi # Horovod on macOS requires OpenMPI, Gloo not currently supported
# TODO: remove after https://github.com/pytorch/pytorch/issues/32186 is resolved
- name: Setup Windows
if: runner.os == 'windows'
run: |
python -c "lines = [line for line in open('requirements-extra.txt').readlines() if not line.startswith('horovod')] ; open('requirements-extra.txt', 'w').writelines(lines)"
# TODO: remove after https://github.com/pytorch/pytorch/issues/32186 is resolved
- name: Setup Windows on Latest
if: runner.os == 'windows' && matrix.requires == 'latest'
run: |
python -c "req = open('requirements.txt').read().replace('torch>=1.1', 'torch<1.5') ; open('requirements.txt', 'w').write(req)"
Expand Down Expand Up @@ -75,11 +81,12 @@ jobs:
run: |
# python -m pip install --upgrade --user pip
pip install -r requirements.txt -U -f https://download.pytorch.org/whl/torch_stable.html -q
pip install -r ./tests/requirements.txt -q
HOROVOD_BUILD_ARCH_FLAGS="-mfma" pip install -r ./tests/requirements.txt -q
# pip install tox coverage
python --version
pip --version
pip list
shell: bash

- name: Tests
# env:
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `ddp_cpu` backend for testing ddp without GPUs ([#1158](https://github.com/PyTorchLightning/pytorch-lightning/pull/1158))

- Added [Horovod](http://horovod.ai) support as a distributed backend `Trainer(distributed_backend='horovod')` ([#1529](https://github.com/PyTorchLightning/pytorch-lightning/pull/1529))

### Changed

Expand Down
38 changes: 38 additions & 0 deletions docs/source/multi_gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ Lightning allows multiple ways of training
- Data Parallel (`distributed_backend='dp'`) (multiple-gpus, 1 machine)
- DistributedDataParallel (`distributed_backend='ddp'`) (multiple-gpus across many machines).
- DistributedDataParallel2 (`distributed_backend='ddp2'`) (dp in a machine, ddp across machines).
- Horovod (`distributed_backend='horovod'`) (multi-machine, multi-gpu, configured at runtime)
- TPUs (`num_tpu_cores=8|x`) (tpu or TPU pod)

Data Parallel (dp)
Expand Down Expand Up @@ -136,6 +137,43 @@ In this case, we can use ddp2 which behaves like dp in a machine and ddp across
# train on 32 GPUs (4 nodes)
trainer = pl.Trainer(gpus=8, distributed_backend='ddp2', num_nodes=4)
Horovod
^^^^^^^
`Horovod <http://horovod.ai>`_ allows the same training script to be used for single-GPU,
multi-GPU, and multi-node training.

Like Distributed Data Parallel, every process in Horovod operates on a single GPU with a fixed
subset of the data. Gradients are averaged across all GPUs in parallel during the backward pass,
then synchronously applied before beginning the next step.

The number of worker processes is configured by a driver application (`horovodrun` or `mpirun`). In
the training script, Horovod will detect the number of workers from the environment, and automatically
scale the learning rate to compensate for the increased total batch size.

Horovod can be configured in the training script to run with any number of GPUs / processes as follows:

.. code-block:: python
# train Horovod on GPU (number of GPUs / machines provided on command-line)
trainer = pl.Trainer(distributed_backend='horovod', gpus=1)
# train Horovod on CPU (number of processes / machines provided on command-line)
trainer = pl.Trainer(distributed_backend='horovod')
When starting the training job, the driver application will then be used to specify the total
number of worker processes:

.. code-block:: bash
# run training with 4 GPUs on a single machine
horovodrun -np 4 python train.py
# run training with 8 GPUs on two machines (4 GPUs each)
horovodrun -np 8 -H hostname1:4,hostname2:4 python train.py
See the official `Horovod documentation <https://horovod.readthedocs.io/en/stable>`_ for details
on installation and performance tuning.

DP/DDP2 caveats
^^^^^^^^^^^^^^^
In DP and DDP2 each GPU within a machine sees a portion of a batch.
Expand Down
18 changes: 17 additions & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@
else:
XLA_AVAILABLE = True

try:
import horovod.torch as hvd
except ImportError:
HOROVOD_AVAILABLE = False
else:
HOROVOD_AVAILABLE = True


def _has_len(dataloader: DataLoader) -> bool:
""" Checks if a given Dataloader has __len__ method implemented i.e. if
Expand All @@ -47,6 +54,7 @@ class TrainerDataLoadingMixin(ABC):
proc_rank: int
use_ddp: bool
use_ddp2: bool
use_horovod: bool
shown_warnings: ...
val_check_interval: float
use_tpu: bool
Expand Down Expand Up @@ -89,7 +97,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
# don't do anything if it's not a dataloader
if not isinstance(dataloader, DataLoader):
return dataloader
need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_tpu)
need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu)
if self.replace_sampler_ddp and need_dist_sampler:

skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']
Expand All @@ -104,6 +112,10 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
)
elif self.use_horovod:
sampler = DistributedSampler(dataloader.dataset,
num_replicas=hvd.size(),
rank=hvd.rank())
else:
world_size = {
'ddp': self.num_nodes * self.num_processes,
Expand Down Expand Up @@ -254,6 +266,10 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
# all processes wait until data download has happened
torch_xla.core.xla_model.rendezvous('pl.TrainerDataLoadingMixin.get_dataloaders')

elif self.use_horovod:
# all processes wait until data download has happened
hvd.join()

return dataloader

def determine_data_use_amount(self, train_percent_check: float, val_percent_check: float,
Expand Down
36 changes: 35 additions & 1 deletion pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ def train_fx(trial_hparams, cluster_manager, _):
else:
APEX_AVAILABLE = True

try:
import horovod.torch as hvd
except ImportError:
HOROVOD_AVAILABLE = False
else:
HOROVOD_AVAILABLE = True


class TrainerDDPMixin(ABC):

Expand Down Expand Up @@ -178,10 +185,14 @@ def set_distributed_mode(self, distributed_backend):
self.use_dp = False
self.use_ddp = False
self.use_ddp2 = False
self.use_horovod = False
self.single_gpu = False

if distributed_backend is None:
if self.num_gpus == 0:
if self.has_horovodrun():
self.check_horovod()
self.use_horovod = True
elif self.num_gpus == 0:
if self.num_nodes > 1 or self.num_processes > 1:
self.use_ddp = True # ddp_cpu
elif self.num_gpus == 1:
Expand Down Expand Up @@ -219,6 +230,9 @@ def set_distributed_mode(self, distributed_backend):
self.use_ddp = True
self.data_parallel_device_ids = None
self.on_gpu = False
elif distributed_backend == 'horovod':
self.check_horovod()
self.use_horovod = True

# throw error to force user ddp or ddp2 choice
if self.num_nodes > 1 and not (self.use_ddp2 or self.use_ddp):
Expand Down Expand Up @@ -402,3 +416,23 @@ def resolve_root_node_address(self, root_node):
root_node = name + number

return root_node

def check_horovod(self):
"""Raises a `MisconfigurationException` if the Trainer is not configured correctly for Horovod."""
if not HOROVOD_AVAILABLE:
raise MisconfigurationException(
'Requested `distributed_backend="horovod"`, but Horovod is not available. See: '
'https://horovod.readthedocs.io/en/stable/install_include.html for installation '
'instructions.'
)

if self.num_gpus > 1 or self.num_nodes > 1:
raise MisconfigurationException(
'Horovod does not support setting num_nodes / num_gpus explicitly. Use '
'horovodrun / mpirun to configure the number of processes.'
)

@staticmethod
def has_horovodrun():
"""Returns True if running with `horovodrun` using Gloo or OpenMPI."""
return 'OMPI_COMM_WORLD_RANK' in os.environ or 'HOROVOD_RANK' in os.environ
69 changes: 69 additions & 0 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,13 +337,16 @@
"""

from contextlib import ExitStack
import os
from abc import ABC, abstractmethod
import time
import random
import torch
from typing import Union

from pytorch_lightning import _logger as log
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.overrides.data_parallel import (
LightningDistributedDataParallel,
LightningDataParallel,
Expand All @@ -365,6 +368,13 @@
else:
XLA_AVAILABLE = True

try:
import horovod.torch as hvd
except ImportError:
HOROVOD_AVAILABLE = False
else:
HOROVOD_AVAILABLE = True


class TrainerDPMixin(ABC):

Expand All @@ -385,6 +395,7 @@ class TrainerDPMixin(ABC):
tpu_global_core_rank: int
use_tpu: bool
data_parallel_device_ids: ...
logger: Union[LightningLoggerBase, bool]

@property
@abstractmethod
Expand Down Expand Up @@ -540,6 +551,64 @@ def dp_train(self, model):

self.run_pretrain_routine(model)

def horovod_train(self, model):
# Horovod: initialize library
hvd.init()

if torch.cuda.is_available() and self.on_gpu:
# Horovod: pin GPU to local rank
torch.cuda.set_device(hvd.local_rank())
model.cuda(hvd.local_rank())

# Only show progress bar from the first worker
self.progress_bar_refresh_rate = self.progress_bar_refresh_rate if hvd.rank() == 0 else 0

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)

# Horovod: scale the learning rate by the number of workers to account for
# increased total batch size
for optimizer in self.optimizers:
for param_group in optimizer.param_groups:
param_group['lr'] *= hvd.size()

if self.use_amp:
# An example
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.optimizers = optimizers

# Horovod: broadcast parameters & optimizer state to ensure consistent initialization
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
for optimizer in self.optimizers:
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

def filter_named_parameters(model, optimizer):
opt_params = set([p for group in optimizer.param_groups for p in group.get('params', [])])
return [(name, p) for name, p in model.named_parameters() if p in opt_params]

# Horovod: wrap optimizers to perform gradient aggregation via allreduce
self.optimizers = [
hvd.DistributedOptimizer(optimizer, named_parameters=filter_named_parameters(model, optimizer))
for optimizer in self.optimizers
]

# Update logger rank info from Horovod to avoid race conditions from different ranks
# creating directories / writing files in the same locations.
self.proc_rank = hvd.rank()
set_proc_rank(self.proc_rank)
if self.logger:
self.logger.rank = self.proc_rank
if model.logger:
model.logger.rank = self.proc_rank

with ExitStack() as stack:
for optimizer in self.optimizers:
# Synchronization will be performed explicitly following backward()
stack.enter_context(optimizer.skip_synchronize())

self.run_pretrain_routine(model)


def normalize_parse_gpu_string_input(s):
if isinstance(s, str):
Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@
else:
XLA_AVAILABLE = True

try:
import horovod.torch as hvd
except ImportError:
HOROVOD_AVAILABLE = False
else:
HOROVOD_AVAILABLE = True


class TrainerEvaluationLoopMixin(ABC):

Expand All @@ -153,9 +160,11 @@ class TrainerEvaluationLoopMixin(ABC):
test_progress_bar: ...
val_progress_bar: ...
main_progress_bar: ...
on_gpu: bool
use_ddp: bool
use_dp: bool
use_ddp2: bool
use_horovod: bool
single_gpu: bool
data_parallel_device_ids: ...
model: LightningModule
Expand Down Expand Up @@ -429,6 +438,11 @@ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode:
output = model(*args)
return output

# Horovod
if self.use_horovod and self.on_gpu:
batch = self.transfer_batch_to_gpu(batch, hvd.local_rank())
args[0] = batch

# single GPU data transfer
if self.single_gpu:
# for single GPU put inputs on gpu manually
Expand Down
Loading

0 comments on commit a76736a

Please sign in to comment.