Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Sagemaker DDP Plugin #6271

Closed
wants to merge 47 commits into from
Closed
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
ad77ff2
feat: add smddp plugin & environment
kaushikb11 Mar 1, 2021
270e4df
update accelerator connector
kaushikb11 Mar 1, 2021
33c9891
update smddp plugin
kaushikb11 Mar 1, 2021
0380efb
update smddp plugin
kaushikb11 Mar 1, 2021
c115ece
update lightning distributed
kaushikb11 Mar 1, 2021
50c045d
fix typo
kaushikb11 Mar 1, 2021
eab4f58
update smddp plugin
kaushikb11 Mar 1, 2021
f233b1a
add comment for parallel devices
kaushikb11 Mar 1, 2021
2e9280c
debug
kaushikb11 Mar 1, 2021
c91e9f2
debug
kaushikb11 Mar 1, 2021
a13a675
ddp name consistency
kaushikb11 Mar 8, 2021
88b2b4b
fix global rank
kaushikb11 Mar 8, 2021
ffc85ea
update smdist environment
kaushikb11 Mar 8, 2021
833bb57
Update Type Annotation
kaushikb11 Mar 8, 2021
467e76b
DDP plugin as base class
kaushikb11 Mar 8, 2021
c2a508c
add test
kaushikb11 Mar 8, 2021
f5a2cf4
Merge branch 'master' into smddp
kaushikb11 Mar 10, 2021
f5675e8
add creates_children mthod
kaushikb11 Mar 10, 2021
0dff5c7
change backend to mpi
kaushikb11 Mar 10, 2021
eacf9a8
set broadcast buffers set to False
kaushikb11 Mar 10, 2021
d1bf909
mini refactor
kaushikb11 Mar 10, 2021
da66a19
address reviews
kaushikb11 Mar 10, 2021
fdaeb5b
address reviews
kaushikb11 Mar 10, 2021
01b2d37
add missing init
kaushikb11 Mar 10, 2021
af070e3
change backend
kaushikb11 Mar 10, 2021
b7e5548
change all reduce
kaushikb11 Mar 11, 2021
1373f8f
add type hints
kaushikb11 Mar 11, 2021
9817608
Add missing Import
kaushikb11 Mar 11, 2021
937b50c
broadcast fix
kaushikb11 Mar 11, 2021
c7c16ba
Change SMDDP to DDPSM
kaushikb11 Mar 11, 2021
281231e
Update num gpus
kaushikb11 Mar 16, 2021
6c2f229
fix
kaushikb11 Mar 16, 2021
d514a6f
Update accelerator_connector.py
kaushikb11 Mar 16, 2021
8605f81
Merge branch 'master' into smddp
kaushikb11 Mar 31, 2021
1c4a315
Update changelog
kaushikb11 Mar 31, 2021
4ac0c4e
Merge branch 'master' into smddp
kaushikb11 Mar 31, 2021
9512dc4
Merge branch 'master' into smddp
kaushikb11 Jul 3, 2021
13dac0b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2021
e38fb37
Fix circular import
kaushikb11 Jul 8, 2021
7797770
Update environment
kaushikb11 Jul 8, 2021
4a0f78d
Add set_world_ranks
kaushikb11 Jul 8, 2021
574887f
Add updates
kaushikb11 Jul 13, 2021
19fa542
Add updates
kaushikb11 Jul 13, 2021
5699a32
Fix broadcasting
kaushikb11 Jul 13, 2021
7a715a0
Fix broadcasting
kaushikb11 Jul 13, 2021
e43fc1b
Update group
kaushikb11 Jul 13, 2021
1ac88a8
Update logger
kaushikb11 Jul 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.distributed.dist import LightningDistributed # noqa: F401
from pytorch_lightning.distributed.dist import LightningDistributed, SMLightningDistributed # noqa: F401
42 changes: 41 additions & 1 deletion pytorch_lightning/distributed/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import io
from typing import Any, Optional

import torch

from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
from pytorch_lightning.utilities import _SMDIST_AVAILABLE
from pytorch_lightning.utilities.distributed import group as _group

if _SMDIST_AVAILABLE:
import smdistributed.dataparallel.torch.distributed as sm_dist


class LightningDistributed:

Expand All @@ -33,3 +40,36 @@ def broadcast(self, obj: Any, group=_group.WORLD):
broadcast_object_list(obj, 0, group=group or _group.WORLD)

