Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 456 centralize ddp setup #544

Merged
merged 38 commits into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
9710cf4
first version
Louis-Dupont Dec 5, 2022
cebe631
breaking the code
Louis-Dupont Dec 5, 2022
158c0eb
fix env_helper
Louis-Dupont Dec 5, 2022
92ff519
improve doc
Louis-Dupont Dec 5, 2022
7fc6934
fix tests
Louis-Dupont Dec 5, 2022
b0d974e
Merge branch 'master' into feature/SG-456-centralise_ddp_setup
Louis-Dupont Dec 5, 2022
40554ed
remove multigpu.OFF on tests, and add better exception when using dev…
Louis-Dupont Dec 6, 2022
1d07cd9
imrpove error raising
Louis-Dupont Dec 6, 2022
132fae0
fix tests
Louis-Dupont Dec 6, 2022
e226196
wip
Louis-Dupont Dec 6, 2022
5107905
refacto of environment names wip
Louis-Dupont Dec 7, 2022
4b3ab3c
Merge branch 'master' into feature/SG-456-centralise_ddp_setup
Louis-Dupont Dec 12, 2022
8b37912
reorganise post merge
Louis-Dupont Dec 12, 2022
736133f
Merge branch 'master' into feature/SG-456-centralise_ddp_setup
Louis-Dupont Dec 12, 2022
14e6e80
fix
Louis-Dupont Dec 12, 2022
bdb6abc
wip
Louis-Dupont Dec 12, 2022
11d9ecd
wip
Louis-Dupont Dec 12, 2022
cc9676e
Merge branch 'master' into feature/SG-456-centralise_ddp_setup
Louis-Dupont Dec 13, 2022
c2e85de
add option to set device
Louis-Dupont Dec 13, 2022
f5369b0
done
Louis-Dupont Dec 13, 2022
27cad03
Merge branch 'hotgix/SG-000-fix_multigpu_OFF' into feature/SG-456-cen…
Louis-Dupont Dec 13, 2022
dc2d565
wip
Louis-Dupont Dec 13, 2022
b847222
wip
Louis-Dupont Dec 13, 2022
a200b71
support torch.distributed.launch
Louis-Dupont Dec 13, 2022
4598dfd
Merge branch 'master' into feature/SG-456-centralise_ddp_setup
Louis-Dupont Dec 13, 2022
2f06889
reorganise pop_local_rank
Louis-Dupont Dec 13, 2022
e170b1a
remove unused logger
Louis-Dupont Dec 13, 2022
b19b77a
undo useless change
Louis-Dupont Dec 13, 2022
907b174
run on CPU if no CUDA
Louis-Dupont Dec 13, 2022
98bbe9c
Merge branch 'master' into feature/SG-456-centralise_ddp_setup
Louis-Dupont Dec 13, 2022
43451ac
fix
Louis-Dupont Dec 13, 2022
33daa93
fix
Louis-Dupont Dec 13, 2022
fcff5b5
fix
Louis-Dupont Dec 13, 2022
77f07ff
fix kd trainer
Louis-Dupont Dec 13, 2022
cc94d8e
Merge branch 'master' into feature/SG-456-centralise_ddp_setup
Louis-Dupont Dec 26, 2022
99a61a2
Merge branch 'master' into feature/SG-456-centralise_ddp_setup
Louis-Dupont Jan 5, 2023
a9f61d8
remove unwanted change
Louis-Dupont Jan 5, 2023
0c064c3
remove unwanted change
Louis-Dupont Jan 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/super_gradients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from super_gradients.examples.train_from_recipe_example import train_from_recipe
from super_gradients.examples.train_from_kd_recipe_example import train_from_kd_recipe
from super_gradients.sanity_check import env_sanity_check
from super_gradients.training.utils.distributed_training_utils import setup_device

__all__ = [
"ARCHITECTURES",
Expand All @@ -18,6 +19,7 @@
"train_from_recipe",
"train_from_kd_recipe",
"env_sanity_check",
"setup_device",
]

__version__ = "3.0.5"
Expand Down
11 changes: 11 additions & 0 deletions src/super_gradients/common/environment/argparse_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import argparse
import sys
from typing import Any
from super_gradients.common.abstractions.abstract_logger import get_logger


logger = get_logger(__name__)

EXTRA_ARGS = []


