From 69121f40049d42fac4bc26fd564a625ad83a2a24 Mon Sep 17 00:00:00 2001 From: Sho Arora Date: Thu, 27 Feb 2020 16:34:23 -0800 Subject: [PATCH 1/4] Use callable object for patching dataloaders --- pytorch_lightning/trainer/trainer.py | 30 +++++++++++++++++----------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ff1860c0acdf0..75f3700bd62ba 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1002,30 +1002,21 @@ def __set_fit_dataloaders(self, model, train_dataloader, val_dataloaders, test_d m = 'You called .fit() with a train_dataloader but did not define training_step()' raise MisconfigurationException(m) - def patch_train_dataloader(): - return train_dataloader - - model.train_dataloader = patch_train_dataloader + model.train_dataloader = _PatchDataLoader(train_dataloader) if val_dataloaders is not None: if not self.is_overriden('validation_step', model): m = 'You called .fit() with a val_dataloaders but did not define validation_step()' raise MisconfigurationException(m) - def patch_val_dataloader(): - return val_dataloaders - - model.val_dataloader = patch_val_dataloader + model.val_dataloader = _PatchDataLoader(val_dataloaders) if test_dataloaders is not None: if not self.is_overriden('test_step', model): m = 'You called .fit() with a test_dataloaders but did not define test_step()' raise MisconfigurationException(m) - def patch_test_dataloader(): - return test_dataloaders - - model.test_dataloader = patch_test_dataloader + model.test_dataloader = _PatchDataLoader(test_dataloaders) def init_optimizers( self, @@ -1189,6 +1180,21 @@ def test(self, model: Optional[LightningModule] = None): self.run_evaluation(test_mode=True) +class _PatchDataLoader(object): + r''' + Callable object for patching dataloaders passed into trainer.fit(). + Use this class to override model.*_dataloader() and be pickle-compatible. + + Args: + dataloader: Dataloader object to return when called. + ''' + def __init__(self, dataloader): + self.dataloader = dataloader + + def __call__(self): + return self.dataloader + + def _set_dataloader(model, dataloader, attribute): r''' Check dataloaders passed to .fit() method if they are pytorch DataLoader From 864875dc7749b664c7eaa30564de7c3498b52316 Mon Sep 17 00:00:00 2001 From: Sho Arora Date: Fri, 28 Feb 2020 17:12:09 -0800 Subject: [PATCH 2/4] Add test for ddp with dataloaders passed to fit() --- tests/test_gpu_models.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/test_gpu_models.py b/tests/test_gpu_models.py index a47b97f95ed91..5fd37870eee23 100644 --- a/tests/test_gpu_models.py +++ b/tests/test_gpu_models.py @@ -66,6 +66,31 @@ def test_multi_gpu_model_ddp(tmpdir): tutils.run_model_test(trainer_options, model) +def test_ddp_all_dataloaders_passed_to_fit(tmpdir): + """Make sure DDP works with dataloaders passed to fit()""" + if not tutils.can_run_gpu_test(): + return + + tutils.reset_seed() + tutils.set_random_master_port() + + model, hparams = tutils.get_model() + trainer_options = dict(default_save_path=tmpdir, + show_progress_bar=False, + max_epochs=1, + train_percent_check=0.4, + val_percent_check=0.2, + gpus=[0, 1], + distributed_backend='ddp') + + fit_options = dict(train_dataloader=model.train_dataloader(), + val_dataloaders=model.val_dataloader()) + + trainer = Trainer(**trainer_options) + result = trainer.fit(model, **fit_options) + assert result == 1, "DDP doesn't work with dataloaders passed to fit()." + + def test_optimizer_return_options(): tutils.reset_seed() From 3e2409e4b01101e501405e7713301bcfb90ff55a Mon Sep 17 00:00:00 2001 From: Sho Arora Date: Sun, 1 Mar 2020 16:46:54 -0800 Subject: [PATCH 3/4] Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Jirka Borovec --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 75f3700bd62ba..ecbec886dd7d4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1188,7 +1188,7 @@ class _PatchDataLoader(object): Args: dataloader: Dataloader object to return when called. ''' - def __init__(self, dataloader): + def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): self.dataloader = dataloader def __call__(self): From 0d2ea79000492da324566015d739ee6ec811e61c Mon Sep 17 00:00:00 2001 From: Sho Arora Date: Sun, 1 Mar 2020 16:47:07 -0800 Subject: [PATCH 4/4] Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Jirka Borovec --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ecbec886dd7d4..fa1d91e9af9a2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1191,7 +1191,7 @@ class _PatchDataLoader(object): def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): self.dataloader = dataloader - def __call__(self): + def __call__(self) -> Union[List[DataLoader], DataLoader]: return self.dataloader