Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Sagemaker DDP Plugin #6271

Closed
wants to merge 47 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
ad77ff2
feat: add smddp plugin & environment
kaushikb11 Mar 1, 2021
270e4df
update accelerator connector
kaushikb11 Mar 1, 2021
33c9891
update smddp plugin
kaushikb11 Mar 1, 2021
0380efb
update smddp plugin
kaushikb11 Mar 1, 2021
c115ece
update lightning distributed
kaushikb11 Mar 1, 2021
50c045d
fix typo
kaushikb11 Mar 1, 2021
eab4f58
update smddp plugin
kaushikb11 Mar 1, 2021
f233b1a
add comment for parallel devices
kaushikb11 Mar 1, 2021
2e9280c
debug
kaushikb11 Mar 1, 2021
c91e9f2
debug
kaushikb11 Mar 1, 2021
a13a675
ddp name consistency
kaushikb11 Mar 8, 2021
88b2b4b
fix global rank
kaushikb11 Mar 8, 2021
ffc85ea
update smdist environment
kaushikb11 Mar 8, 2021
833bb57
Update Type Annotation
kaushikb11 Mar 8, 2021
467e76b
DDP plugin as base class
kaushikb11 Mar 8, 2021
c2a508c
add test
kaushikb11 Mar 8, 2021
f5a2cf4
Merge branch 'master' into smddp
kaushikb11 Mar 10, 2021
f5675e8
add creates_children mthod
kaushikb11 Mar 10, 2021
0dff5c7
change backend to mpi
kaushikb11 Mar 10, 2021
eacf9a8
set broadcast buffers set to False
kaushikb11 Mar 10, 2021
d1bf909
mini refactor
kaushikb11 Mar 10, 2021
da66a19
address reviews
kaushikb11 Mar 10, 2021
fdaeb5b
address reviews
kaushikb11 Mar 10, 2021
01b2d37
add missing init
kaushikb11 Mar 10, 2021
af070e3
change backend
kaushikb11 Mar 10, 2021
b7e5548
change all reduce
kaushikb11 Mar 11, 2021
1373f8f
add type hints
kaushikb11 Mar 11, 2021
9817608
Add missing Import
kaushikb11 Mar 11, 2021
937b50c
broadcast fix
kaushikb11 Mar 11, 2021
c7c16ba
Change SMDDP to DDPSM
kaushikb11 Mar 11, 2021
281231e
Update num gpus
kaushikb11 Mar 16, 2021
6c2f229
fix
kaushikb11 Mar 16, 2021
d514a6f
Update accelerator_connector.py
kaushikb11 Mar 16, 2021
8605f81
Merge branch 'master' into smddp
kaushikb11 Mar 31, 2021
1c4a315
Update changelog
kaushikb11 Mar 31, 2021
4ac0c4e
Merge branch 'master' into smddp
kaushikb11 Mar 31, 2021
9512dc4
Merge branch 'master' into smddp
kaushikb11 Jul 3, 2021
13dac0b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2021
e38fb37
Fix circular import
kaushikb11 Jul 8, 2021
7797770
Update environment
kaushikb11 Jul 8, 2021
4a0f78d
Add set_world_ranks
kaushikb11 Jul 8, 2021
574887f
Add updates
kaushikb11 Jul 13, 2021
19fa542
Add updates
kaushikb11 Jul 13, 2021
5699a32
Fix broadcasting
kaushikb11 Jul 13, 2021
7a715a0
Fix broadcasting
kaushikb11 Jul 13, 2021
e43fc1b
Update group
kaushikb11 Jul 13, 2021
1ac88a8
Update logger
kaushikb11 Jul 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.distributed.dist import LightningDistributed # noqa: F401
from pytorch_lightning.distributed.dist import LightningDistributed, SMLightningDistributed # noqa: F401
100 changes: 99 additions & 1 deletion pytorch_lightning/distributed/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,20 @@
# limitations under the License.
from typing import Any

from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
import torch

from pytorch_lightning.overrides.torch_distributed import (
_object_to_tensor,
_rank_not_in_group,
_tensor_to_object,
broadcast_object_list,
)
from pytorch_lightning.utilities import _SMDIST_AVAILABLE
from pytorch_lightning.utilities.distributed import group as _group

if _SMDIST_AVAILABLE:
import smdistributed.dataparallel.torch.distributed as sm_dist


class LightningDistributed:

Expand All @@ -33,3 +44,90 @@ def broadcast(self, obj: Any, group=_group.WORLD):
broadcast_object_list(obj, 0, group=group or _group.WORLD)

return obj[0]


class SMLightningDistributed(LightningDistributed):

def broadcast(self, obj: Any, group=sm_dist.group.WORLD):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sm_dist is conditionally imported. So this line fails when it's not present in cases where you train without that package.

# always wrap into a list so list can be brodcasted.
obj = [obj]

obj = [obj]

if self.rank != 0:
obj = [None] * len(obj)

_broadcast_object_list(obj, self.rank, 0, group=group)

return obj[0]

# def _broadcast(self, tensor: torch.Tensor, src: int, group: Optional[Any] = None):
# if group is None:
# return sm_dist.broadcast(tensor, src=src)
# return sm_dist.broadcast(tensor, src=0, group=group)

