From d69455a466fbf412103b59af2594faa7247e2102 Mon Sep 17 00:00:00 2001 From: Sho Arora Date: Sun, 1 Mar 2020 19:50:49 -0800 Subject: [PATCH] Use callable object for patching dataloaders (#971) * Use callable object for patching dataloaders * Add test for ddp with dataloaders passed to fit() * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Jirka Borovec * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Jirka Borovec Co-authored-by: Jirka Borovec --- pytorch_lightning/trainer/trainer.py | 30 +++++++++++++++++----------- tests/test_gpu_models.py | 25 +++++++++++++++++++++++ 2 files changed, 43 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ff1860c0acdf0..fa1d91e9af9a2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1002,30 +1002,21 @@ def __set_fit_dataloaders(self, model, train_dataloader, val_dataloaders, test_d m = 'You called .fit() with a train_dataloader but did not define training_step()' raise MisconfigurationException(m) - def patch_train_dataloader(): - return train_dataloader - - model.train_dataloader = patch_train_dataloader + model.train_dataloader = _PatchDataLoader(train_dataloader) if val_dataloaders is not None: if not self.is_overriden('validation_step', model): m = 'You called .fit() with a val_dataloaders but did not define validation_step()' raise MisconfigurationException(m) - def patch_val_dataloader(): - return val_dataloaders - - model.val_dataloader = patch_val_dataloader + model.val_dataloader = _PatchDataLoader(val_dataloaders) if test_dataloaders is not None: if not self.is_overriden('test_step', model): m = 'You called .fit() with a test_dataloaders but did not define test_step()' raise MisconfigurationException(m) - def patch_test_dataloader(): - return test_dataloaders - - model.test_dataloader = patch_test_dataloader + model.test_dataloader = _PatchDataLoader(test_dataloaders) def init_optimizers( self, @@ -1189,6 +1180,21 @@ def test(self, model: Optional[LightningModule] = None): self.run_evaluation(test_mode=True) +class _PatchDataLoader(object): + r''' + Callable object for patching dataloaders passed into trainer.fit(). + Use this class to override model.*_dataloader() and be pickle-compatible. + + Args: + dataloader: Dataloader object to return when called. + ''' + def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): + self.dataloader = dataloader + + def __call__(self) -> Union[List[DataLoader], DataLoader]: + return self.dataloader + + def _set_dataloader(model, dataloader, attribute): r''' Check dataloaders passed to .fit() method if they are pytorch DataLoader diff --git a/tests/test_gpu_models.py b/tests/test_gpu_models.py index a47b97f95ed91..5fd37870eee23 100644 --- a/tests/test_gpu_models.py +++ b/tests/test_gpu_models.py @@ -66,6 +66,31 @@ def test_multi_gpu_model_ddp(tmpdir): tutils.run_model_test(trainer_options, model) +def test_ddp_all_dataloaders_passed_to_fit(tmpdir): + """Make sure DDP works with dataloaders passed to fit()""" + if not tutils.can_run_gpu_test(): + return + + tutils.reset_seed() + tutils.set_random_master_port() + + model, hparams = tutils.get_model() + trainer_options = dict(default_save_path=tmpdir, + show_progress_bar=False, + max_epochs=1, + train_percent_check=0.4, + val_percent_check=0.2, + gpus=[0, 1], + distributed_backend='ddp') + + fit_options = dict(train_dataloader=model.train_dataloader(), + val_dataloaders=model.val_dataloader()) + + trainer = Trainer(**trainer_options) + result = trainer.fit(model, **fit_options) + assert result == 1, "DDP doesn't work with dataloaders passed to fit()." + + def test_optimizer_return_options(): tutils.reset_seed()