Expand All @@ -18,3 +21,11 @@ def pop_arg(arg_name: str, default_value: Any = None) -> Any:
EXTRA_ARGS.append(val)
sys.argv.remove(val)
return vars(args)[arg_name]


def pop_local_rank() -> int:
"""Pop the python arg "local-rank". If exists inform the user with a log, otherwise return -1."""
local_rank = pop_arg("local_rank", default_value=-1)
if local_rank != -1:
logger.info("local_rank was automatically parsed from your config.")
return local_rank
35 changes: 10 additions & 25 deletions src/super_gradients/common/environment/ddp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,24 @@
import socket
from functools import wraps

from super_gradients.common.environment.argparse_utils import pop_arg
from super_gradients.common.environment.device_utils import device_config
from super_gradients.common.environment.omegaconf_utils import register_hydra_resolvers


DDP_LOCAL_RANK = int(os.getenv("LOCAL_RANK", default=-1))
INIT_TRAINER = False
from super_gradients.common.environment.argparse_utils import pop_local_rank


def init_trainer():
"""
Initialize the super_gradients environment.

This function should be the first thing to be called by any code running super_gradients.
It resolves conflicts between the different tools, packages and environments used and prepares the super_gradients environment.
"""
global INIT_TRAINER, DDP_LOCAL_RANK

if not INIT_TRAINER:
register_hydra_resolvers()

# We pop local_rank if it was specified in the args, because it would break
args_local_rank = pop_arg("local_rank", default_value=-1)

# Set local_rank with priority order (env variable > args.local_rank > args.default_value)
DDP_LOCAL_RANK = int(os.getenv("LOCAL_RANK", default=args_local_rank))
INIT_TRAINER = True
register_hydra_resolvers()
pop_local_rank()


def is_distributed() -> bool:
return DDP_LOCAL_RANK >= 0


def is_rank_0() -> bool:
"""Check if the node was launched with torch.distributed.launch and if the node is of rank 0"""
return os.getenv("LOCAL_RANK") == "0"
"""Check if current process is a DDP subprocess."""
return device_config.assigned_rank >= 0


def is_launched_using_sg():
Expand All @@ -55,7 +38,9 @@ def is_main_process():
"""
if not is_distributed(): # If no DDP, or DDP launching process
return True
elif is_rank_0() and not is_launched_using_sg(): # If DDP launched using torch.distributed.launch or torchrun, we need to run the check on rank 0
elif (
device_config.assigned_rank == 0 and not is_launched_using_sg()
): # If DDP launched using torch.distributed.launch or torchrun, we need to run the check on rank 0
return True
else:
return False
Expand All @@ -74,7 +59,7 @@ def do_nothing(*args, **kwargs):

@wraps(func)
def wrapper(*args, **kwargs):
if DDP_LOCAL_RANK <= 0:
if device_config.assigned_rank <= 0:
return func(*args, **kwargs)
else:
return do_nothing(*args, **kwargs)
Expand Down
28 changes: 28 additions & 0 deletions src/super_gradients/common/environment/device_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os
import dataclasses

import torch

from super_gradients.common.environment.argparse_utils import pop_local_rank


__all__ = ["device_config"]


def _get_assigned_rank() -> int:
"""Get the rank assigned by DDP launcher. If not DDP subprocess, return -1."""
if os.getenv("LOCAL_RANK") is not None:
return int(os.getenv("LOCAL_RANK"))
else:
return pop_local_rank()


@dataclasses.dataclass
class DeviceConfig:
device: str = "cuda" if torch.cuda.is_available() else "cpu"
multi_gpu: str = None
assigned_rank: str = dataclasses.field(default=_get_assigned_rank(), init=False)


# Singleton holding the device information
device_config = DeviceConfig()
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from super_gradients import Trainer
from super_gradients.training.metrics.classification_metrics import Accuracy, Top5
from super_gradients.training import MultiGPUMode
from torch.optim import ASGD
from torch.optim.lr_scheduler import MultiStepLR, ReduceLROnPlateau
from torch.nn import CrossEntropyLoss
Expand Down Expand Up @@ -49,7 +48,7 @@
]

# Bring everything together with Trainer and start training
trainer = Trainer("Cifar10_external_objects_example", multi_gpu=MultiGPUMode.OFF)
trainer = Trainer("Cifar10_external_objects_example")

train_params = {
"max_epochs": 300,
Expand Down
Loading