Skip to content

Commit

Permalink
sets default ddp mode to spawn (#2168)
Browse files Browse the repository at this point in the history
* set ddp_spawn as default

* spawn message

* spawn message

* spawn message

* spawn message

* spawn message

* spawn message

* spawn message

* spawn message
  • Loading branch information
williamFalcon committed Jun 13, 2020
1 parent bb32ae5 commit 9df2b20
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 18 deletions.
8 changes: 4 additions & 4 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,18 +942,18 @@ def init_ddp_connection(
self._init_slurm_connection()

if 'MASTER_ADDR' not in os.environ:
log.warning("MASTER_ADDR environment variable is not defined. Set as localhost")
rank_zero_warn("MASTER_ADDR environment variable is not defined. Set as localhost")
os.environ['MASTER_ADDR'] = '127.0.0.1'
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")

if 'MASTER_PORT' not in os.environ:
log.warning("MASTER_PORT environment variable is not defined. Set as 12910")
rank_zero_warn("MASTER_PORT environment variable is not defined. Set as 12910")
os.environ['MASTER_PORT'] = '12910'
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

if 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) != world_size:
log.warning(f"WORLD_SIZE environment variable ({os.environ['WORLD_SIZE']}) "
f"is not equal to the computed world size ({world_size}). Ignored.")
rank_zero_warn(f"WORLD_SIZE environment variable ({os.environ['WORLD_SIZE']}) "
f"is not equal to the computed world size ({world_size}). Ignored.")

torch_backend = "nccl" if self.trainer.on_gpu else "gloo"
log.info(f"initializing ddp: LOCAL_RANK: {proc_rank}/{world_size - 1} WORLD_SIZE:{world_size}")
Expand Down
15 changes: 14 additions & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,24 @@ def _percent_range_check(self, name: str) -> None:
def _worker_check(self, dataloader: DataLoader, name: str) -> None:
on_windows = platform.system() == 'Windows'

if isinstance(dataloader, DataLoader) and dataloader.num_workers <= 2 and not on_windows:
# ddp_spawn + num_workers > 0 don't mix! tell the user
is_dataloader = isinstance(dataloader, DataLoader)
using_spawn = self.distributed_backend == 'ddp_spawn'
if is_dataloader and dataloader.num_workers > 0 and not on_windows and using_spawn:
rank_zero_warn('Dataloader(num_workers>0) and ddp_spawn do not mix well! '
'Your performance might suffer dramatically. '
'Please consider setting distributed_backend=ddp to use num_workers > 0 '
'(this is a bottleneck of Python .spawn() and PyTorch')

elif is_dataloader and dataloader.num_workers <= 2 and not on_windows and not using_spawn:
rank_zero_warn(f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
' Consider increasing the value of the `num_workers` argument`'
' in the `DataLoader` init to improve performance.')

elif is_dataloader and dataloader.num_workers == 0 and not on_windows and using_spawn:
rank_zero_warn('You are using `distributed_backend=ddp_spawn` with num_workers=0. '
'For much faster performance, switch to `distributed_backend=ddp` and set `num_workers>0`')

def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:

# don't do anything if it's not a dataloader
Expand Down
16 changes: 8 additions & 8 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def train_fx(trial_hparams, cluster_manager, _):
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info

try:
from apex import amp
Expand Down Expand Up @@ -220,9 +220,9 @@ def set_distributed_mode(self, distributed_backend):
elif self.num_gpus > 1:
rank_zero_warn('You requested multiple GPUs but did not specify a backend, e.g.'
' Trainer(distributed_backend=dp) (or ddp, ddp2).'
' Setting distributed_backend=ddp for you.')
self.distributed_backend = 'ddp'
distributed_backend = 'ddp'
' Setting distributed_backend=ddp_spawn for you.')
self.distributed_backend = 'ddp_spawn'
distributed_backend = 'ddp_spawn'

if distributed_backend == "dp":
# do nothing if num_gpus == 0
Expand Down Expand Up @@ -264,7 +264,7 @@ def set_distributed_mode(self, distributed_backend):
'To silence this warning set distributed_backend=ddp or distributed_backend=ddp2'
)

log.info(f'GPU available: {torch.cuda.is_available()}, used: {self.on_gpu}')
rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self.on_gpu}')

def configure_slurm_ddp(self, num_gpu_nodes):
self.is_slurm_managing_tasks = False
Expand Down Expand Up @@ -298,7 +298,7 @@ def configure_slurm_ddp(self, num_gpu_nodes):

# notify user the that slurm is managing tasks
if self.is_slurm_managing_tasks:
log.info('Multi-processing is handled by Slurm.')
rank_zero_info('Multi-processing is handled by Slurm.')

def determine_ddp_node_rank(self):
if self.is_slurm_managing_tasks:
Expand All @@ -316,7 +316,7 @@ def determine_ddp_node_rank(self):
log.warning(f"Multiple environment variables ({node_ids}) defined for node rank. "
f"Using the first one.")
k, rank = node_ids.pop()
log.info(f"Using environment variable {k} for node rank ({rank}).")
rank_zero_info(f"Using environment variable {k} for node rank ({rank}).")
return int(rank)

def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
Expand All @@ -336,7 +336,7 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_str

# don't make this debug... this is good UX
log.info(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]')
rank_zero_info(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]')

def __set_random_port(self):
"""
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn, parsing
from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info

try:
from apex import amp
Expand Down Expand Up @@ -362,8 +362,8 @@ def __init__(
if self.fast_dev_run:
self.num_sanity_val_steps = 0
self.max_epochs = 1
log.info('Running in fast_dev_run mode: will run a full train,'
' val and test loop using a single batch')
rank_zero_info('Running in fast_dev_run mode: will run a full train,'
' val and test loop using a single batch')

# set default save path if user didn't provide one
self.default_root_dir = default_root_dir
Expand Down Expand Up @@ -838,7 +838,7 @@ def fit(
self.single_gpu_train(model)

elif self.use_tpu: # pragma: no-cover
log.info(f'training on {self.tpu_cores} TPU cores')
rank_zero_info(f'training on {self.tpu_cores} TPU cores')

# COLAB_GPU is an env var available by default in Colab environments.
start_method = 'fork' if self.on_colab_kaggle else 'spawn'
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""General utilities"""

from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.parsing import AttributeDict
6 changes: 6 additions & 0 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import wraps
import warnings
from pytorch_lightning import _logger as log


def rank_zero_only(fn):
Expand All @@ -23,4 +24,9 @@ def _warn(*args, **kwargs):
warnings.warn(*args, **kwargs)


def _info(*args, **kwargs):
log.info(*args, **kwargs)


rank_zero_info = rank_zero_only(_info)
rank_zero_warn = rank_zero_only(_warn)

0 comments on commit 9df2b20

Please sign in to comment.