Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Sagemaker DDP Plugin #6271

Closed
wants to merge 47 commits into from
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
ad77ff2
feat: add smddp plugin & environment
kaushikb11 Mar 1, 2021
270e4df
update accelerator connector
kaushikb11 Mar 1, 2021
33c9891
update smddp plugin
kaushikb11 Mar 1, 2021
0380efb
update smddp plugin
kaushikb11 Mar 1, 2021
c115ece
update lightning distributed
kaushikb11 Mar 1, 2021
50c045d
fix typo
kaushikb11 Mar 1, 2021
eab4f58
update smddp plugin
kaushikb11 Mar 1, 2021
f233b1a
add comment for parallel devices
kaushikb11 Mar 1, 2021
2e9280c
debug
kaushikb11 Mar 1, 2021
c91e9f2
debug
kaushikb11 Mar 1, 2021
a13a675
ddp name consistency
kaushikb11 Mar 8, 2021
88b2b4b
fix global rank
kaushikb11 Mar 8, 2021
ffc85ea
update smdist environment
kaushikb11 Mar 8, 2021
833bb57
Update Type Annotation
kaushikb11 Mar 8, 2021
467e76b
DDP plugin as base class
kaushikb11 Mar 8, 2021
c2a508c
add test
kaushikb11 Mar 8, 2021
f5a2cf4
Merge branch 'master' into smddp
kaushikb11 Mar 10, 2021
f5675e8
add creates_children mthod
kaushikb11 Mar 10, 2021
0dff5c7
change backend to mpi
kaushikb11 Mar 10, 2021
eacf9a8
set broadcast buffers set to False
kaushikb11 Mar 10, 2021
d1bf909
mini refactor
kaushikb11 Mar 10, 2021
da66a19
address reviews
kaushikb11 Mar 10, 2021
fdaeb5b
address reviews
kaushikb11 Mar 10, 2021
01b2d37
add missing init
kaushikb11 Mar 10, 2021
af070e3
change backend
kaushikb11 Mar 10, 2021
b7e5548
change all reduce
kaushikb11 Mar 11, 2021
1373f8f
add type hints
kaushikb11 Mar 11, 2021
9817608
Add missing Import
kaushikb11 Mar 11, 2021
937b50c
broadcast fix
kaushikb11 Mar 11, 2021
c7c16ba
Change SMDDP to DDPSM
kaushikb11 Mar 11, 2021
281231e
Update num gpus
kaushikb11 Mar 16, 2021
6c2f229
fix
kaushikb11 Mar 16, 2021
d514a6f
Update accelerator_connector.py
kaushikb11 Mar 16, 2021
8605f81
Merge branch 'master' into smddp
kaushikb11 Mar 31, 2021
1c4a315
Update changelog
kaushikb11 Mar 31, 2021
4ac0c4e
Merge branch 'master' into smddp
kaushikb11 Mar 31, 2021
9512dc4
Merge branch 'master' into smddp
kaushikb11 Jul 3, 2021
13dac0b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2021
e38fb37
Fix circular import
kaushikb11 Jul 8, 2021
7797770
Update environment
kaushikb11 Jul 8, 2021
4a0f78d
Add set_world_ranks
kaushikb11 Jul 8, 2021
574887f
Add updates
kaushikb11 Jul 13, 2021
19fa542
Add updates
kaushikb11 Jul 13, 2021
5699a32
Fix broadcasting
kaushikb11 Jul 13, 2021
7a715a0
Fix broadcasting
kaushikb11 Jul 13, 2021
e43fc1b
Update group
kaushikb11 Jul 13, 2021
1ac88a8
Update logger
kaushikb11 Jul 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin # noqa: F401
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.smddp import SMDDPPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401

Expand All @@ -44,4 +45,5 @@
'Plugin',
'DDPShardedPlugin',
'DDPSpawnShardedPlugin',
'SMDDPPlugin',
]
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
# limitations under the License.
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.smdist_environment import SMDistributedEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment # noqa: F401
54 changes: 54 additions & 0 deletions pytorch_lightning/plugins/environments/smdist_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from pytorch_lightning import _logger as log
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.utilities import _SMDIST_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _SMDIST_AVAILABLE:
import smdistributed.dataparallel.torch.distributed as dist