return obj[0]


class SMLightningDistributed(LightningDistributed):

def broadcast(self, obj: Any, group=_group.WORLD):
if self.rank == 0:
self._emit(obj, group)
else:
obj = self._receive(group)
return obj

def _broadcast(self, tensor: torch.Tensor, src: int, group: Optional[Any] = None):
if group is None:
return sm_dist.broadcast(tensor, src=src)
return sm_dist.broadcast(tensor, src=0, group=group)

def _emit(self, obj: Any, group=_group.WORLD):
buffer = io.BytesIO()
torch.save(obj, buffer)
data = bytearray(buffer.getbuffer())
length_tensor = torch.tensor([len(data)]).long().to(self.device)
self._broadcast(length_tensor, src=0, group=group)
data_tensor = torch.ByteTensor(data).to(self.device)
self._broadcast(data_tensor, src=0, group=group)

def _receive(self, group=_group.WORLD):
length_tensor = torch.tensor([0]).long().to(self.device)
self._broadcast(length_tensor, src=0, group=group)
data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8).to(self.device)
self._broadcast(data_tensor, src=0, group=group)
buffer = io.BytesIO(data_tensor.cpu().numpy())
obj = torch.load(buffer)
return obj
2 changes: 2 additions & 0 deletions pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp_sm import DDPSMPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
Expand Down Expand Up @@ -56,6 +57,7 @@
"Plugin",
"DDPShardedPlugin",
"DDPSpawnShardedPlugin",
"DDPSMPlugin",
]

from pathlib import Path
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
from pytorch_lightning.plugins.environments.kubeflow_environment import KubeflowEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.lightning_environment import LightningEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.smdist_environment import SMDistributedEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment # noqa: F401
62 changes: 62 additions & 0 deletions pytorch_lightning/plugins/environments/smdist_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from pytorch_lightning import _logger as log
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.utilities import _SMDIST_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _SMDIST_AVAILABLE:
import smdistributed.dataparallel.torch.distributed as dist


class SMDistributedEnvironment(ClusterEnvironment):
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self):
if not _SMDIST_AVAILABLE:
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
raise MisconfigurationException("`smdistributed` module is not available.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a small description on how to make this work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise MisconfigurationException("`smdistributed` module is not available.")
raise MisconfigurationException("`smdistributed` package is not available.")

also add how to instal it

super().__init__()

def creates_children(self) -> bool:
return False

def master_address(self) -> str:
master_address = os.environ["SM_CURRENT_HOST"]
log.debug(f"MASTER_ADDR: {master_address}")
return master_address

def master_port(self) -> str:
if "MASTER_PORT" not in os.environ:
rank_zero_warn("MASTER_PORT environment variable is not defined. Set as 12910")
os.environ["MASTER_PORT"] = "12910"
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

port = os.environ["MASTER_PORT"]
return port

def world_size(self) -> int:
return dist.get_world_size()

def local_rank(self) -> int:
return dist.get_local_rank()

def node_rank(self) -> int:
hosts = os.environ["SM_HOSTS"]
current_host = os.environ["SM_CURRENT_HOST"]
return hosts.index(current_host) if current_host in hosts else 0

def global_rank(self) -> int:
return dist.get_rank()
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp_sm import DDPSMPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
Expand Down
223 changes: 223 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp_sm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Dict, Optional, Union

import torch

from pytorch_lightning import _logger as log
from pytorch_lightning.distributed import SMLightningDistributed
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.plugins.environments.smdist_environment import SMDistributedEnvironment
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.utilities import _SMDIST_AVAILABLE
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything

if _SMDIST_AVAILABLE:
import smdistributed.dataparallel.torch.distributed as dist
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel


class DDPSMPlugin(DDPPlugin):

distributed_backend = "ddp_sm"

def __init__(
self,
cluster_environment: Optional[SMDistributedEnvironment] = None,
sync_batchnorm: bool = False,
**kwargs: Union[Any, Dict[str, Any]],
) -> None:
if not _SMDIST_AVAILABLE:
raise MisconfigurationException(
"`smdistributed` module is not available."
" You would need to enable distributed=smdistributed"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean adding to to DDPSMPlugin or where?

" in the Sagemaker Estimator Object."
Comment on lines +47 to +49
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"`smdistributed` module is not available."
" You would need to enable distributed=smdistributed"
" in the Sagemaker Estimator Object."
"`smdistributed` package is not available."
" You would need to enable `distributed=smdistributed` in the Sagemaker Estimator Object."

)

# While running smdistributed, all the gpus in the instance are considered
parallel_device_ids = list(range(torch.cuda.device_count()))
self.parallel_devices = [torch.device("cuda", i) for i in parallel_device_ids]
num_nodes = len(os.environ['SM_HOSTS'].split(","))

super().__init__(
parallel_devices=self.parallel_devices,
num_nodes=num_nodes,
cluster_environment=cluster_environment,
sync_batchnorm=sync_batchnorm
)

self._ddp_kwargs = kwargs
self.dist = SMLightningDistributed()

def setup(self, model):
self._model = model

self.node_rank = self.cluster_environment.node_rank()
self.local_rank = self.cluster_environment.local_rank()
self.global_rank = self.cluster_environment.global_rank()
self.world_size = self.cluster_environment.world_size()

rank_zero_only.rank = self.global_rank
self.model_to_device()

def configure_ddp(self):
self.pre_configure_ddp()
self._model = DistributedDataParallel(
LightningDistributedModule(self.model),
device_ids=[dist.get_local_rank()],
broadcast_buffers=False,
)

def init_ddp_connection(self, global_rank: int, world_size: int) -> None:

if not dist.is_initialized():
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
dist.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)

