-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Sagemaker DDP Plugin #6271
Add Sagemaker DDP Plugin #6271
Changes from all commits
ad77ff2
270e4df
33c9891
0380efb
c115ece
50c045d
eab4f58
f233b1a
2e9280c
c91e9f2
a13a675
88b2b4b
ffc85ea
833bb57
467e76b
c2a508c
f5a2cf4
f5675e8
0dff5c7
eacf9a8
d1bf909
da66a19
fdaeb5b
01b2d37
af070e3
b7e5548
1373f8f
9817608
937b50c
c7c16ba
281231e
6c2f229
d514a6f
8605f81
1c4a315
4ac0c4e
9512dc4
13dac0b
e38fb37
7797770
4a0f78d
574887f
19fa542
5699a32
7a715a0
e43fc1b
1ac88a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,70 @@ | ||||||
# Copyright The PyTorch Lightning team. | ||||||
# | ||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
# you may not use this file except in compliance with the License. | ||||||
# You may obtain a copy of the License at | ||||||
# | ||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||
# | ||||||
# Unless required by applicable law or agreed to in writing, software | ||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
|
||||||
import os | ||||||
|
||||||
from pytorch_lightning import _logger as log | ||||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment | ||||||
from pytorch_lightning.utilities import _SMDIST_AVAILABLE, rank_zero_warn | ||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||||||
|
||||||
if _SMDIST_AVAILABLE: | ||||||
import smdistributed.dataparallel.torch.distributed as dist | ||||||
|
||||||
|
||||||
class SMDistributedEnvironment(ClusterEnvironment): | ||||||
kaushikb11 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
def __init__(self): | ||||||
if not _SMDIST_AVAILABLE: | ||||||
kaushikb11 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
raise MisconfigurationException("`smdistributed` module is not available.") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a small description on how to make this work. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
also add how to instal it |
||||||
super().__init__() | ||||||
|
||||||
def creates_children(self) -> bool: | ||||||
return False | ||||||
|
||||||
def master_address(self) -> str: | ||||||
master_address = os.environ["SM_CURRENT_HOST"] | ||||||
log.debug(f"MASTER_ADDR: {master_address}") | ||||||
return master_address | ||||||
|
||||||
def master_port(self) -> str: | ||||||
if "MASTER_PORT" not in os.environ: | ||||||
rank_zero_warn("MASTER_PORT environment variable is not defined. Set as 12910") | ||||||
os.environ["MASTER_PORT"] = "12910" | ||||||
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") | ||||||
|
||||||
port = os.environ["MASTER_PORT"] | ||||||
return port | ||||||
|
||||||
def world_size(self) -> int: | ||||||
return dist.get_world_size() | ||||||
|
||||||
def set_world_size(self, size: int) -> None: | ||||||
log.debug("SMDistributedEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") | ||||||
|
||||||
def local_rank(self) -> int: | ||||||
return dist.get_local_rank() | ||||||
|
||||||
def node_rank(self) -> int: | ||||||
hosts = os.environ["SM_HOSTS"] | ||||||
current_host = os.environ["SM_CURRENT_HOST"] | ||||||
return hosts.index(current_host) if current_host in hosts else 0 | ||||||
|
||||||
def global_rank(self) -> int: | ||||||
return dist.get_rank() | ||||||
|
||||||
def set_global_rank(self, rank: int) -> None: | ||||||
log.debug( | ||||||
"SMDistributedEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored." | ||||||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sm_dist
is conditionally imported. So this line fails when it's not present in cases where you train without that package.