From bd3dc788cccfdd4e9028b656aa73f8acf71d5ced Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 1 Mar 2020 22:41:39 -0500 Subject: [PATCH 1/2] Update README.md --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 1c29d267b1d048..388d6977f60a08 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,9 @@ pip install pytorch-lightning [Copy and run this COLAB!](https://colab.research.google.com/drive/1F_RNcHzTfFuQf-LeKvSlud6x7jXYkG31#scrollTo=HOk9c4_35FKg) ## What is it? -Lightning is a way to organize your PyTorch code to decouple the science code from the engineering. It's more of a style-guide than a framework. By refactoring your code, we can automate most of the non-research code. Lightning guarantees tested, correct, modern best practices for the automated parts. +Lightning is a way to organize your PyTorch code to decouple the science code from the engineering. It's more of a style-guide than a framework. + +By refactoring your code, we can automate most of the non-research code. Lightning guarantees tested, correct, modern best practices for the automated parts. Here's an example of how to organize PyTorch code into the LightningModule. @@ -69,7 +71,7 @@ This is how lightning separates the science (red) from the engineering (blue). ![Overview](docs/source/_static/images/pl_overview.gif) ## How much effort is it to convert? -You're probably tired of switching frameworks at this point. But it is a very quick process to refactor into the Lightning format (ie: hours). [Check out this tutorial](https://towardsdatascience.com/how-to-refactor-your-pytorch-code-to-get-these-42-benefits-of-pytorch-lighting-6fdd0dc97538). +You're probably tired of switching frameworks at this point. But it is a very quick process to refactor into the Lightning format (ie: hours). [Check out this tutorial](https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09). ## What are the differences with PyTorch? If you're wondering what you gain out of refactoring your PyTorch code, [read this comparison!](https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09) From d69455a466fbf412103b59af2594faa7247e2102 Mon Sep 17 00:00:00 2001 From: Sho Arora Date: Sun, 1 Mar 2020 19:50:49 -0800 Subject: [PATCH 2/2] Use callable object for patching dataloaders (#971) * Use callable object for patching dataloaders * Add test for ddp with dataloaders passed to fit() * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Jirka Borovec * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Jirka Borovec Co-authored-by: Jirka Borovec --- pytorch_lightning/trainer/trainer.py | 30 +++++++++++++++++----------- tests/test_gpu_models.py | 25 +++++++++++++++++++++++ 2 files changed, 43 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ff1860c0acdf0d..fa1d91e9af9a29 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1002,30 +1002,21 @@ def __set_fit_dataloaders(self, model, train_dataloader, val_dataloaders, test_d m = 'You called .fit() with a train_dataloader but did not define training_step()' raise MisconfigurationException(m) - def patch_train_dataloader(): - return train_dataloader - - model.train_dataloader = patch_train_dataloader + model.train_dataloader = _PatchDataLoader(train_dataloader) if val_dataloaders is not None: if not self.is_overriden('validation_step', model): m = 'You called .fit() with a val_dataloaders but did not define validation_step()' raise MisconfigurationException(m) - def patch_val_dataloader(): - return val_dataloaders - - model.val_dataloader = patch_val_dataloader + model.val_dataloader = _PatchDataLoader(val_dataloaders) if test_dataloaders is not None: if not self.is_overriden('test_step', model): m = 'You called .fit() with a test_dataloaders but did not define test_step()' raise MisconfigurationException(m) - def patch_test_dataloader(): - return test_dataloaders - - model.test_dataloader = patch_test_dataloader + model.test_dataloader = _PatchDataLoader(test_dataloaders) def init_optimizers( self, @@ -1189,6 +1180,21 @@ def test(self, model: Optional[LightningModule] = None): self.run_evaluation(test_mode=True) +class _PatchDataLoader(object): + r''' + Callable object for patching dataloaders passed into trainer.fit(). + Use this class to override model.*_dataloader() and be pickle-compatible. + + Args: + dataloader: Dataloader object to return when called. + ''' + def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): + self.dataloader = dataloader + + def __call__(self) -> Union[List[DataLoader], DataLoader]: + return self.dataloader + + def _set_dataloader(model, dataloader, attribute): r''' Check dataloaders passed to .fit() method if they are pytorch DataLoader diff --git a/tests/test_gpu_models.py b/tests/test_gpu_models.py index a47b97f95ed919..5fd37870eee232 100644 --- a/tests/test_gpu_models.py +++ b/tests/test_gpu_models.py @@ -66,6 +66,31 @@ def test_multi_gpu_model_ddp(tmpdir): tutils.run_model_test(trainer_options, model) +def test_ddp_all_dataloaders_passed_to_fit(tmpdir): + """Make sure DDP works with dataloaders passed to fit()""" + if not tutils.can_run_gpu_test(): + return + + tutils.reset_seed() + tutils.set_random_master_port() + + model, hparams = tutils.get_model() + trainer_options = dict(default_save_path=tmpdir, + show_progress_bar=False, + max_epochs=1, + train_percent_check=0.4, + val_percent_check=0.2, + gpus=[0, 1], + distributed_backend='ddp') + + fit_options = dict(train_dataloader=model.train_dataloader(), + val_dataloaders=model.val_dataloader()) + + trainer = Trainer(**trainer_options) + result = trainer.fit(model, **fit_options) + assert result == 1, "DDP doesn't work with dataloaders passed to fit()." + + def test_optimizer_return_options(): tutils.reset_seed()