class SMDistributedEnvironment(ClusterEnvironment):
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self):
if not _SMDIST_AVAILABLE:
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
raise MisconfigurationException("`smdistributed` module is not available.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a small description on how to make this work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise MisconfigurationException("`smdistributed` module is not available.")
raise MisconfigurationException("`smdistributed` package is not available.")

also add how to instal it

super().__init__()

def master_address(self):
master_address = os.environ['SM_CURRENT_HOST']
log.debug(f"MASTER_ADDR: {master_address}")
return master_address

def master_port(self):
if "MASTER_PORT" not in os.environ:
rank_zero_warn("MASTER_PORT environment variable is not defined. Set as 12910")
os.environ["MASTER_PORT"] = "12910"
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

port = os.environ.get('MASTER_PORT')
return port

def world_size(self) -> int:
return dist.get_world_size()

def local_rank(self) -> int:
return dist.get_local_rank()

def node_rank(self) -> int:
return dist.get_rank()
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin # noqa: F401
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.smddp import SMDDPPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def block_backward_sync(self):
else:
yield None

def broadcast(self, obj: object, src: int) -> object:
def broadcast(self, obj: object, src: int = 0) -> object:
buffer = io.BytesIO()
torch.save(obj, buffer)
data = bytearray(buffer.getbuffer())
Expand Down
260 changes: 260 additions & 0 deletions pytorch_lightning/plugins/training_type/smddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
from typing import Any, Dict, List, Optional, Union

import torch
from torch.optim import Optimizer

from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities import _SMDIST_AVAILABLE
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything

if _SMDIST_AVAILABLE:
import smdistributed.dataparallel.torch.distributed as dist
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel


class SMDDPPlugin(ParallelPlugin):
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved

distributed_backend = "smddp"

def __init__(
self,
cluster_environment: Optional[ClusterEnvironment] = None,
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
sync_batchnorm: bool = False,
**kwargs: Union[Any, Dict[str, Any]],
):
if not _SMDIST_AVAILABLE:
raise MisconfigurationException("`smdistributed` module is not available.")

# While running smdistributed, all the gpus in the instance are considered
parallel_device_ids = list(range(torch.cuda.device_count()))
self.parallel_devices = [torch.device("cuda", i) for i in parallel_device_ids]

super().__init__(parallel_devices=self.parallel_devices, cluster_environment=cluster_environment)

self.sync_batchnorm = sync_batchnorm
self.dist = SMLightningDistributed()
self.num_nodes = len(os.environ['SM_HOSTS'])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something we maybe should extend the Environment class by?

cc @awaelchli

self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else self.parallel_devices

@property
def root_device(self):
return self.parallel_devices[self.local_rank]

@property
def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
return distributed_sampler_kwargs

def training_step(self, *args, **kwargs):
return self.model(*args, **kwargs)

def validation_step(self, *args, **kwargs):
return self.model(*args, **kwargs)

def test_step(self, *args, **kwargs):
return self.model(*args, **kwargs)

def predict(self, *args, **kwargs):
return self.model(*args, **kwargs)

def post_training_step(self):
if not self.lightning_module.automatic_optimization:
self.model.require_backward_grad_sync = True

def barrier(self, *args, **kwargs) -> None:
if dist.is_initialized():
dist.barrier()

def broadcast(self, obj: object, src: int = 0) -> object:
return self.dist.broadcast(obj)

def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
"""Run before precision plugin executes backward"""
if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync:
prepare_for_backward(self.model, closure_loss)

def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"):
"""
Reduces a tensor from several distributed processes to one aggregated tensor.
As this plugin only operates with a single device, the reduction is simply the identity.

Args:
tensor: the tensor to sync and reduce
*args: ignored
**kwargs: ignored

Return:
the unmodified input as reduction is not needed for single process operation
"""
if isinstance(tensor, torch.Tensor):
tensor = self.sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean"))
return tensor

@property
def lightning_module(self):
return self.unwrap_lightning_module()

def setup(self, model):
self._model = model

self.node_rank = self.cluster_environment.node_rank()
self.local_rank = self.cluster_environment.local_rank()
self.global_rank = self.node_rank * self.num_processes + self.local_rank
self.world_size = self.cluster_environment.world_size()

rank_zero_only.rank = self.global_rank
self.model_to_device()

def pre_dispatch(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's different in this function from the ddp plugin? Can we remove this and inherit it from the DDP Plugin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only difference is calling self.set_world_ranks(), as I am able to get it through the Environment methods using the smdistributed module. For eg., dist.get_world_size()

# TODO: check if needed
seed = os.environ.get("PL_GLOBAL_SEED")
if seed is not None:
seed_everything(int(seed))

# set warning rank
rank_zero_only.rank = self.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
self.init_smddp_connection(self.global_rank, self.world_size)
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved

# TODO: we moved it to the trainer.fit after calling pre_dispatch
# ... need to double check that it is the correct place
# self.trainer.call_setup_hook(self.model)

# on world_size=0 let everyone know training is starting
if self.is_global_zero and not dist.is_initialized():
log.info("-" * 100)
log.info(f"distributed_backend={self.distributed_backend}")
log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes")
log.info("-" * 100)

# # set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device

if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)

# move the model to the correct device
self.model_to_device()

self.configure_smddp()
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved

self.barrier()
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved

def model_to_device(self):
if self.on_gpu:
torch.cuda.set_device(self.root_device)
self.model.to(self.root_device)

def init_smddp_connection(self, global_rank: int, world_size: int) -> None:

if not dist.is_initialized():
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
dist.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)

