From b23ea28b0a56883ab44a59eb3bf3203688bce136 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Wed, 5 Aug 2020 14:39:02 -0400 Subject: [PATCH 01/12] updated sync bn --- pytorch_lightning/accelerator_backends/ddp_backend.py | 4 ++-- .../accelerator_backends/ddp_spawn_backend.py | 4 ++-- pytorch_lightning/core/lightning.py | 2 +- pytorch_lightning/trainer/__init__.py | 8 ++++++++ pytorch_lightning/trainer/trainer.py | 6 +++--- 5 files changed, 16 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/accelerator_backends/ddp_backend.py b/pytorch_lightning/accelerator_backends/ddp_backend.py index c2e549c18ef1a..44ad52d34ba2f 100644 --- a/pytorch_lightning/accelerator_backends/ddp_backend.py +++ b/pytorch_lightning/accelerator_backends/ddp_backend.py @@ -177,8 +177,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 self.trainer.optimizer_frequencies = optimizer_frequencies # call sync_bn before .cuda(), configure_apex and configure_ddp - if self.trainer.sync_bn: - model = model.configure_sync_bn(model) + if self.trainer.sync_batchnorm: + model = model.configure_sync_batchnorm(model) # MODEL # copy model to each gpu diff --git a/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py b/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py index 85de2d1b7759e..704fc5558588a 100644 --- a/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py +++ b/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py @@ -119,8 +119,8 @@ def ddp_train(self, process_idx, mp_queue, model): self.trainer.optimizer_frequencies = optimizer_frequencies # call sync_bn before .cuda(), configure_apex and configure_ddp - if self.trainer.sync_bn: - model = model.configure_sync_bn(model) + if self.trainer.sync_batchnorm: + model = model.configure_sync_batchnorm(model) # MODEL # copy model to each gpu diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c09d981d1d5e3..4189c828ed266 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -957,7 +957,7 @@ def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managi log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank+1}/{world_size}") torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size) - def configure_sync_bn(self, model: 'LightningModule') -> 'LightningModule': + def configure_sync_batchnorm(self, model: 'LightningModule') -> 'LightningModule': """ Add global batchnorm for a model spread across multiple GPUs and nodes. diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 0164210c771fb..eea796411553e 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -855,6 +855,14 @@ def on_train_end(self, trainer, pl_module): # default used by the Trainer trainer = Trainer(row_log_interval=50) +sync_batchnorm +^^^^^^^^^^^^^^^^^ + +Enable synchrnization between batchnorm layers across all GPUs. + +.. testcode:: + + trainer = Trainer(sync_batchnorm=True) val_percent_check ^^^^^^^^^^^^^^^^^ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ebfe680fc3372..4b342328df297 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -184,7 +184,7 @@ def __init__( log_save_interval: int = 100, row_log_interval: int = 50, distributed_backend: Optional[str] = None, - sync_bn: bool = False, + sync_batchnorm: bool = False, precision: int = 32, weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT, weights_save_path: Optional[str] = None, @@ -297,7 +297,7 @@ def __init__( distributed_backend: The distributed backend to use (dp, ddp, ddp2, ddp_spawn, ddp_cpu) - sync_bn: Synchronize batch norm layers between process groups/whole world. + sync_batchnorm: Synchronize batch norm layers between process groups/whole world. precision: Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs. @@ -431,7 +431,7 @@ def __init__( self.log_gpu_memory = log_gpu_memory # sync-bn backend - self.sync_bn = sync_bn + self.sync_batchnorm = sync_batchnorm self.gradient_clip_val = gradient_clip_val self.check_val_every_n_epoch = check_val_every_n_epoch From 2d260a079cd3dac65df86e0fcbdf96eed0812e58 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Wed, 5 Aug 2020 14:41:02 -0400 Subject: [PATCH 02/12] updated sync bn --- pytorch_lightning/trainer/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index eea796411553e..5cec30f4a33a0 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -858,7 +858,7 @@ def on_train_end(self, trainer, pl_module): sync_batchnorm ^^^^^^^^^^^^^^^^^ -Enable synchrnization between batchnorm layers across all GPUs. +Enable synchronization between batchnorm layers across all GPUs. .. testcode:: From b3523d904295a94962e121fa3c822cdfb208c31c Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Wed, 5 Aug 2020 14:47:40 -0400 Subject: [PATCH 03/12] updated sync bn --- pytorch_lightning/trainer/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 5cec30f4a33a0..8dcec8eb30511 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -856,7 +856,7 @@ def on_train_end(self, trainer, pl_module): trainer = Trainer(row_log_interval=50) sync_batchnorm -^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^ Enable synchronization between batchnorm layers across all GPUs. From 9dadb2989997e21724baf872e6208c340df17a19 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Wed, 5 Aug 2020 15:11:27 -0400 Subject: [PATCH 04/12] updated sync bn --- pl_examples/basic_examples/sync_bn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_examples/basic_examples/sync_bn.py b/pl_examples/basic_examples/sync_bn.py index bb602a74ea89c..35f2cd4088a44 100644 --- a/pl_examples/basic_examples/sync_bn.py +++ b/pl_examples/basic_examples/sync_bn.py @@ -145,7 +145,7 @@ def main(args, datamodule, bn_outputs): distributed_backend=args.dist_backend, max_epochs=args.epochs, max_steps=args.steps, - sync_bn=args.bn_sync, + sync_batchnorm=args.bn_sync, num_sanity_val_steps=0, replace_sampler_ddp=False, ) From 5a88a6192f89dcb02d89616bd624dbcbe607c6a2 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Wed, 5 Aug 2020 14:39:02 -0400 Subject: [PATCH 05/12] updated sync bn --- pytorch_lightning/accelerator_backends/ddp_backend.py | 4 ++-- .../accelerator_backends/ddp_spawn_backend.py | 4 ++-- pytorch_lightning/core/lightning.py | 2 +- pytorch_lightning/trainer/__init__.py | 8 ++++++++ pytorch_lightning/trainer/trainer.py | 6 +++--- 5 files changed, 16 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/accelerator_backends/ddp_backend.py b/pytorch_lightning/accelerator_backends/ddp_backend.py index c2e549c18ef1a..44ad52d34ba2f 100644 --- a/pytorch_lightning/accelerator_backends/ddp_backend.py +++ b/pytorch_lightning/accelerator_backends/ddp_backend.py @@ -177,8 +177,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 self.trainer.optimizer_frequencies = optimizer_frequencies # call sync_bn before .cuda(), configure_apex and configure_ddp - if self.trainer.sync_bn: - model = model.configure_sync_bn(model) + if self.trainer.sync_batchnorm: + model = model.configure_sync_batchnorm(model) # MODEL # copy model to each gpu diff --git a/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py b/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py index 85de2d1b7759e..704fc5558588a 100644 --- a/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py +++ b/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py @@ -119,8 +119,8 @@ def ddp_train(self, process_idx, mp_queue, model): self.trainer.optimizer_frequencies = optimizer_frequencies # call sync_bn before .cuda(), configure_apex and configure_ddp - if self.trainer.sync_bn: - model = model.configure_sync_bn(model) + if self.trainer.sync_batchnorm: + model = model.configure_sync_batchnorm(model) # MODEL # copy model to each gpu diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 80081c0dd446f..d272c23fd9a65 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -957,7 +957,7 @@ def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managi log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank+1}/{world_size}") torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size) - def configure_sync_bn(self, model: 'LightningModule') -> 'LightningModule': + def configure_sync_batchnorm(self, model: 'LightningModule') -> 'LightningModule': """ Add global batchnorm for a model spread across multiple GPUs and nodes. diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 0164210c771fb..eea796411553e 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -855,6 +855,14 @@ def on_train_end(self, trainer, pl_module): # default used by the Trainer trainer = Trainer(row_log_interval=50) +sync_batchnorm +^^^^^^^^^^^^^^^^^ + +Enable synchrnization between batchnorm layers across all GPUs. + +.. testcode:: + + trainer = Trainer(sync_batchnorm=True) val_percent_check ^^^^^^^^^^^^^^^^^ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ebfe680fc3372..4b342328df297 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -184,7 +184,7 @@ def __init__( log_save_interval: int = 100, row_log_interval: int = 50, distributed_backend: Optional[str] = None, - sync_bn: bool = False, + sync_batchnorm: bool = False, precision: int = 32, weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT, weights_save_path: Optional[str] = None, @@ -297,7 +297,7 @@ def __init__( distributed_backend: The distributed backend to use (dp, ddp, ddp2, ddp_spawn, ddp_cpu) - sync_bn: Synchronize batch norm layers between process groups/whole world. + sync_batchnorm: Synchronize batch norm layers between process groups/whole world. precision: Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs. @@ -431,7 +431,7 @@ def __init__( self.log_gpu_memory = log_gpu_memory # sync-bn backend - self.sync_bn = sync_bn + self.sync_batchnorm = sync_batchnorm self.gradient_clip_val = gradient_clip_val self.check_val_every_n_epoch = check_val_every_n_epoch From 07aa2ec77858f63ba267b7370e06f768a97626f6 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Wed, 5 Aug 2020 14:41:02 -0400 Subject: [PATCH 06/12] updated sync bn --- pytorch_lightning/trainer/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index eea796411553e..5cec30f4a33a0 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -858,7 +858,7 @@ def on_train_end(self, trainer, pl_module): sync_batchnorm ^^^^^^^^^^^^^^^^^ -Enable synchrnization between batchnorm layers across all GPUs. +Enable synchronization between batchnorm layers across all GPUs. .. testcode:: From ed391e1a24ba579016655997cb7bcf51237be0c5 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Wed, 5 Aug 2020 14:47:40 -0400 Subject: [PATCH 07/12] updated sync bn --- pytorch_lightning/trainer/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 5cec30f4a33a0..8dcec8eb30511 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -856,7 +856,7 @@ def on_train_end(self, trainer, pl_module): trainer = Trainer(row_log_interval=50) sync_batchnorm -^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^ Enable synchronization between batchnorm layers across all GPUs. From 66f948b6395da066f262975542daafbf06bb74f6 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Wed, 5 Aug 2020 15:11:27 -0400 Subject: [PATCH 08/12] updated sync bn --- pl_examples/basic_examples/sync_bn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_examples/basic_examples/sync_bn.py b/pl_examples/basic_examples/sync_bn.py index bb602a74ea89c..35f2cd4088a44 100644 --- a/pl_examples/basic_examples/sync_bn.py +++ b/pl_examples/basic_examples/sync_bn.py @@ -145,7 +145,7 @@ def main(args, datamodule, bn_outputs): distributed_backend=args.dist_backend, max_epochs=args.epochs, max_steps=args.steps, - sync_bn=args.bn_sync, + sync_batchnorm=args.bn_sync, num_sanity_val_steps=0, replace_sampler_ddp=False, ) From 984c5db0681dad31454f29a1690edc4623162276 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Wed, 5 Aug 2020 17:24:03 -0400 Subject: [PATCH 09/12] added ddp_spawn test --- pl_examples/basic_examples/sync_bn.py | 204 -------------------------- pl_examples/test_examples.py | 14 -- tests/base/datamodules.py | 51 +++++++ tests/models/test_sync_batchnorm.py | 110 ++++++++++++++ 4 files changed, 161 insertions(+), 218 deletions(-) delete mode 100644 pl_examples/basic_examples/sync_bn.py create mode 100644 tests/models/test_sync_batchnorm.py diff --git a/pl_examples/basic_examples/sync_bn.py b/pl_examples/basic_examples/sync_bn.py deleted file mode 100644 index 35f2cd4088a44..0000000000000 --- a/pl_examples/basic_examples/sync_bn.py +++ /dev/null @@ -1,204 +0,0 @@ -""" -Sync-bn with DDP (GPU) - -This code is to verify that batch statistics are synchronized across GPUs using sync-bn. -When sync_bn is set to True the training loop should run for 3 iterations. -When sync_bn is set to False, the code should result in an AssertionError. -""" -import os -import math -import numpy as np -from argparse import ArgumentParser - -import torch -import torch.nn as nn -import torch.nn.functional as F -import pytorch_lightning as pl - -import torchvision.transforms as transforms -from torchvision.datasets import MNIST -from torch.utils.data import DataLoader, Dataset -from torch.utils.data.distributed import DistributedSampler - - -pl.seed_everything(234) -FLOAT16_EPSILON = np.finfo(np.float16).eps - - -class MNISTDataModule(pl.LightningDataModule): - def __init__(self, data_dir: str = './', batch_size=32, dist_sampler=False): - super().__init__() - - self.dist_sampler = dist_sampler - self.data_dir = data_dir - self.batch_size = batch_size - - self.transforms = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ]) - - # self.dims is returned when you call dm.size() - # Setting default dims here because we know them. - # Could optionally be assigned dynamically in dm.setup() - self.dims = (1, 28, 28) - - def prepare_data(self): - # download only - MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) - MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) - - def setup(self, stage=None): - - # Assign train/val datasets for use in dataloaders - if stage == 'fit' or stage is None: - self.mnist_train = MNIST(self.data_dir, train=True, transform=self.transforms) - - # Assign test dataset for use in dataloader(s) - if stage == 'test' or stage is None: - self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transforms) - - def train_dataloader(self): - dist_sampler = None - if self.dist_sampler: - dist_sampler = DistributedSampler(self.mnist_train, shuffle=False) - - return DataLoader( - self.mnist_train, batch_size=self.batch_size, sampler=dist_sampler, shuffle=False - ) - - def test_dataloader(self): - return DataLoader(self.mnist_test, batch_size=self.batch_size, shuffle=False) - - -class SyncBNModule(pl.LightningModule): - def __init__(self, gpu_count=1, **kwargs): - super().__init__() - - self.gpu_count = gpu_count - self.bn_targets = None - if 'bn_targets' in kwargs: - self.bn_targets = kwargs['bn_targets'] - - self.linear = nn.Linear(28 * 28, 10) - self.bn_layer = nn.BatchNorm1d(28 * 28) - - def forward(self, x, batch_idx): - with torch.no_grad(): - out_bn = self.bn_layer(x.view(x.size(0), -1)) - - if self.bn_targets: - bn_target = self.bn_targets[batch_idx] - - # executes on both GPUs - bn_target = bn_target[self.trainer.local_rank::self.gpu_count] - bn_target = bn_target.to(out_bn.device) - assert torch.sum(torch.abs(bn_target - out_bn)) < FLOAT16_EPSILON - - out = self.linear(out_bn) - - return out, out_bn - - def training_step(self, batch, batch_idx): - x, y = batch - - y_hat, _ = self(x, batch_idx) - loss = F.cross_entropy(y_hat, y) - - return pl.TrainResult(loss) - - def configure_optimizers(self): - return torch.optim.Adam(self.linear.parameters(), lr=0.02) - - @staticmethod - def add_model_specific_argument(parent_parser, root_dir): - """ - Define parameters that only apply to this model - """ - parser = ArgumentParser(parents=[parent_parser]) - - parser.add_argument('--nodes', default=1, type=int) - parser.add_argument('--gpu', default=2, type=int) - parser.add_argument('--dist_backend', default='ddp', type=str) - - parser.add_argument('--epochs', default=1, type=int) - parser.add_argument('--steps', default=3, type=int) - - parser.add_argument('--bn_sync', action='store_true') - - return parser - - -def main(args, datamodule, bn_outputs): - """Main training routine specific for this project.""" - # ------------------------ - # 1 INIT LIGHTNING MODEL - # ------------------------ - model = SyncBNModule(gpu_count=args.gpu, bn_targets=bn_outputs) - - # ------------------------ - # 2 INIT TRAINER - # ------------------------ - trainer = pl.Trainer( - gpus=args.gpu, - num_nodes=args.nodes, - distributed_backend=args.dist_backend, - max_epochs=args.epochs, - max_steps=args.steps, - sync_batchnorm=args.bn_sync, - num_sanity_val_steps=0, - replace_sampler_ddp=False, - ) - - # ------------------------ - # 3 START TRAINING - # ------------------------ - trainer.fit(model, datamodule) - - -def run_cli(): - root_dir = os.path.dirname(os.path.realpath(__file__)) - parent_parser = ArgumentParser(add_help=False) - - # define datamodule and dataloader - dm = MNISTDataModule() - dm.prepare_data() - dm.setup(stage=None) - - train_dataloader = dm.train_dataloader() - model = SyncBNModule() - - bn_outputs = [] - - # shuffle is false by default - for batch_idx, batch in enumerate(train_dataloader): - x, y = batch - - out, out_bn = model.forward(x, batch_idx) - bn_outputs.append(out_bn) - - # get 3 steps - if batch_idx == 2: - break - - bn_outputs = [x.cuda() for x in bn_outputs] - - # reset datamodule - # batch-size = 16 because 2 GPUs in DDP - dm = MNISTDataModule(batch_size=16, dist_sampler=True) - dm.prepare_data() - dm.setup(stage=None) - - # each LightningModule defines arguments relevant to it - parser = SyncBNModule.add_model_specific_argument(parent_parser, root_dir=root_dir) - parser = pl.Trainer.add_argparse_args(parser) - args = parser.parse_args() - - # --------------------- - # RUN TRAINING - # --------------------- - main(args, dm, bn_outputs) - - -if __name__ == '__main__': - run_cli() diff --git a/pl_examples/test_examples.py b/pl_examples/test_examples.py index d527354647f07..330135e8ea78a 100644 --- a/pl_examples/test_examples.py +++ b/pl_examples/test_examples.py @@ -25,20 +25,6 @@ def test_gpu_template(cli_args): run_cli() -@pytest.mark.parametrize( - 'cli_args', - ['--max_epochs 1 --max_steps 3 --num_nodes 1 --gpus 2 --dist_backend ddp_spawn --bn_sync'] -) -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -def test_sync_bn(cli_args): - """Test running CLI for an example with sync bn.""" - from pl_examples.basic_examples.sync_bn import run_cli - - cli_args = cli_args.split(' ') if cli_args else [] - with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): - run_cli() - - # @pytest.mark.parametrize('cli_args', ['--max_epochs 1 --max_steps 3 --num_nodes 1 --gpus 2']) # @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") # def test_multi_node_ddp(cli_args): diff --git a/tests/base/datamodules.py b/tests/base/datamodules.py index a55a9a718ea9d..37ed5c23ae775 100644 --- a/tests/base/datamodules.py +++ b/tests/base/datamodules.py @@ -1,7 +1,11 @@ +import os +import torchvision.transforms as transforms from torch.utils.data import random_split, DataLoader from pytorch_lightning.core.datamodule import LightningDataModule from tests.base.datasets import TrialMNIST +from torchvision.datasets import MNIST +from torch.utils.data.distributed import DistributedSampler class TrialMNISTDataModule(LightningDataModule): @@ -36,3 +40,50 @@ def val_dataloader(self): def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=32) + + +class MNISTDataModule(LightningDataModule): + def __init__(self, data_dir: str = './', batch_size=32, dist_sampler=False): + super().__init__() + + self.dist_sampler = dist_sampler + self.data_dir = data_dir + self.batch_size = batch_size + + self.transforms = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + + # self.dims is returned when you call dm.size() + # Setting default dims here because we know them. + # Could optionally be assigned dynamically in dm.setup() + self.dims = (1, 28, 28) + + def prepare_data(self): + # download only + MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) + MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) + + def setup(self, stage=None): + + # Assign train/val datasets for use in dataloaders + # TODO: need to split using random_split once updated to torch >= 1.6 + if stage == 'fit' or stage is None: + self.mnist_train = MNIST(self.data_dir, train=True, transform=self.transforms) + + # Assign test dataset for use in dataloader(s) + if stage == 'test' or stage is None: + self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transforms) + + def train_dataloader(self): + dist_sampler = None + if self.dist_sampler: + dist_sampler = DistributedSampler(self.mnist_train, shuffle=False) + + return DataLoader( + self.mnist_train, batch_size=self.batch_size, sampler=dist_sampler, shuffle=False + ) + + def test_dataloader(self): + return DataLoader(self.mnist_test, batch_size=self.batch_size, shuffle=False) diff --git a/tests/models/test_sync_batchnorm.py b/tests/models/test_sync_batchnorm.py new file mode 100644 index 0000000000000..609e7295611b9 --- /dev/null +++ b/tests/models/test_sync_batchnorm.py @@ -0,0 +1,110 @@ +import os +import math +import numpy as np +from argparse import ArgumentParser + +import pytest +from collections import namedtuple +import tests.base.develop_utils as tutils + +import torch +import torch.nn as nn +import torch.nn.functional as F +import pytorch_lightning as pl +from pytorch_lightning import Trainer +from tests.base.datamodules import MNISTDataModule + + +pl.seed_everything(234) +FLOAT16_EPSILON = np.finfo(np.float16).eps + + +class SyncBNModule(pl.LightningModule): + def __init__(self, gpu_count=1, **kwargs): + super().__init__() + + self.gpu_count = gpu_count + self.bn_targets = None + if 'bn_targets' in kwargs: + self.bn_targets = kwargs['bn_targets'] + + self.linear = nn.Linear(28 * 28, 10) + self.bn_layer = nn.BatchNorm1d(28 * 28) + + def forward(self, x, batch_idx): + with torch.no_grad(): + out_bn = self.bn_layer(x.view(x.size(0), -1)) + + if self.bn_targets: + bn_target = self.bn_targets[batch_idx] + + # executes on both GPUs + bn_target = bn_target[self.trainer.local_rank::self.gpu_count] + bn_target = bn_target.to(out_bn.device) + assert torch.sum(torch.abs(bn_target - out_bn)) < FLOAT16_EPSILON + + out = self.linear(out_bn) + + return out, out_bn + + def training_step(self, batch, batch_idx): + x, y = batch + + y_hat, _ = self(x, batch_idx) + loss = F.cross_entropy(y_hat, y) + + return pl.TrainResult(loss) + + def configure_optimizers(self): + return torch.optim.Adam(self.linear.parameters(), lr=0.02) + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_sync_batchnorm_ddp(tmpdir): + tutils.set_random_master_port() + + parent_parser = ArgumentParser(add_help=False) + + # define datamodule and dataloader + dm = MNISTDataModule() + dm.prepare_data() + dm.setup(stage=None) + + train_dataloader = dm.train_dataloader() + model = SyncBNModule() + + bn_outputs = [] + + # shuffle is false by default + for batch_idx, batch in enumerate(train_dataloader): + x, _ = batch + + _, out_bn = model.forward(x, batch_idx) + bn_outputs.append(out_bn) + + # get 3 steps + if batch_idx == 2: + break + + bn_outputs = [x.cuda() for x in bn_outputs] + + # reset datamodule + # batch-size = 16 because 2 GPUs in DDP + dm = MNISTDataModule(batch_size=16, dist_sampler=True) + dm.prepare_data() + dm.setup(stage=None) + + model = SyncBNModule(gpu_count=2, bn_targets=bn_outputs) + + trainer = Trainer( + gpus=2, + num_nodes=1, + distributed_backend='ddp_spawn', + max_epochs=1, + max_steps=3, + sync_batchnorm=True, + num_sanity_val_steps=0, + replace_sampler_ddp=False, + ) + + result = trainer.fit(model, dm) + assert result == 1, "Sync batchnorm failing with DDP" From 1a587a9451010655cdda13227856059ca25cf31e Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Wed, 5 Aug 2020 17:36:25 -0400 Subject: [PATCH 10/12] updated test --- tests/base/datamodules.py | 23 +++++++++-------------- tests/models/test_sync_batchnorm.py | 4 +--- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/tests/base/datamodules.py b/tests/base/datamodules.py index 37ed5c23ae775..d1f7fabf8d6b4 100644 --- a/tests/base/datamodules.py +++ b/tests/base/datamodules.py @@ -1,10 +1,8 @@ import os -import torchvision.transforms as transforms from torch.utils.data import random_split, DataLoader from pytorch_lightning.core.datamodule import LightningDataModule -from tests.base.datasets import TrialMNIST -from torchvision.datasets import MNIST +from tests.base.datasets import TrialMNIST, MNIST from torch.utils.data.distributed import DistributedSampler @@ -43,18 +41,15 @@ def test_dataloader(self): class MNISTDataModule(LightningDataModule): - def __init__(self, data_dir: str = './', batch_size=32, dist_sampler=False): + def __init__( + self, data_dir: str = './', batch_size: int = 32, dist_sampler: bool = False + ) -> None: super().__init__() self.dist_sampler = dist_sampler self.data_dir = data_dir self.batch_size = batch_size - self.transforms = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ]) - # self.dims is returned when you call dm.size() # Setting default dims here because we know them. # Could optionally be assigned dynamically in dm.setup() @@ -62,19 +57,19 @@ def __init__(self, data_dir: str = './', batch_size=32, dist_sampler=False): def prepare_data(self): # download only - MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) - MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) + MNIST(self.data_dir, train=True, download=True, normalize=(0.1307, 0.3081)) + MNIST(self.data_dir, train=False, download=True, normalize=(0.1307, 0.3081)) - def setup(self, stage=None): + def setup(self, stage: str = None): # Assign train/val datasets for use in dataloaders # TODO: need to split using random_split once updated to torch >= 1.6 if stage == 'fit' or stage is None: - self.mnist_train = MNIST(self.data_dir, train=True, transform=self.transforms) + self.mnist_train = MNIST(self.data_dir, train=True, normalize=(0.1307, 0.3081)) # Assign test dataset for use in dataloader(s) if stage == 'test' or stage is None: - self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transforms) + self.mnist_test = MNIST(self.data_dir, train=False, normalize=(0.1307, 0.3081)) def train_dataloader(self): dist_sampler = None diff --git a/tests/models/test_sync_batchnorm.py b/tests/models/test_sync_batchnorm.py index 609e7295611b9..5a614e1bc2425 100644 --- a/tests/models/test_sync_batchnorm.py +++ b/tests/models/test_sync_batchnorm.py @@ -1,7 +1,6 @@ import os import math import numpy as np -from argparse import ArgumentParser import pytest from collections import namedtuple @@ -58,12 +57,11 @@ def training_step(self, batch, batch_idx): def configure_optimizers(self): return torch.optim.Adam(self.linear.parameters(), lr=0.02) + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_sync_batchnorm_ddp(tmpdir): tutils.set_random_master_port() - parent_parser = ArgumentParser(add_help=False) - # define datamodule and dataloader dm = MNISTDataModule() dm.prepare_data() From 2c9d8b41e7ac7e12b0d62b7c1f6916df619513af Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 6 Aug 2020 00:50:48 +0200 Subject: [PATCH 11/12] clean --- tests/models/test_sync_batchnorm.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/tests/models/test_sync_batchnorm.py b/tests/models/test_sync_batchnorm.py index 5a614e1bc2425..828d419ca4aa5 100644 --- a/tests/models/test_sync_batchnorm.py +++ b/tests/models/test_sync_batchnorm.py @@ -1,23 +1,15 @@ -import os -import math -import numpy as np - import pytest -from collections import namedtuple -import tests.base.develop_utils as tutils - import torch import torch.nn as nn import torch.nn.functional as F + import pytorch_lightning as pl -from pytorch_lightning import Trainer +import tests.base.develop_utils as tutils +from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning.utilities import FLOAT16_EPSILON from tests.base.datamodules import MNISTDataModule -pl.seed_everything(234) -FLOAT16_EPSILON = np.finfo(np.float16).eps - - class SyncBNModule(pl.LightningModule): def __init__(self, gpu_count=1, **kwargs): super().__init__() @@ -60,6 +52,7 @@ def configure_optimizers(self): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_sync_batchnorm_ddp(tmpdir): + seed_everything(234) tutils.set_random_master_port() # define datamodule and dataloader From 6ace35820f0fc9c538a737ae1b1e378322f9b1ab Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 6 Aug 2020 00:55:43 +0200 Subject: [PATCH 12/12] clean --- tests/models/test_sync_batchnorm.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/models/test_sync_batchnorm.py b/tests/models/test_sync_batchnorm.py index 828d419ca4aa5..5aff30d0aacbd 100644 --- a/tests/models/test_sync_batchnorm.py +++ b/tests/models/test_sync_batchnorm.py @@ -3,14 +3,13 @@ import torch.nn as nn import torch.nn.functional as F -import pytorch_lightning as pl -import tests.base.develop_utils as tutils -from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning import Trainer, seed_everything, LightningModule, TrainResult from pytorch_lightning.utilities import FLOAT16_EPSILON from tests.base.datamodules import MNISTDataModule +from tests.base.develop_utils import set_random_master_port -class SyncBNModule(pl.LightningModule): +class SyncBNModule(LightningModule): def __init__(self, gpu_count=1, **kwargs): super().__init__() @@ -44,7 +43,7 @@ def training_step(self, batch, batch_idx): y_hat, _ = self(x, batch_idx) loss = F.cross_entropy(y_hat, y) - return pl.TrainResult(loss) + return TrainResult(loss) def configure_optimizers(self): return torch.optim.Adam(self.linear.parameters(), lr=0.02) @@ -53,7 +52,7 @@ def configure_optimizers(self): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_sync_batchnorm_ddp(tmpdir): seed_everything(234) - tutils.set_random_master_port() + set_random_master_port() # define datamodule and dataloader dm = MNISTDataModule()