Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updated sync bn #2838

Merged
merged 13 commits into from
Aug 5, 2020
Merged
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerator_backends/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
self.trainer.optimizer_frequencies = optimizer_frequencies

# call sync_bn before .cuda(), configure_apex and configure_ddp
if self.trainer.sync_bn:
model = model.configure_sync_bn(model)
if self.trainer.sync_batchnorm:
model = model.configure_sync_batchnorm(model)

# MODEL
# copy model to each gpu
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerator_backends/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def ddp_train(self, process_idx, mp_queue, model):
self.trainer.optimizer_frequencies = optimizer_frequencies

# call sync_bn before .cuda(), configure_apex and configure_ddp
if self.trainer.sync_bn:
model = model.configure_sync_bn(model)
if self.trainer.sync_batchnorm:
model = model.configure_sync_batchnorm(model)

# MODEL
# copy model to each gpu
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managi
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank+1}/{world_size}")
torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size)

def configure_sync_bn(self, model: 'LightningModule') -> 'LightningModule':
def configure_sync_batchnorm(self, model: 'LightningModule') -> 'LightningModule':
"""
Add global batchnorm for a model spread across multiple GPUs and nodes.

Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,14 @@ def on_train_end(self, trainer, pl_module):
# default used by the Trainer
trainer = Trainer(row_log_interval=50)

sync_batchnorm
^^^^^^^^^^^^^^^^^

Enable synchronization between batchnorm layers across all GPUs.

.. testcode::

trainer = Trainer(sync_batchnorm=True)

val_percent_check
^^^^^^^^^^^^^^^^^
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def __init__(
log_save_interval: int = 100,
row_log_interval: int = 50,
distributed_backend: Optional[str] = None,
sync_bn: bool = False,
sync_batchnorm: bool = False,
precision: int = 32,
weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT,
weights_save_path: Optional[str] = None,
Expand Down Expand Up @@ -297,7 +297,7 @@ def __init__(

distributed_backend: The distributed backend to use (dp, ddp, ddp2, ddp_spawn, ddp_cpu)

sync_bn: Synchronize batch norm layers between process groups/whole world.
sync_batchnorm: Synchronize batch norm layers between process groups/whole world.

precision: Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs.

Expand Down Expand Up @@ -431,7 +431,7 @@ def __init__(
self.log_gpu_memory = log_gpu_memory

# sync-bn backend
self.sync_bn = sync_bn
self.sync_batchnorm = sync_batchnorm

self.gradient_clip_val = gradient_clip_val
self.check_val_every_n_epoch = check_val_every_n_epoch
Expand Down