Skip to content
This repository has been archived by the owner on Sep 28, 2022. It is now read-only.

Commit

Permalink
Supporting Adding DDP Communication Hooks (Lightning-AI#6736)
Browse files Browse the repository at this point in the history
* Fix some test errors
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* checkpoint consolidation

* Update ddp_spawn.py

* Update test_metric_result_integration.py

* Update test_results.py

* Update utils.py

* Update utils.py

* Update test_all_gather_grad.py

* Update test_all_gather_grad.py

* Update test_results.py

* Revert "Update test_results.py"

This reverts commit 9d4a2b8.

* Revert "Merge pull request #1 from shuyingsunshine21/shuyingsunshine21-checkpoint_consolidate"

This reverts commit c5053da, reversing
changes made to 0d23d75.

* Revert "Update test_all_gather_grad.py"

This reverts commit 0d23d75.

* Revert "Update utils.py"

This reverts commit 70fe5da.

* Revert "Update utils.py"

This reverts commit a9aae99.

* Revert "Update test_results.py"

This reverts commit ea74906.

* Revert "Update test_metric_result_integration.py"

This reverts commit bf70e43.

* Revert "Update ddp_spawn.py"

This reverts commit f172101.

* Revert "checkpoint consolidation"

This reverts commit 536c132.

* Revert "Revert "checkpoint consolidation""

This reverts commit 3a9fde9.

* Revert "Revert "Revert "checkpoint consolidation"""

This reverts commit 7a369f4.

* Revert "Revert "Update ddp_spawn.py""

This reverts commit 8222dc9.

* Revert "Revert "Update test_metric_result_integration.py""

This reverts commit 6c095b2.

* Revert "Revert "Update test_results.py""

This reverts commit 250d0aa.

* Revert "Revert "Update utils.py""

This reverts commit 8651d54.

* Revert "Revert "Update test_all_gather_grad.py""

This reverts commit dcdcd29.

* modify distributed environment to make test pass

* add DDP communication hook

* remove test related setting

* remove more test related setting

* fix ddp comm hook util import issue

* comments

* one more fix for test_custom_plugin

* fix ddp spwan

* fix sgd

* address comments and add tests

* 1. add is gpu checking 2. modify test a bit 3. formatting

* formatting nit

* fix conda 3.7 1.7 issue for no torch.distributed.algorithms module

* need at least 1.8.0

* minor fix

* modify changelog

* changelog should link to PR number instead of issue number

* refine a bit on doc for register_ddp_comm_hook function, like ddp_comm_wrapper explanation and add hyperparameter for power sgd states in example usge

* move single device checking before call register_ddp_comm_hook

* formatting

* comments

* typo

* pre-commit formatting
  • Loading branch information
shuyingsunshine21 authored Apr 7, 2021
1 parent 86e1d9f commit 313e816
Show file tree
Hide file tree
Showing 9 changed files with 309 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for `precision=64`, enabling training with double precision ([#6595](https://github.com/PyTorchLightning/pytorch-lightning/pull/6595))

- Added support for DDP communication hooks ([#6736](https://github.com/PyTorchLightning/pytorch-lightning/issues/6736))

- Added `artifact_location` argument to `MLFlowLogger` which will be passed to the `MlflowClient.create_experiment` call ([#6677](https://github.com/PyTorchLightning/pytorch-lightning/pull/6677))

Expand Down
31 changes: 30 additions & 1 deletion pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,21 @@
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, rank_zero_warn
from pytorch_lightning.utilities import (
_HYDRA_AVAILABLE,
_TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8,
rank_zero_warn,
)
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything

if _HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path
if _TORCH_GREATER_EQUAL_1_8:
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook

log = logging.getLogger(__name__)

Expand All @@ -58,6 +65,9 @@ def __init__(
num_nodes: int = 1,
cluster_environment: ClusterEnvironment = None,
sync_batchnorm: bool = False,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
**kwargs: Union[Any, Dict[str, Any]],
) -> None:
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
Expand All @@ -70,6 +80,9 @@ def __init__(
self.task_idx = None
self.node_rank = 0
self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices
self._ddp_comm_state = ddp_comm_state
self._ddp_comm_hook = ddp_comm_hook
self._ddp_comm_wrapper = ddp_comm_wrapper

@property
def root_device(self):
Expand All @@ -80,6 +93,10 @@ def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
return distributed_sampler_kwargs

@property
def _is_single_process_single_device(self) -> bool:
return True

def setup_environment(self):
# start the other scripts
if not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1":
Expand Down Expand Up @@ -218,13 +235,25 @@ def pre_configure_ddp(self):
)
self._ddp_kwargs["find_unused_parameters"] = True

def _register_ddp_hooks(self) -> None:
# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
if (_TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device):
register_ddp_comm_hook(
model=self._model,
ddp_comm_state=self._ddp_comm_state,
ddp_comm_hook=self._ddp_comm_hook,
ddp_comm_wrapper=self._ddp_comm_wrapper,
)

def configure_ddp(self):
self.pre_configure_ddp()
self._model = DistributedDataParallel(
LightningDistributedModule(self.model),
device_ids=self.determine_ddp_device_ids(),
**self._ddp_kwargs,
)
self._register_ddp_hooks()

def determine_ddp_device_ids(self):
if self.root_device.type == "cpu":
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=self.num_nodes, rank=self.global_rank)
return distributed_sampler_kwargs

@property
def _is_single_process_single_device(self) -> bool:
return False

def set_world_ranks(self):
self.local_rank = self.task_idx
self.node_rank = self.cluster_environment.node_rank()
Expand Down
30 changes: 27 additions & 3 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,15 @@
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.seed import seed_everything

if _TORCH_GREATER_EQUAL_1_8:
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook

log = logging.getLogger(__name__)


Expand All @@ -47,16 +50,22 @@ def __init__(
num_nodes: int = 1,
cluster_environment: ClusterEnvironment = None,
sync_batchnorm: bool = False,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
**kwargs: Union[Any, Dict[str, Any]],
):
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
self.num_nodes = num_nodes
self.sync_batchnorm = sync_batchnorm
self._ddp_kwargs = kwargs
self.dist = LightningDistributed()
self.num_processes = len(parallel_devices)
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
self.node_rank = 0
self.mp_queue = None
self._ddp_comm_state = ddp_comm_state
self._ddp_comm_hook = ddp_comm_hook
self._ddp_comm_wrapper = ddp_comm_wrapper

def __getstate__(self):
""" Makes this plugin pickleable without destroying the queue in the current process. """
Expand All @@ -76,9 +85,12 @@ def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
return distributed_sampler_kwargs

@property
def _is_single_process_single_device(self):
return True

def setup(self, model):
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())

# pass in a state q
smp = mp.get_context("spawn")
self.mp_queue = smp.SimpleQueue()
Expand Down Expand Up @@ -181,13 +193,25 @@ def pre_configure_ddp(self):
)
self._ddp_kwargs["find_unused_parameters"] = True

def _register_ddp_hooks(self) -> None:
# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
if (_TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device):
register_ddp_comm_hook(
model=self._model,
ddp_comm_state=self._ddp_comm_state,
ddp_comm_hook=self._ddp_comm_hook,
ddp_comm_wrapper=self._ddp_comm_wrapper,
)

def configure_ddp(self):
self.pre_configure_ddp()
self._model = DistributedDataParallel(
LightningDistributedModule(self.model),
device_ids=self.determine_ddp_device_ids(),
**self._ddp_kwargs,
)
self._register_ddp_hooks()

def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
# TODO: this code is duplicated in DDP and DDPSpawn, make this a function
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
_RPC_AVAILABLE,
_TORCH_GREATER_EQUAL_1_6,
_TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8,
_TORCH_GREATER_EQUAL_1_9,
_TORCH_LOWER_EQUAL_1_4,
_TORCH_QUANTIZE_AVAILABLE,
_TORCHTEXT_AVAILABLE,
Expand Down
110 changes: 110 additions & 0 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@
import warnings
from functools import partial, wraps
from typing import Any, Optional, Union
from pytorch_lightning.utilities.imports import (
_TORCH_GREATER_EQUAL_1_8,
_TORCH_GREATER_EQUAL_1_9,
)

import torch

from torch.nn.parallel.distributed import DistributedDataParallel

log = logging.getLogger(__name__)

if torch.distributed.is_available():
Expand Down Expand Up @@ -208,3 +214,107 @@ def all_gather_ddp_if_available(
with torch.no_grad():
return AllGatherGrad.apply(tensor, group)
return tensor


def register_ddp_comm_hook(
model: DistributedDataParallel,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
) -> None:
"""
Function to register communication hook for DDP model
https://pytorch.org/docs/master/ddp_comm_hooks.html
Args:
model:
DDP model
ddp_comm_state:
state is passed to the hook and can be used to maintain
and update any state information that users would like to
maintain as part of the training process. Examples: error
feedback in gradient compression, peers to communicate with
next in GossipGrad etc.
ddp_comm_hook:
hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future
This callable function is called once the bucket is ready. The
hook can perform whatever processing is needed and return
a Future indicating completion of any async work (ex: allreduce).
If the hook doesn't perform any communication, it can also
just return a completed Future. The Future should hold the
new value of grad bucket's tensors. Once a bucket is ready,
c10d reducer would call this hook and use the tensors returned
by the Future and copy grads to individual parameters.
ddp_comm_wrapper:
communication hook wraper to support a communication hook such
as FP16 compression as wrapper, which could be combined with
ddp_comm_hook
.. warning ::
DDP communication hook needs pytorch version at least 1.8.0
.. warning ::
DDP communication wrapper needs pytorch version at least 1.9.0
Example:
from torch.distributed.algorithms.ddp_comm_hooks import (
default_hooks as default,
powerSGD_hook as powerSGD,
)
# fp16_compress_hook for compress gradients
register_ddp_comm_hook(
model=ddp_model,
ddp_comm_hook=default.fp16_compress_hook,
)
# powerSGD_hook
register_ddp_comm_hook(
model=ddp_model,
ddp_comm_state=powerSGD.PowerSGDState(
process_group=None,
matrix_approximation_rank=1,
start_powerSGD_iter=5000,
),
ddp_comm_hook=powerSGD.powerSGD_hook,
)
# fp16_compress_wrapper combined with other communication hook
register_ddp_comm_hook(
model=ddp_model,
ddp_comm_state=powerSGD.PowerSGDState(
process_group=None,
matrix_approximation_rank=1,
start_powerSGD_iter=5000,
),
ddp_comm_hook=powerSGD.powerSGD_hook,
ddp_comm_wrapper=default.fp16_compress_wrapper,
)
"""
if not _TORCH_GREATER_EQUAL_1_8:
rank_zero_warn(
"Not registering DDP comm hook. "
"To use communication hooks, please use pytorch>=1.8.0."
)
return
if ddp_comm_hook is None:
return
if ddp_comm_wrapper is not None:
if not _TORCH_GREATER_EQUAL_1_9:
rank_zero_warn(
"Not applying DDP comm wrapper. "
"To use communication wrapper, please use pytorch>=1.9.0."
)
else:
rank_zero_info(
f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})."
)
ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook)

rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.")
model.register_comm_hook(
state=ddp_comm_state,
hook=ddp_comm_hook,
)
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _compare_version(package: str, op, version) -> bool:
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
_TORCH_GREATER_EQUAL_1_8 = _compare_version("torch", operator.ge, "1.8.0")
_TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0")

_APEX_AVAILABLE = _module_available("apex.amp")
_BOLTS_AVAILABLE = _module_available('pl_bolts')
Expand Down
1 change: 0 additions & 1 deletion tests/plugins/test_custom_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@


class CustomParallelPlugin(DDPPlugin):

def __init__(self, **kwargs):
super().__init__(**kwargs)
# Set to None so it will be overwritten by the accelerator connector.
Expand Down
Loading

0 comments on commit 313e816

Please sign in to comment.