diff --git a/train.py b/train.py index ee1295ca3a59..e31c0dbd3e69 100644 --- a/train.py +++ b/train.py @@ -37,9 +37,9 @@ from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume logger = logging.getLogger(__name__) -LOCAL_RANK = int(getattr(os.environ, 'LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html -RANK = int(getattr(os.environ, 'RANK', -1)) -WORLD_SIZE = int(getattr(os.environ, 'WORLD_SIZE', 1)) +LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html +RANK = int(os.getenv('RANK', -1)) +WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) def train(hyp, # path/to/hyp.yaml or hyp dictionary