Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updated sync bn #2838

Merged
merged 13 commits into from
Aug 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 0 additions & 204 deletions pl_examples/basic_examples/sync_bn.py

This file was deleted.

14 changes: 0 additions & 14 deletions pl_examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,6 @@ def test_gpu_template(cli_args):
run_cli()


@pytest.mark.parametrize(
'cli_args',
['--max_epochs 1 --max_steps 3 --num_nodes 1 --gpus 2 --dist_backend ddp_spawn --bn_sync']
)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_sync_bn(cli_args):
"""Test running CLI for an example with sync bn."""
from pl_examples.basic_examples.sync_bn import run_cli

cli_args = cli_args.split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
run_cli()


# @pytest.mark.parametrize('cli_args', ['--max_epochs 1 --max_steps 3 --num_nodes 1 --gpus 2'])
# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
# def test_multi_node_ddp(cli_args):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerator_backends/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
self.trainer.optimizer_frequencies = optimizer_frequencies

# call sync_bn before .cuda(), configure_apex and configure_ddp
if self.trainer.sync_bn:
model = model.configure_sync_bn(model)
if self.trainer.sync_batchnorm:
model = model.configure_sync_batchnorm(model)

# MODEL
# copy model to each gpu
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerator_backends/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def ddp_train(self, process_idx, mp_queue, model):
self.trainer.optimizer_frequencies = optimizer_frequencies

# call sync_bn before .cuda(), configure_apex and configure_ddp
if self.trainer.sync_bn:
model = model.configure_sync_bn(model)
if self.trainer.sync_batchnorm:
model = model.configure_sync_batchnorm(model)

# MODEL
# copy model to each gpu
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managi
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank+1}/{world_size}")
torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size)

def configure_sync_bn(self, model: 'LightningModule') -> 'LightningModule':
def configure_sync_batchnorm(self, model: 'LightningModule') -> 'LightningModule':
"""
Add global batchnorm for a model spread across multiple GPUs and nodes.

Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,14 @@ def on_train_end(self, trainer, pl_module):
# default used by the Trainer
trainer = Trainer(row_log_interval=50)

sync_batchnorm
^^^^^^^^^^^^^^

Enable synchronization between batchnorm layers across all GPUs.

.. testcode::

trainer = Trainer(sync_batchnorm=True)

val_percent_check
^^^^^^^^^^^^^^^^^
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def __init__(
log_save_interval: int = 100,
row_log_interval: int = 50,
distributed_backend: Optional[str] = None,
sync_bn: bool = False,
sync_batchnorm: bool = False,
precision: int = 32,
weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT,
weights_save_path: Optional[str] = None,
Expand Down Expand Up @@ -297,7 +297,7 @@ def __init__(

distributed_backend: The distributed backend to use (dp, ddp, ddp2, ddp_spawn, ddp_cpu)

sync_bn: Synchronize batch norm layers between process groups/whole world.
sync_batchnorm: Synchronize batch norm layers between process groups/whole world.

precision: Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs.

Expand Down Expand Up @@ -431,7 +431,7 @@ def __init__(
self.log_gpu_memory = log_gpu_memory

# sync-bn backend
self.sync_bn = sync_bn
self.sync_batchnorm = sync_batchnorm

self.gradient_clip_val = gradient_clip_val
self.check_val_every_n_epoch = check_val_every_n_epoch
Expand Down
48 changes: 47 additions & 1 deletion tests/base/datamodules.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
from torch.utils.data import random_split, DataLoader

from pytorch_lightning.core.datamodule import LightningDataModule
from tests.base.datasets import TrialMNIST
from tests.base.datasets import TrialMNIST, MNIST
from torch.utils.data.distributed import DistributedSampler


class TrialMNISTDataModule(LightningDataModule):
Expand Down Expand Up @@ -36,3 +38,47 @@ def val_dataloader(self):

def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)


class MNISTDataModule(LightningDataModule):
def __init__(
self, data_dir: str = './', batch_size: int = 32, dist_sampler: bool = False
) -> None:
super().__init__()

self.dist_sampler = dist_sampler
self.data_dir = data_dir
self.batch_size = batch_size

# self.dims is returned when you call dm.size()
# Setting default dims here because we know them.
# Could optionally be assigned dynamically in dm.setup()
self.dims = (1, 28, 28)

def prepare_data(self):
# download only
MNIST(self.data_dir, train=True, download=True, normalize=(0.1307, 0.3081))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to double check, do we need TrialMNIST or MNIST?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not want to change TrialMNIST structure for this PR as I had to include a DistributedSampler with shuffle

MNIST(self.data_dir, train=False, download=True, normalize=(0.1307, 0.3081))

def setup(self, stage: str = None):

# Assign train/val datasets for use in dataloaders
# TODO: need to split using random_split once updated to torch >= 1.6
if stage == 'fit' or stage is None:
self.mnist_train = MNIST(self.data_dir, train=True, normalize=(0.1307, 0.3081))

# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, normalize=(0.1307, 0.3081))

def train_dataloader(self):
dist_sampler = None
if self.dist_sampler:
dist_sampler = DistributedSampler(self.mnist_train, shuffle=False)

return DataLoader(
self.mnist_train, batch_size=self.batch_size, sampler=dist_sampler, shuffle=False
)

def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size, shuffle=False)
Loading