def configure_smddp(self):
self._model = DistributedDataParallel(
LightningDistributedModule(self.model),
device_ids=[dist.get_local_rank()],
)

def sync_ddp_if_available(
self,
result: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None
) -> torch.Tensor:
"""
Function to reduce a tensor across worker processes during distributed training
Args:
result: the value to sync and reduce (typically tensor or number)
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum.
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.

Return:
reduced value
"""
if dist.is_available() and dist.is_initialized():
return self.sync_ddp(result, group=group, reduce_op=reduce_op)
return result

def sync_ddp(
self,
result: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None
) -> torch.Tensor:
"""
Function to reduce the tensors from several ddp processes to one master process

Args:
result: the value to sync and reduce (typically tensor or number)
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum.
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.

Return:
reduced value
"""
divide_by_world_size = False

if group is None:
group = dist.group.WORLD

op = reduce_op if isinstance(reduce_op, ReduceOp) else ReduceOp.SUM

if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
divide_by_world_size = True

# sync all processes before reduction
dist.barrier(group=group)
dist.all_reduce(result, op=op, group=group, async_op=False)

if divide_by_world_size:
result = result / dist.get_world_size(group)

return result

def unwrap_lightning_module(self) -> LightningModule:
model = self._model
if isinstance(model, (DistributedDataParallel)):
model = model.module
if isinstance(model, _LightningModuleWrapperBase):
model = model.module
return model


class SMLightningDistributed(LightningDistributed):
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved

def _broadcast(self, tensor, src, group):
if group is None:
return dist.broadcast(tensor, src=src)
return dist.broadcast(tensor, src=0, group=group)
16 changes: 15 additions & 1 deletion pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,17 @@
ShardedNativeMixedPrecisionPlugin,
SingleDevicePlugin,
SingleTPUPlugin,
SMDDPPlugin,
TPUHalfPrecisionPlugin,
TPUSpawnPlugin,
TrainingTypePlugin,
)
from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment
from pytorch_lightning.plugins.environments import (
ClusterEnvironment,
SLURMEnvironment,
SMDistributedEnvironment,
TorchElasticEnvironment,
)
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
from pytorch_lightning.utilities import (
_APEX_AVAILABLE,
Expand Down Expand Up @@ -299,6 +305,10 @@ def is_using_torchelastic(self) -> bool:
te_flags_passed = "WORLD_SIZE" in os.environ and ("GROUP_RANK" in os.environ or "NODE_RANK" in os.environ)
return te_flags_passed

@property
def use_smdistributed(self) -> bool:
return self.distributed_backend == DistributedType.SMDDP

kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
def select_precision_plugin(self) -> PrecisionPlugin:
# set precision type
self.amp_type = AMPType.from_str(self.amp_type)
Expand Down Expand Up @@ -396,6 +406,8 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
plugin = DataParallelPlugin(parallel_devices=self.parallel_devices)
elif self.use_horovod:
plugin = HorovodPlugin(parallel_devices=self.parallel_devices)
elif self.use_smdistributed:
plugin = SMDDPPlugin(cluster_environment=self.cluster_environment, sync_batchnorm=self.sync_batchnorm)
elif self.on_tpu:
if isinstance(self.tpu_cores, list):
plugin = SingleTPUPlugin(self.tpu_id)
Expand Down Expand Up @@ -457,6 +469,8 @@ def select_cluster_environment(self) -> ClusterEnvironment:
# TODO: decouple DDP from TE
# refactor and let generic cluster env hold the information about who spawns the processes
os.environ["PL_IN_DDP_SUBPROCESS"] = "1"
elif self.use_smdistributed:
env = SMDistributedEnvironment()
else:
# TODO: maybe introduce a DefaultEnvironment?
env = TorchElasticEnvironment()
Expand Down
Loading