def pre_dispatch(self):
# TODO: check if needed
seed = os.environ.get("PL_GLOBAL_SEED")
if seed is not None:
seed_everything(int(seed))

# set warning rank
rank_zero_only.rank = self.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
self.init_ddp_connection(self.global_rank, self.world_size)

# TODO: we moved it to the trainer.fit after calling pre_dispatch
# ... need to double check that it is the correct place
# self.trainer.call_setup_hook(self.model)

# on world_size=0 let everyone know training is starting
if self.is_global_zero and not dist.is_initialized():
log.info("-" * 100)
log.info(f"distributed_backend={self.distributed_backend}")
log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes")
log.info("-" * 100)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
log.info("-" * 100)
log.info(f"distributed_backend={self.distributed_backend}")
log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes")
log.info("-" * 100)
log.info(
"-" * 100 + '\n',
f"distributed_backend={self.distributed_backend}" + '\n',
f"All DDP processes registered. Starting ddp with {self.world_size} processes" + '\n',
"-" * 100,
)


# # set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device

if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)

# move the model to the correct device
self.model_to_device()

self.configure_ddp()

self.barrier("configure ddp")

def sync_ddp_if_available(
self,
result: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None
) -> torch.Tensor:
"""
Function to reduce a tensor across worker processes during distributed training
Args:
result: the value to sync and reduce (typically tensor or number)
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum.
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.

Return:
reduced value
"""
if dist.is_available() and dist.is_initialized():
return self.sync_ddp(result, group=group, reduce_op=reduce_op)
return result

def sync_ddp(
self,
result: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None
) -> torch.Tensor:
"""
Function to reduce the tensors from several ddp processes to one master process

Args:
result: the value to sync and reduce (typically tensor or number)
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum.
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.

Return:
reduced value
"""
divide_by_world_size = False

if group is None:
group = dist.group.WORLD

op = reduce_op if isinstance(reduce_op, ReduceOp) else ReduceOp.SUM

if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
divide_by_world_size = True

# sync all processes before reduction
dist.barrier(group=group)
dist.all_reduce(result, op=op, group=group, async_op=False)

if divide_by_world_size:
result = result / dist.get_world_size(group)

return result

@property
def lightning_module(self):
return self.unwrap_lightning_module()

def unwrap_lightning_module(self):
model = self._model
if isinstance(model, (DistributedDataParallel)):
model = model.module
if isinstance(model, _LightningModuleWrapperBase):
model = model.module
return model

def barrier(self, *args, **kwargs) -> None:
if dist.is_initialized():
dist.barrier()

def broadcast(self, obj: object, src: int = 0) -> object:
return self.dist.broadcast(obj)

def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"):
"""
Reduces a tensor from several distributed processes to one aggregated tensor.
As this plugin only operates with a single device, the reduction is simply the identity.

Args:
tensor: the tensor to sync and reduce
*args: ignored
**kwargs: ignored

Return:
the unmodified input as reduction is not needed for single process operation
"""
if isinstance(tensor, torch.Tensor):
tensor = self.sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean"))
return tensor
Loading