Skip to content

Commit

Permalink
updated sync bn (#2838)
Browse files Browse the repository at this point in the history
* updated sync bn

* updated sync bn

* updated sync bn

* updated sync bn

* updated sync bn

* updated sync bn

* updated sync bn

* updated sync bn

* added ddp_spawn test

* updated test

* clean

* clean

Co-authored-by: Jirka Borovec <jirka@pytorchlightning.ai>
  • Loading branch information
ananyahjha93 and Borda committed Aug 5, 2020
1 parent 633cf76 commit a5f2b89
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 227 deletions.
204 changes: 0 additions & 204 deletions pl_examples/basic_examples/sync_bn.py

This file was deleted.

14 changes: 0 additions & 14 deletions pl_examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,6 @@ def test_gpu_template(cli_args):
run_cli()


@pytest.mark.parametrize(
'cli_args',
['--max_epochs 1 --max_steps 3 --num_nodes 1 --gpus 2 --dist_backend ddp_spawn --bn_sync']
)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_sync_bn(cli_args):
"""Test running CLI for an example with sync bn."""
from pl_examples.basic_examples.sync_bn import run_cli

cli_args = cli_args.split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
run_cli()


# @pytest.mark.parametrize('cli_args', ['--max_epochs 1 --max_steps 3 --num_nodes 1 --gpus 2'])
# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
# def test_multi_node_ddp(cli_args):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerator_backends/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
self.trainer.optimizer_frequencies = optimizer_frequencies

# call sync_bn before .cuda(), configure_apex and configure_ddp
if self.trainer.sync_bn:
model = model.configure_sync_bn(model)
if self.trainer.sync_batchnorm:
model = model.configure_sync_batchnorm(model)

# MODEL
# copy model to each gpu
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerator_backends/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def ddp_train(self, process_idx, mp_queue, model):
self.trainer.optimizer_frequencies = optimizer_frequencies

# call sync_bn before .cuda(), configure_apex and configure_ddp
if self.trainer.sync_bn:
model = model.configure_sync_bn(model)
if self.trainer.sync_batchnorm:
model = model.configure_sync_batchnorm(model)

# MODEL
# copy model to each gpu
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managi
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank+1}/{world_size}")
torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size)

def configure_sync_bn(self, model: 'LightningModule') -> 'LightningModule':
def configure_sync_batchnorm(self, model: 'LightningModule') -> 'LightningModule':
"""
Add global batchnorm for a model spread across multiple GPUs and nodes.
Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,14 @@ def on_train_end(self, trainer, pl_module):
# default used by the Trainer
trainer = Trainer(row_log_interval=50)
sync_batchnorm
^^^^^^^^^^^^^^
Enable synchronization between batchnorm layers across all GPUs.
.. testcode::
trainer = Trainer(sync_batchnorm=True)
val_percent_check
^^^^^^^^^^^^^^^^^
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def __init__(
log_save_interval: int = 100,
row_log_interval: int = 50,
distributed_backend: Optional[str] = None,
sync_bn: bool = False,
sync_batchnorm: bool = False,
precision: int = 32,
weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT,
weights_save_path: Optional[str] = None,
Expand Down Expand Up @@ -297,7 +297,7 @@ def __init__(
distributed_backend: The distributed backend to use (dp, ddp, ddp2, ddp_spawn, ddp_cpu)
sync_bn: Synchronize batch norm layers between process groups/whole world.
sync_batchnorm: Synchronize batch norm layers between process groups/whole world.
precision: Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs.
Expand Down Expand Up @@ -431,7 +431,7 @@ def __init__(
self.log_gpu_memory = log_gpu_memory

# sync-bn backend
self.sync_bn = sync_bn
self.sync_batchnorm = sync_batchnorm

self.gradient_clip_val = gradient_clip_val
self.check_val_every_n_epoch = check_val_every_n_epoch
Expand Down
48 changes: 47 additions & 1 deletion tests/base/datamodules.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
from torch.utils.data import random_split, DataLoader

from pytorch_lightning.core.datamodule import LightningDataModule
from tests.base.datasets import TrialMNIST
from tests.base.datasets import TrialMNIST, MNIST
from torch.utils.data.distributed import DistributedSampler


class TrialMNISTDataModule(LightningDataModule):
Expand Down Expand Up @@ -36,3 +38,47 @@ def val_dataloader(self):

def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)


class MNISTDataModule(LightningDataModule):
def __init__(
self, data_dir: str = './', batch_size: int = 32, dist_sampler: bool = False
) -> None:
super().__init__()

self.dist_sampler = dist_sampler
self.data_dir = data_dir
self.batch_size = batch_size

# self.dims is returned when you call dm.size()
# Setting default dims here because we know them.
# Could optionally be assigned dynamically in dm.setup()
self.dims = (1, 28, 28)

def prepare_data(self):
# download only
MNIST(self.data_dir, train=True, download=True, normalize=(0.1307, 0.3081))
MNIST(self.data_dir, train=False, download=True, normalize=(0.1307, 0.3081))

def setup(self, stage: str = None):

# Assign train/val datasets for use in dataloaders
# TODO: need to split using random_split once updated to torch >= 1.6
if stage == 'fit' or stage is None:
self.mnist_train = MNIST(self.data_dir, train=True, normalize=(0.1307, 0.3081))

# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, normalize=(0.1307, 0.3081))

def train_dataloader(self):
dist_sampler = None
if self.dist_sampler:
dist_sampler = DistributedSampler(self.mnist_train, shuffle=False)

return DataLoader(
self.mnist_train, batch_size=self.batch_size, sampler=dist_sampler, shuffle=False
)

def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size, shuffle=False)
Loading

0 comments on commit a5f2b89

Please sign in to comment.