# def _emit(self, obj: Any, group=_group.WORLD):
# buffer = io.BytesIO()
# torch.save(obj, buffer)
# data = bytearray(buffer.getbuffer())
# length_tensor = torch.tensor([len(data)]).long().to(self.device)
# self._broadcast(length_tensor, src=0, group=group)
# data_tensor = torch.ByteTensor(data).to(self.device)
# self._broadcast(data_tensor, src=0, group=group)

# def _receive(self, group=_group.WORLD):
# length_tensor = torch.tensor([0]).long().to(self.device)
# self._broadcast(length_tensor, src=0, group=group)
# data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8).to(self.device)
# self._broadcast(data_tensor, src=0, group=group)
# buffer = io.BytesIO(data_tensor.cpu().numpy())
# obj = torch.load(buffer)
# return obj


# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1327
def _broadcast_object_list(object_list, rank, src=0, group=None):
if _rank_not_in_group(group):
return

my_rank = rank
# Serialize object_list elements to tensors on src rank.
if my_rank == src:
tensor_list, size_list = zip(*[_object_to_tensor(obj) for obj in object_list])
object_sizes_tensor = torch.cat(size_list)
else:
object_sizes_tensor = torch.LongTensor(len(object_list))

# group_backend = get_backend(group)
# is_nccl_backend = group_backend == Backend.NCCL
# current_device = torch.device("cpu")
# if is_nccl_backend:
# # See note about using torch.cuda.current_device() here in docstring.
# # We cannot simply use my_rank since rank == device is not necessarily
# # true.
# current_device = torch.device('cuda', torch.cuda.current_device())
# object_sizes_tensor = object_sizes_tensor.to(current_device)
# object_sizes_tensor = object_sizes_tensor.to(current_device)

# Broadcast object sizes
sm_dist.broadcast(object_sizes_tensor, src=src, group=group)

# Concatenate and broadcast serialized object tensors
if my_rank == src:
object_tensor = torch.cat(tensor_list)
else:
object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item())

# if is_nccl_backend:
# object_tensor = object_tensor.to(current_device)

sm_dist.broadcast(object_tensor, src=src, group=group)

# Deserialize objects using their stored sizes.
offset = 0
if my_rank != src:
for i, obj_size in enumerate(object_sizes_tensor):
obj_view = object_tensor[offset:offset + obj_size]
obj_view = obj_view.type(torch.ByteTensor) # type: ignore[call-overload]
offset += obj_size
object_list[i] = _tensor_to_object(obj_view, obj_size)
8 changes: 3 additions & 5 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,16 +266,14 @@ def version(self) -> int:
return self._version

def _get_next_version(self):
root_dir = self.root_dir
root_dir = os.path.join(self.save_dir, self.name)

try:
listdir_info = self._fs.listdir(root_dir)
except OSError:
if not self._fs.isdir(root_dir):
log.warning('Missing logger folder: %s', root_dir)
return 0

existing_versions = []
for listing in listdir_info:
for listing in self._fs.listdir(root_dir):
d = listing["name"]
bn = os.path.basename(d)
if self._fs.isdir(d) and bn.startswith("version_"):
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp_sm import DDPSMPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
Expand Down Expand Up @@ -56,6 +57,7 @@
"Plugin",
"DDPShardedPlugin",
"DDPSpawnShardedPlugin",
"DDPSMPlugin",
]

from pathlib import Path
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
from pytorch_lightning.plugins.environments.kubeflow_environment import KubeflowEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.lightning_environment import LightningEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.smdist_environment import SMDistributedEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment # noqa: F401
70 changes: 70 additions & 0 deletions pytorch_lightning/plugins/environments/smdist_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from pytorch_lightning import _logger as log
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.utilities import _SMDIST_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _SMDIST_AVAILABLE:
import smdistributed.dataparallel.torch.distributed as dist


class SMDistributedEnvironment(ClusterEnvironment):
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self):
if not _SMDIST_AVAILABLE:
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
raise MisconfigurationException("`smdistributed` module is not available.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a small description on how to make this work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise MisconfigurationException("`smdistributed` module is not available.")
raise MisconfigurationException("`smdistributed` package is not available.")

also add how to instal it

super().__init__()

def creates_children(self) -> bool:
return False

def master_address(self) -> str:
master_address = os.environ["SM_CURRENT_HOST"]
log.debug(f"MASTER_ADDR: {master_address}")
return master_address

def master_port(self) -> str:
if "MASTER_PORT" not in os.environ:
rank_zero_warn("MASTER_PORT environment variable is not defined. Set as 12910")
os.environ["MASTER_PORT"] = "12910"
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

port = os.environ["MASTER_PORT"]
return port

def world_size(self) -> int:
return dist.get_world_size()

def set_world_size(self, size: int) -> None:
log.debug("SMDistributedEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

def local_rank(self) -> int:
return dist.get_local_rank()

def node_rank(self) -> int:
hosts = os.environ["SM_HOSTS"]
current_host = os.environ["SM_CURRENT_HOST"]
return hosts.index(current_host) if current_host in hosts else 0

def global_rank(self) -> int:
return dist.get_rank()

def set_global_rank(self, rank: int) -> None:
log.debug(
"SMDistributedEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored."
)
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp_sm import DDPSMPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
Expand Down
Loading