From 8dc681a77fea6df24f3bd427e50c7e50a2934cca Mon Sep 17 00:00:00 2001 From: nateraw Date: Wed, 29 Jul 2020 13:31:09 -0600 Subject: [PATCH 01/21] :sparkles: call dm hooks in trainer implicitly --- pytorch_lightning/core/datamodule.py | 26 ++++++++++++++++++++++++++ pytorch_lightning/trainer/trainer.py | 4 ++++ 2 files changed, 30 insertions(+) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 547f9dc87a605..d6b3054190661 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import inspect from abc import abstractmethod from argparse import ArgumentParser, Namespace @@ -28,17 +29,42 @@ def __call__(cls, *args, **kwargs): 1. Runs user defined subclass's __init__ 2. Assures prepare_data() runs on rank 0 + 3. Lets you check prepare_data and setup to see if they've been called """ # Wrap cls's prepare_data function with rank_zero_only cls.prepare_data = rank_zero_only(cls.prepare_data) + # prepare_data and setup wrapped w/ function to track if they've been called. + # Usage: your_dm.setup.has_been_called & your_dm.prepare_data.has_been_called + cls.prepare_data = track_func_calls(cls.prepare_data) + cls.setup = track_func_calls(cls.setup) + # Get instance of LightningDataModule by mocking its __init__ via __call__ obj = type.__call__(cls, *args, **kwargs) return obj +def track_func_calls(fn): + """A decorator that checks if a function has been called. + + Args: + fn (function): Function that will be tracked to see if it has been called. + + Returns: + callable: Your function with an added bool attr fn.has_been_called. + """ + @functools.wraps(fn) + def wrapped_fn(*args, **kwargs): + wrapped_fn.has_been_called = True + return fn(*args, **kwargs) + + wrapped_fn.has_been_called = False + + return wrapped_fn + + class LightningDataModule(object, metaclass=_DataModuleWrapper): # pragma: no cover """ A DataModule standardizes the training, val, test splits, data preparation and transforms. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0a4c9c349fba4..5024d4b0769d5 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1076,6 +1076,10 @@ def __attach_datamodule(self, model, datamodule=None): # If we have a datamodule, attach necessary hooks + dataloaders if datamodule: + if self.is_overridden('prepare_data', datamodule) and not datamodule.prepare_data.has_been_called: + datamodule.prepare_data() + if self.is_overridden('setup', datamodule) and not datamodule.setup.has_been_called: + datamodule.setup() if self.is_overridden('train_dataloader', datamodule): model.train_dataloader = datamodule.train_dataloader if self.is_overridden('val_dataloader', datamodule): From 5064b7df208dcc9aac2299cf0e380453910b5d6f Mon Sep 17 00:00:00 2001 From: nateraw Date: Wed, 29 Jul 2020 13:44:06 -0600 Subject: [PATCH 02/21] :white_check_mark: update tests --- tests/core/test_datamodules.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 76f62590f904a..8f88decfb7f7a 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -13,6 +13,20 @@ def test_base_datamodule(tmpdir): dm.setup() +def test_dm_has_been_called(tmpdir): + dm = TrialMNISTDataModule() + assert dm.prepare_data.has_been_called is False + assert dm.setup.has_been_called is False + + dm.prepare_data() + assert dm.prepare_data.has_been_called is True + assert dm.setup.has_been_called is False + + dm.setup() + assert dm.prepare_data.has_been_called is True + assert dm.setup.has_been_called is True + + def test_dm_add_argparse_args(tmpdir): parser = ArgumentParser() parser = TrialMNISTDataModule.add_argparse_args(parser) @@ -43,8 +57,6 @@ def test_dm_pickle_after_setup(tmpdir): def test_train_loop_only(tmpdir): dm = TrialMNISTDataModule(tmpdir) - dm.prepare_data() - dm.setup() model = EvalModelTemplate() model.validation_step = None @@ -69,8 +81,6 @@ def test_train_loop_only(tmpdir): def test_train_val_loop_only(tmpdir): dm = TrialMNISTDataModule(tmpdir) - dm.prepare_data() - dm.setup() model = EvalModelTemplate() model.validation_step = None @@ -87,13 +97,11 @@ def test_train_val_loop_only(tmpdir): # fit model result = trainer.fit(model) assert result == 1 - assert trainer.callback_metrics['loss'] < 0.50 + assert trainer.callback_metrics['loss'] < 0.65 def test_full_loop(tmpdir): dm = TrialMNISTDataModule(tmpdir) - dm.prepare_data() - dm.setup() model = EvalModelTemplate() @@ -117,8 +125,6 @@ def test_full_loop(tmpdir): @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires multi-GPU machine") def test_full_loop_single_gpu(tmpdir): dm = TrialMNISTDataModule(tmpdir) - dm.prepare_data() - dm.setup() model = EvalModelTemplate() @@ -143,8 +149,6 @@ def test_full_loop_single_gpu(tmpdir): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_full_loop_dp(tmpdir): dm = TrialMNISTDataModule(tmpdir) - dm.prepare_data() - dm.setup() model = EvalModelTemplate() @@ -173,8 +177,6 @@ def test_full_loop_ddp_spawn(tmpdir): os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' dm = TrialMNISTDataModule(tmpdir) - dm.prepare_data() - dm.setup() model = EvalModelTemplate() From 0e43c0b8f324a0c896f28f91e5af87c87f7d8240 Mon Sep 17 00:00:00 2001 From: nateraw Date: Wed, 29 Jul 2020 13:46:00 -0600 Subject: [PATCH 03/21] :pencil: remove unused stage arg from dm docs --- docs/source/datamodules.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/datamodules.rst b/docs/source/datamodules.rst index 9a78158d947e8..5ffe5de763b69 100644 --- a/docs/source/datamodules.rst +++ b/docs/source/datamodules.rst @@ -19,7 +19,7 @@ matching transforms and data processing/downloads steps required. ... MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) ... MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) ... - ... def setup(self, stage): + ... def setup(self): ... mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor()) ... mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor()) ... # train/val split @@ -79,7 +79,7 @@ There are also data operations you might want to perform on every GPU. Use setup >>> import pytorch_lightning as pl >>> class MNISTDataModule(pl.LightningDataModule): - ... def setup(self, stage): + ... def setup(self): ... mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor()) ... mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor()) ... # train/val split From 8550fb309e1a109c26d2f58b701d21d94af891c0 Mon Sep 17 00:00:00 2001 From: nateraw Date: Wed, 29 Jul 2020 14:09:34 -0600 Subject: [PATCH 04/21] :white_check_mark: update tests --- tests/core/test_datamodules.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 8f88decfb7f7a..61e843a29a639 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -1,10 +1,13 @@ import pickle +from argparse import ArgumentParser + import torch import pytest + from pytorch_lightning import Trainer -from tests.base.datamodules import TrialMNISTDataModule from tests.base import EvalModelTemplate -from argparse import ArgumentParser +from tests.base.datamodules import TrialMNISTDataModule +from tests.base.develop_utils import reset_seed def test_base_datamodule(tmpdir): @@ -80,6 +83,8 @@ def test_train_loop_only(tmpdir): def test_train_val_loop_only(tmpdir): + reset_seed() + dm = TrialMNISTDataModule(tmpdir) model = EvalModelTemplate() @@ -101,6 +106,8 @@ def test_train_val_loop_only(tmpdir): def test_full_loop(tmpdir): + reset_seed() + dm = TrialMNISTDataModule(tmpdir) model = EvalModelTemplate() @@ -124,6 +131,8 @@ def test_full_loop(tmpdir): @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires multi-GPU machine") def test_full_loop_single_gpu(tmpdir): + reset_seed() + dm = TrialMNISTDataModule(tmpdir) model = EvalModelTemplate() @@ -148,6 +157,8 @@ def test_full_loop_single_gpu(tmpdir): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_full_loop_dp(tmpdir): + reset_seed() + dm = TrialMNISTDataModule(tmpdir) model = EvalModelTemplate() @@ -176,6 +187,8 @@ def test_full_loop_ddp_spawn(tmpdir): import os os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' + reset_seed() + dm = TrialMNISTDataModule(tmpdir) model = EvalModelTemplate() From 94c1eb146a2f5db581835144fd54c4be585d54de Mon Sep 17 00:00:00 2001 From: nateraw Date: Wed, 29 Jul 2020 14:20:10 -0600 Subject: [PATCH 05/21] :white_check_mark: update tests --- tests/core/test_datamodules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 61e843a29a639..c3bcc67f27b20 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -79,7 +79,7 @@ def test_train_loop_only(tmpdir): # fit model result = trainer.fit(model) assert result == 1 - assert trainer.callback_metrics['loss'] < 0.50 + assert trainer.callback_metrics['loss'] < 0.6 def test_train_val_loop_only(tmpdir): @@ -102,7 +102,7 @@ def test_train_val_loop_only(tmpdir): # fit model result = trainer.fit(model) assert result == 1 - assert trainer.callback_metrics['loss'] < 0.65 + assert trainer.callback_metrics['loss'] < 0.6 def test_full_loop(tmpdir): From 05a16d76e6e90957021aec94548ed929a4e4ae30 Mon Sep 17 00:00:00 2001 From: nateraw Date: Thu, 30 Jul 2020 01:49:32 -0600 Subject: [PATCH 06/21] :construction: include stage in datamodule.setup --- pytorch_lightning/core/datamodule.py | 85 ++++++++++++++++++++++------ pytorch_lightning/trainer/trainer.py | 24 ++++++-- tests/base/datamodules.py | 19 +++++-- tests/core/test_datamodules.py | 68 +++++++++++++++++++--- 4 files changed, 162 insertions(+), 34 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index d6b3054190661..9e9f641409843 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -16,7 +16,7 @@ import inspect from abc import abstractmethod from argparse import ArgumentParser, Namespace -from typing import Any, List, Tuple, Union +from typing import Any, List, Optional, Tuple, Union from torch.utils.data import DataLoader @@ -32,13 +32,10 @@ def __call__(cls, *args, **kwargs): 3. Lets you check prepare_data and setup to see if they've been called """ - # Wrap cls's prepare_data function with rank_zero_only - cls.prepare_data = rank_zero_only(cls.prepare_data) - - # prepare_data and setup wrapped w/ function to track if they've been called. - # Usage: your_dm.setup.has_been_called & your_dm.prepare_data.has_been_called - cls.prepare_data = track_func_calls(cls.prepare_data) - cls.setup = track_func_calls(cls.setup) + # Track prepare_data calls and make sure it runs on rank zero + cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data)) + # Track setup calls + cls.setup = track_data_hook_calls(cls.setup) # Get instance of LightningDataModule by mocking its __init__ via __call__ obj = type.__call__(cls, *args, **kwargs) @@ -46,21 +43,45 @@ def __call__(cls, *args, **kwargs): return obj -def track_func_calls(fn): - """A decorator that checks if a function has been called. +def track_data_hook_calls(fn): + """A decorator that checks if prepare_data/setup have been called. + + - When dm.prepare_data() is called, dm.has_prepared_data gets set to True + - When dm.setup('fit') is called, dm.has_setup_fit gets set to True + - When dm.setup('test') is called, dm.has_setup_test gets set to True + - When dm.setup() is called without stage arg, both dm.has_setup_fit and dm.has_setup_test get set to True Args: fn (function): Function that will be tracked to see if it has been called. Returns: - callable: Your function with an added bool attr fn.has_been_called. + function: Decorated function that tracks its call status and saves it to private attrs in its obj instance. """ + @functools.wraps(fn) def wrapped_fn(*args, **kwargs): - wrapped_fn.has_been_called = True - return fn(*args, **kwargs) - wrapped_fn.has_been_called = False + # The object instance from which setup or prepare_data was called + obj = args[0] + + # If calling setup, we check the stage and assign stage-specific bool args + if fn.__name__ == 'setup': + + # Get stage either by grabbing from args or checking kwargs. + # If not provided, set call status of 'fit' and 'test' to True. + # We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test() + stage = args[1] if len(args) > 1 else kwargs.get('stage', None) + + if stage == 'fit' or stage is None: + obj._has_setup_fit = True + + if stage == 'test' or stage is None: + obj._has_setup_test = True + + if fn.__name__ == 'prepare_data': + obj._has_prepared_data = True + + return fn(*args, **kwargs) return wrapped_fn @@ -116,6 +137,11 @@ def __init__( self._test_transforms = test_transforms self.dims = () + # Private attrs to keep track of whether or not data hooks have been called yet + self._has_prepared_data = False + self._has_setup_fit = False + self._has_setup_test = False + @property def train_transforms(self): """ @@ -159,6 +185,33 @@ def size(self, dim=None) -> Union[Tuple, int]: return self.dims + @property + def has_prepared_data(self): + """Return bool letting you know if datamodule.prepare_data() has been called or not. + + Returns: + bool: True if datamodule.prepare_data() has been called. False by default. + """ + return self._has_prepared_data + + @property + def has_setup_fit(self): + """Return bool letting you know if datamodule.setup('fit') has been called or not. + + Returns: + bool: True if datamodule.setup('fit') has been called. False by default. + """ + return self._has_setup_fit + + @property + def has_setup_test(self): + """Return bool letting you know if datamodule.setup('test') has been called or not. + + Returns: + bool: True if datamodule.setup('test') has been called. False by default. + """ + return self._has_setup_test + @abstractmethod def prepare_data(self, *args, **kwargs): """ @@ -181,14 +234,14 @@ def prepare_data(self): """ @abstractmethod - def setup(self, *args, **kwargs): + def setup(self, stage: Optional[str] = None): """ Use this to load your data from file, split it, etc. You are safe to make state assignments here. This hook is called on every process when using DDP. Example:: - def setup(self): + def setup(self, stage): data = load_data(...) self.train_ds, self.val_ds, self.test_ds = split_data(data) """ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5024d4b0769d5..19d22e82ee900 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -941,7 +941,7 @@ def fit( # set up the passed in dataloaders (if needed) self.__attach_dataloaders(model, train_dataloader, val_dataloaders) - self.__attach_datamodule(model, datamodule) + self.__attach_datamodule(model, datamodule, 'fit') # check that model is configured correctly self.config_validator.verify_loop_configurations(model) @@ -1069,17 +1069,29 @@ def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=Non if test_dataloaders is not None: model.test_dataloader = _PatchDataLoader(test_dataloaders) - def __attach_datamodule(self, model, datamodule=None): + def __attach_datamodule(self, model, datamodule, stage): # We use datamodule if it's been provided on .fit or .test, otherwise we check model for it datamodule = datamodule or getattr(model, 'datamodule', None) # If we have a datamodule, attach necessary hooks + dataloaders if datamodule: - if self.is_overridden('prepare_data', datamodule) and not datamodule.prepare_data.has_been_called: + + # If datamodule.prepare_data() has not been called yet, call it + if self.is_overridden('prepare_data', datamodule) and not datamodule.has_prepared_data: datamodule.prepare_data() - if self.is_overridden('setup', datamodule) and not datamodule.setup.has_been_called: - datamodule.setup() + + # If datamodule.setup('fit') has not been called yet, call it + if stage == 'fit': + if self.is_overridden('setup', datamodule) and not datamodule.has_setup_fit: + datamodule.setup('fit') + + # If datamodule.setup('test') has not been called yet, call it + if stage == 'test': + if self.is_overridden('setup', datamodule) and not datamodule.has_setup_test: + datamodule.setup('test') + + # Override loader hooks if self.is_overridden('train_dataloader', datamodule): model.train_dataloader = datamodule.train_dataloader if self.is_overridden('val_dataloader', datamodule): @@ -1283,7 +1295,7 @@ def test( ) # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.__attach_datamodule(model or self.get_model(), datamodule) + self.__attach_datamodule(model or self.get_model(), datamodule, 'test') self.setup('test') diff --git a/tests/base/datamodules.py b/tests/base/datamodules.py index d863c85605af7..a55a9a718ea9d 100644 --- a/tests/base/datamodules.py +++ b/tests/base/datamodules.py @@ -5,19 +5,28 @@ class TrialMNISTDataModule(LightningDataModule): + def __init__(self, data_dir: str = './'): super().__init__() self.data_dir = data_dir + self.non_picklable = None def prepare_data(self): TrialMNIST(self.data_dir, train=True, download=True) TrialMNIST(self.data_dir, train=False, download=True) - def setup(self): - mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True) - self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64]) - self.dims = tuple(self.mnist_train[0][0].shape) - self.mnist_test = TrialMNIST(root=self.data_dir, train=False, num_samples=32, download=True) + def setup(self, stage: str = None): + + if stage == 'fit' or stage is None: + mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True) + self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64]) + self.dims = self.mnist_train[0][0].shape + + if stage == 'test' or stage is None: + self.mnist_test = TrialMNIST(root=self.data_dir, train=False, num_samples=32, download=True) + self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape) + + self.non_picklable = lambda x: x**2 def train_dataloader(self): return DataLoader(self.mnist_train, batch_size=32) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index c3bcc67f27b20..929c5c3d2c805 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -16,18 +16,64 @@ def test_base_datamodule(tmpdir): dm.setup() -def test_dm_has_been_called(tmpdir): +def test_base_datamodule_with_verbose_setup(tmpdir): dm = TrialMNISTDataModule() - assert dm.prepare_data.has_been_called is False - assert dm.setup.has_been_called is False + dm.prepare_data() + dm.setup('fit') + dm.setup('test') + + +def test_data_hooks_called(tmpdir): + dm = TrialMNISTDataModule() + assert dm.has_prepared_data is False + assert dm.has_setup_fit is False + assert dm.has_setup_test is False dm.prepare_data() - assert dm.prepare_data.has_been_called is True - assert dm.setup.has_been_called is False + assert dm.has_prepared_data is True + assert dm.has_setup_fit is False + assert dm.has_setup_test is False dm.setup() - assert dm.prepare_data.has_been_called is True - assert dm.setup.has_been_called is True + assert dm.has_prepared_data is True + assert dm.has_setup_fit is True + assert dm.has_setup_test is True + + +def test_data_hooks_called_verbose(tmpdir): + dm = TrialMNISTDataModule() + assert dm.has_prepared_data is False + assert dm.has_setup_fit is False + assert dm.has_setup_test is False + + dm.prepare_data() + assert dm.has_prepared_data is True + assert dm.has_setup_fit is False + assert dm.has_setup_test is False + + dm.setup('fit') + assert dm.has_prepared_data is True + assert dm.has_setup_fit is True + assert dm.has_setup_test is False + + dm.setup('test') + assert dm.has_prepared_data is True + assert dm.has_setup_fit is True + assert dm.has_setup_test is True + + +def test_data_hooks_called_with_stage_kwarg(tmpdir): + dm = TrialMNISTDataModule() + dm.prepare_data() + assert dm.has_prepared_data is True + + dm.setup(stage='fit') + assert dm.has_setup_fit is True + assert dm.has_setup_test is False + + dm.setup(stage='test') + assert dm.has_setup_fit is True + assert dm.has_setup_test is True def test_dm_add_argparse_args(tmpdir): @@ -58,6 +104,14 @@ def test_dm_pickle_after_setup(tmpdir): pickle.dumps(dm) +def test_dm_pickle_after_setup_verbose(tmpdir): + dm = TrialMNISTDataModule() + dm.prepare_data() + dm.setup('fit') + dm.setup('test') + pickle.dumps(dm) + + def test_train_loop_only(tmpdir): dm = TrialMNISTDataModule(tmpdir) From d55bcd76f969106145df494ad82de6d13e0a8728 Mon Sep 17 00:00:00 2001 From: nateraw Date: Thu, 30 Jul 2020 03:58:50 -0600 Subject: [PATCH 07/21] :pencil: docs --- docs/source/datamodules.rst | 174 +++++++++++++++++++++--------------- 1 file changed, 104 insertions(+), 70 deletions(-) diff --git a/docs/source/datamodules.rst b/docs/source/datamodules.rst index 5ffe5de763b69..12b26b4a23857 100644 --- a/docs/source/datamodules.rst +++ b/docs/source/datamodules.rst @@ -11,33 +11,40 @@ Data preparation in PyTorch follows 5 steps: A DataModule is simply a collection of a train_dataloader, val_dataloader(s), test_dataloader(s) along with the matching transforms and data processing/downloads steps required. +.. code-block:: python + + class MNISTDataModule(LightningDataModule): + + def __init__(self, data_dir: str = './'): + super().__init__() + self.data_dir = data_dir + + def prepare_data(self): + # download + MNIST(self.data_dir, train=True, download=True) + MNIST(self.data_dir, train=False, download=True) + + def setup(self, stage): + + # Assign train/val datasets for use in dataloaders + if stage == 'fit': + mnist_full = MNIST(self.data_dir, train=True, download=True) + self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) + + # Assign test dataset for use in dataloader(s) + if stage == 'test': + self.mnist_test = MNIST(self.data_dir, train=False, download=True) + + def train_dataloader(self): + return DataLoader(self.mnist_train, batch_size=32) + + def val_dataloader(self): + return DataLoader(self.mnist_val, batch_size=32) + + def test_dataloader(self): + return DataLoader(self.mnist_test, batch_size=32) - >>> import pytorch_lightning as pl - >>> class MNISTDataModule(pl.LightningDataModule): - ... def prepare_data(self): - ... # download - ... MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) - ... MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) - ... - ... def setup(self): - ... mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor()) - ... mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor()) - ... # train/val split - ... mnist_train, mnist_val = random_split(mnist_train, [55000, 5000]) - ... - ... # assign to use in dataloaders - ... self.train_dataset = mnist_train - ... self.val_dataset = mnist_val - ... self.test_dataset = mnist_test - ... - ... def train_dataloader(self): - ... return DataLoader(self.train_dataset, batch_size=64) - ... - ... def val_dataloader(self): - ... return DataLoader(self.val_dataset, batch_size=64) - ... - ... def test_dataloader(self): - ... return DataLoader(self.test_dataset, batch_size=64) +.. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``. --------------- @@ -60,11 +67,13 @@ settings. - tokenize - etc... - >>> class MNISTDataModule(pl.LightningDataModule): - ... def prepare_data(self): - ... # download - ... MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) - ... MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) +.. code-block:: python + + class MNISTDataModule(pl.LightningDataModule): + def prepare_data(self): + # download + MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) + MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) .. warning:: `prepare_data` is called from a single GPU. Do not use it to assign state (`self.x = y`). @@ -77,33 +86,46 @@ There are also data operations you might want to perform on every GPU. Use setup - perform train/val/test splits - etc... - >>> import pytorch_lightning as pl - >>> class MNISTDataModule(pl.LightningDataModule): - ... def setup(self): - ... mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor()) - ... mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor()) - ... # train/val split - ... mnist_train, mnist_val = random_split(mnist_train, [55000, 5000]) - ... - ... # assign to use in dataloaders - ... self.train_dataset = mnist_train - ... self.val_dataset = mnist_val - ... self.test_dataset = mnist_test +.. code-block:: python + + import pytorch_lightning as pl + + + class MNISTDataModule(pl.LightningDataModule): + + def setup(self, stage: Optional[str] = None): + + # Assign Train/val split(s) for use in Dataloaders + if stage == 'fit' or stage is None: + mnist_full = MNIST(self.data_dir, train=True, download=True) + self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) + self.dims = self.mnist_train[0][0].shape + + # Assign Test split(s) for use in Dataloaders + if stage == 'test' or stage is None: + self.mnist_test = MNIST(self.data_dir, train=False, download=True) + self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape) + .. warning:: `setup` is called from every GPU. Setting state here is okay. + train_dataloader ^^^^^^^^^^^^^^^^ Use this method to generate the train dataloader. This is also a good place to place default transformations. - >>> import pytorch_lightning as pl - >>> class MNISTDataModule(pl.LightningDataModule): - ... def train_dataloader(self): - ... transforms = transform_lib.Compose([ - ... transform_lib.ToTensor(), - ... transform_lib.Normalize(mean=(0.5,), std=(0.5,)), - ... ]) - ... return DataLoader(self.train_dataset, transform=transforms, batch_size=64) +.. code-block:: python + + import pytorch_lightning as pl + + + class MNISTDataModule(pl.LightningDataModule): + def train_dataloader(self): + transforms = transform_lib.Compose([ + transform_lib.ToTensor(), + transform_lib.Normalize(mean=(0.5,), std=(0.5,)), + ]) + return DataLoader(self.train_dataset, transform=transforms, batch_size=64) However, to decouple your data from transforms you can parametrize them via `__init__`. @@ -119,32 +141,41 @@ val_dataloader ^^^^^^^^^^^^^^ Use this method to generate the val dataloader. This is also a good place to place default transformations. - >>> import pytorch_lightning as pl - >>> class MNISTDataModule(pl.LightningDataModule): - ... def val_dataloader(self): - ... transforms = transform_lib.Compose([ - ... transform_lib.ToTensor(), - ... transform_lib.Normalize(mean=(0.5,), std=(0.5,)), - ... ]) - ... return DataLoader(self.val_dataset, transform=transforms, batch_size=64) +.. code-block:: python + + import pytorch_lightning as pl + + + class MNISTDataModule(pl.LightningDataModule): + def val_dataloader(self): + transforms = transform_lib.Compose([ + transform_lib.ToTensor(), + transform_lib.Normalize(mean=(0.5,), std=(0.5,)), + ]) + return DataLoader(self.val_dataset, transform=transforms, batch_size=64) test_dataloader ^^^^^^^^^^^^^^^ Use this method to generate the test dataloader. This is also a good place to place default transformations. - >>> import pytorch_lightning as pl - >>> class MNISTDataModule(pl.LightningDataModule): - ... def test_dataloader(self): - ... transforms = transform_lib.Compose([ - ... transform_lib.ToTensor(), - ... transform_lib.Normalize(mean=(0.5,), std=(0.5,)), - ... ]) - ... return DataLoader(self.test_dataset, transform=transforms, batch_size=64) +.. code-block:: python + + import pytorch_lightning as pl + + + class MNISTDataModule(pl.LightningDataModule): + def test_dataloader(self): + transforms = transform_lib.Compose([ + transform_lib.ToTensor(), + transform_lib.Normalize(mean=(0.5,), std=(0.5,)), + ]) + return DataLoader(self.test_dataset, transform=transforms, batch_size=64) ------------------ Using a DataModule ------------------ + The recommended way to use a DataModule is simply: .. code-block:: python @@ -162,12 +193,13 @@ still ensures the method runs on the correct devices) dm = MNISTDataModule() dm.prepare_data() - dm.setup() + dm.setup('fit') model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab) trainer.fit(model, dm) - trainer.test(model, datamodule=dm) + dm.setup('test') + trainer.test(datamodule=dm) ---------------- @@ -184,12 +216,14 @@ DataModules have a few key advantages: dm = MNISTDataModule() dm.prepare_data() - dm.setup() + dm.setup('fit') for batch in dm.train_dataloader(): ... for batch in dm.val_dataloader(): ... + + dm.setup('test') for batch in dm.test_dataloader(): ... From 981378c46382d5aac492621c05fe7ad1b5d19225 Mon Sep 17 00:00:00 2001 From: nateraw Date: Thu, 30 Jul 2020 15:26:14 -0600 Subject: [PATCH 08/21] :pencil: docs --- docs/source/datamodules.rst | 35 ++++++++++++++++++++++++++++------ tests/core/test_datamodules.py | 15 +++++++++++++++ 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/docs/source/datamodules.rst b/docs/source/datamodules.rst index 12b26b4a23857..8bc7cfc56447b 100644 --- a/docs/source/datamodules.rst +++ b/docs/source/datamodules.rst @@ -13,27 +13,50 @@ matching transforms and data processing/downloads steps required. .. code-block:: python - class MNISTDataModule(LightningDataModule): + import pytorch_lightning as pl + from torch.utils.data import random_split, DataLoader + + # Note - you must have torchvision installed for this example + from torchvision.datasets import MNIST + from torchvision import transforms + + + class MNISTDataModule(pl.LightningDataModule): def __init__(self, data_dir: str = './'): super().__init__() self.data_dir = data_dir + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + + # self.dims is returned when you call dm.size() + # Setting default dims here because we know them. + # Could optionally be assigned dynamically in dm.setup() + self.dims = (1, 28, 28) def prepare_data(self): # download MNIST(self.data_dir, train=True, download=True) MNIST(self.data_dir, train=False, download=True) - def setup(self, stage): + def setup(self, stage=None): # Assign train/val datasets for use in dataloaders - if stage == 'fit': - mnist_full = MNIST(self.data_dir, train=True, download=True) + if stage == 'fit' or stage is None: + mnist_full = MNIST(self.data_dir, train=True, transform=self.transform) self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) + # Optionally... + # self.dims = tuple(self.mnist_train[0][0].shape) + # Assign test dataset for use in dataloader(s) - if stage == 'test': - self.mnist_test = MNIST(self.data_dir, train=False, download=True) + if stage == 'test' or stage is None: + self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform) + + # Optionally... + # self.dims = tuple(self.mnist_test[0][0].shape) def train_dataloader(self): return DataLoader(self.mnist_train, batch_size=32) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 929c5c3d2c805..b525a2b344c73 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -159,6 +159,21 @@ def test_train_val_loop_only(tmpdir): assert trainer.callback_metrics['loss'] < 0.6 +def test_test_loop_only(tmpdir): + reset_seed() + + dm = TrialMNISTDataModule(tmpdir) + + model = EvalModelTemplate() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + weights_summary=None, + ) + trainer.test(model, datamodule=dm) + + def test_full_loop(tmpdir): reset_seed() From 9331b6095118437e3902623f6823f01147441c18 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 30 Jul 2020 18:38:30 -0400 Subject: [PATCH 09/21] added more dm tests --- tests/core/test_datamodules.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index b525a2b344c73..a4217ae8a5344 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -279,3 +279,34 @@ def test_full_loop_ddp_spawn(tmpdir): result = trainer.test(datamodule=dm) result = result[0] assert result['test_acc'] > 0.8 + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_full_loop_ddp_spawn_non_picklable(tmpdir): + import os + os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' + + reset_seed() + + dm = TrialMNISTDataModule(tmpdir) + dm.non_pickle_thing = lambda x: x**2 + + model = EvalModelTemplate() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + weights_summary=None, + distributed_backend='ddp_spawn', + gpus=[0, 1] + ) + trainer.fit(model, dm) + + # fit model + result = trainer.fit(model) + assert result == 1 + + # test + result = trainer.test(datamodule=dm) + result = result[0] + assert result['test_acc'] > 0.8 From 6be261b07a658a01b44aef6b43a30348ab6fd199 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 30 Jul 2020 18:40:44 -0400 Subject: [PATCH 10/21] added more dm tests --- tests/core/test_datamodules.py | 31 ------------------------------- 1 file changed, 31 deletions(-) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index a4217ae8a5344..b525a2b344c73 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -279,34 +279,3 @@ def test_full_loop_ddp_spawn(tmpdir): result = trainer.test(datamodule=dm) result = result[0] assert result['test_acc'] > 0.8 - - -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -def test_full_loop_ddp_spawn_non_picklable(tmpdir): - import os - os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' - - reset_seed() - - dm = TrialMNISTDataModule(tmpdir) - dm.non_pickle_thing = lambda x: x**2 - - model = EvalModelTemplate() - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=3, - weights_summary=None, - distributed_backend='ddp_spawn', - gpus=[0, 1] - ) - trainer.fit(model, dm) - - # fit model - result = trainer.fit(model) - assert result == 1 - - # test - result = trainer.test(datamodule=dm) - result = result[0] - assert result['test_acc'] > 0.8 From 5233ac7786f199705b6f3216ede9e09aade051fb Mon Sep 17 00:00:00 2001 From: nateraw Date: Thu, 30 Jul 2020 18:06:02 -0600 Subject: [PATCH 11/21] :bug: call dm.setup everywhere --- .../accelerator_backends/cpu_backend.py | 6 +++++- .../accelerator_backends/ddp_spawn_backend.py | 5 ++++- .../accelerator_backends/dp_backend.py | 5 ++++- .../accelerator_backends/gpu_backend.py | 5 ++++- .../accelerator_backends/tpu_backend.py | 5 ++++- .../trainer/distrib_data_parallel.py | 4 +++- pytorch_lightning/trainer/trainer.py | 7 +------ tests/core/test_datamodules.py | 18 ++++++------------ 8 files changed, 31 insertions(+), 24 deletions(-) diff --git a/pytorch_lightning/accelerator_backends/cpu_backend.py b/pytorch_lightning/accelerator_backends/cpu_backend.py index 2446aab4ddc00..a150677325e13 100644 --- a/pytorch_lightning/accelerator_backends/cpu_backend.py +++ b/pytorch_lightning/accelerator_backends/cpu_backend.py @@ -20,14 +20,18 @@ class CPUBackend(object): def __init__(self, trainer): self.trainer = trainer - def setup(self, model): + def setup(self, model, datamodule=None): # run through amp wrapper if self.trainer.use_amp: raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option') # call setup after the ddp process has connected if not self.trainer.testing: + if datamodule is not None: + datamodule.setup('fit') + self.trainer.setup('fit') + model.setup('fit') # CHOOSE OPTIMIZER diff --git a/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py b/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py index 6aee68f6634f2..4c5bdf76833fb 100644 --- a/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py +++ b/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py @@ -60,7 +60,7 @@ def teardown(self, model): self.trainer.model = model return results - def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0): + def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0, datamodule=None): """ Entry point for ddp @@ -107,7 +107,10 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 # call setup after the ddp process has connected if not self.trainer.testing: + if datamodule is not None: + datamodule.setup('fit') self.trainer.setup('fit') + model.setup('fit') # on world_size=0 let everyone know training is starting diff --git a/pytorch_lightning/accelerator_backends/dp_backend.py b/pytorch_lightning/accelerator_backends/dp_backend.py index 2b0bfca93f7ac..c18828a82d28a 100644 --- a/pytorch_lightning/accelerator_backends/dp_backend.py +++ b/pytorch_lightning/accelerator_backends/dp_backend.py @@ -31,10 +31,13 @@ def __init__(self, trainer): self.trainer = trainer self.model_autocast_original_forward = None - def setup(self, model): + def setup(self, model, datamodule=None): # call setup after the ddp process has connected if not self.trainer.testing: + if datamodule is not None: + datamodule.setup('fit') self.trainer.setup('fit') + model.setup('fit') # put model on correct device diff --git a/pytorch_lightning/accelerator_backends/gpu_backend.py b/pytorch_lightning/accelerator_backends/gpu_backend.py index 3b5f37671d9e8..7f034cc1d5ed4 100644 --- a/pytorch_lightning/accelerator_backends/gpu_backend.py +++ b/pytorch_lightning/accelerator_backends/gpu_backend.py @@ -28,11 +28,14 @@ class GPUBackend(object): def __init__(self, trainer): self.trainer = trainer - def setup(self, model): + def setup(self, model, datamodule=None): # call setup if not self.trainer.testing: + if datamodule is not None: + datamodule.setup('fit') self.trainer.setup('fit') + model.setup('fit') model.cuda(self.trainer.root_gpu) diff --git a/pytorch_lightning/accelerator_backends/tpu_backend.py b/pytorch_lightning/accelerator_backends/tpu_backend.py index 8d1d1b271b7dc..4bc9c2b5e7ded 100644 --- a/pytorch_lightning/accelerator_backends/tpu_backend.py +++ b/pytorch_lightning/accelerator_backends/tpu_backend.py @@ -96,14 +96,17 @@ def __load_weights_on_main_process(self): self.trainer.model = model - def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, trainer=None, mp_queue=None): + def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, trainer=None, mp_queue=None, datamodule=None): """ Here we are inside each individual process """ if not trainer: trainer = self.trainer if not trainer.testing: + if datamodule is not None: + datamodule.setup('fit') trainer.setup('fit') + model.setup('fit') # setup TPU training diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 3ca5f6ffa68f3..6fa3aeb1b3916 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -488,7 +488,7 @@ def spawn_ddp_children(self, model): return results - def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0): + def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0, datamodule=None): """ Entry point for ddp @@ -531,6 +531,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 # call setup after the ddp process has connected if not self.testing: + if datamodule is not None: + datamodule.setup('fit') self.setup('fit') model.setup('fit') diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 19d22e82ee900..04dbf4eb909f9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1033,7 +1033,7 @@ def fit( else: self.accelerator_backend = CPUBackend(self) - self.accelerator_backend.setup(model) + self.accelerator_backend.setup(model, datamodule=datamodule) results = self.accelerator_backend.train(model) # callbacks @@ -1081,11 +1081,6 @@ def __attach_datamodule(self, model, datamodule, stage): if self.is_overridden('prepare_data', datamodule) and not datamodule.has_prepared_data: datamodule.prepare_data() - # If datamodule.setup('fit') has not been called yet, call it - if stage == 'fit': - if self.is_overridden('setup', datamodule) and not datamodule.has_setup_fit: - datamodule.setup('fit') - # If datamodule.setup('test') has not been called yet, call it if stage == 'test': if self.is_overridden('setup', datamodule) and not datamodule.has_setup_test: diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index b525a2b344c73..5f5b2900caf1e 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -128,10 +128,9 @@ def test_train_loop_only(tmpdir): max_epochs=3, weights_summary=None, ) - trainer.fit(model, dm) # fit model - result = trainer.fit(model) + result = trainer.fit(model, dm) assert result == 1 assert trainer.callback_metrics['loss'] < 0.6 @@ -151,10 +150,9 @@ def test_train_val_loop_only(tmpdir): max_epochs=3, weights_summary=None, ) - trainer.fit(model, dm) # fit model - result = trainer.fit(model) + result = trainer.fit(model, dm) assert result == 1 assert trainer.callback_metrics['loss'] < 0.6 @@ -186,10 +184,9 @@ def test_full_loop(tmpdir): max_epochs=3, weights_summary=None, ) - trainer.fit(model, dm) # fit model - result = trainer.fit(model) + result = trainer.fit(model, dm) assert result == 1 # test @@ -212,10 +209,9 @@ def test_full_loop_single_gpu(tmpdir): weights_summary=None, gpus=1 ) - trainer.fit(model, dm) # fit model - result = trainer.fit(model) + result = trainer.fit(model, dm) assert result == 1 # test @@ -239,10 +235,9 @@ def test_full_loop_dp(tmpdir): distributed_backend='dp', gpus=2 ) - trainer.fit(model, dm) # fit model - result = trainer.fit(model) + result = trainer.fit(model, dm) assert result == 1 # test @@ -269,10 +264,9 @@ def test_full_loop_ddp_spawn(tmpdir): distributed_backend='ddp_spawn', gpus=[0, 1] ) - trainer.fit(model, dm) # fit model - result = trainer.fit(model) + result = trainer.fit(model, dm) assert result == 1 # test From cb1b8482ecab3f2372a15e8f3963835275a62628 Mon Sep 17 00:00:00 2001 From: nateraw Date: Thu, 30 Jul 2020 18:07:42 -0600 Subject: [PATCH 12/21] :fire: pickle tests now implied by accelerator tests --- tests/core/test_datamodules.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 5f5b2900caf1e..5fc34f63f17a8 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -97,21 +97,6 @@ def test_dm_pickle_after_init(tmpdir): pickle.dumps(dm) -def test_dm_pickle_after_setup(tmpdir): - dm = TrialMNISTDataModule() - dm.prepare_data() - dm.setup() - pickle.dumps(dm) - - -def test_dm_pickle_after_setup_verbose(tmpdir): - dm = TrialMNISTDataModule() - dm.prepare_data() - dm.setup('fit') - dm.setup('test') - pickle.dumps(dm) - - def test_train_loop_only(tmpdir): dm = TrialMNISTDataModule(tmpdir) From 1b77442a127efa2f5f05cb188194877a3576dc1c Mon Sep 17 00:00:00 2001 From: nateraw Date: Fri, 31 Jul 2020 11:29:47 -0600 Subject: [PATCH 13/21] :art: set dm as attr of trainer --- pytorch_lightning/accelerator_backends/cpu_backend.py | 6 +++--- pytorch_lightning/accelerator_backends/ddp_spawn_backend.py | 6 +++--- pytorch_lightning/accelerator_backends/dp_backend.py | 6 +++--- pytorch_lightning/accelerator_backends/gpu_backend.py | 6 +++--- pytorch_lightning/accelerator_backends/tpu_backend.py | 6 +++--- pytorch_lightning/trainer/trainer.py | 4 +++- 6 files changed, 18 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/accelerator_backends/cpu_backend.py b/pytorch_lightning/accelerator_backends/cpu_backend.py index a150677325e13..886e81f625fb8 100644 --- a/pytorch_lightning/accelerator_backends/cpu_backend.py +++ b/pytorch_lightning/accelerator_backends/cpu_backend.py @@ -20,15 +20,15 @@ class CPUBackend(object): def __init__(self, trainer): self.trainer = trainer - def setup(self, model, datamodule=None): + def setup(self, model): # run through amp wrapper if self.trainer.use_amp: raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option') # call setup after the ddp process has connected if not self.trainer.testing: - if datamodule is not None: - datamodule.setup('fit') + if self.trainer.datamodule is not None: + self.trainer.datamodule.setup('fit') self.trainer.setup('fit') diff --git a/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py b/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py index 4c5bdf76833fb..f03957bfebc9d 100644 --- a/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py +++ b/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py @@ -60,7 +60,7 @@ def teardown(self, model): self.trainer.model = model return results - def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0, datamodule=None): + def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0): """ Entry point for ddp @@ -107,8 +107,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 # call setup after the ddp process has connected if not self.trainer.testing: - if datamodule is not None: - datamodule.setup('fit') + if self.trainer.datamodule is not None: + self.trainer.datamodule.setup('fit') self.trainer.setup('fit') model.setup('fit') diff --git a/pytorch_lightning/accelerator_backends/dp_backend.py b/pytorch_lightning/accelerator_backends/dp_backend.py index c18828a82d28a..5a6c11c4a9d5d 100644 --- a/pytorch_lightning/accelerator_backends/dp_backend.py +++ b/pytorch_lightning/accelerator_backends/dp_backend.py @@ -31,11 +31,11 @@ def __init__(self, trainer): self.trainer = trainer self.model_autocast_original_forward = None - def setup(self, model, datamodule=None): + def setup(self, model): # call setup after the ddp process has connected if not self.trainer.testing: - if datamodule is not None: - datamodule.setup('fit') + if self.trainer.datamodule is not None: + self.trainer.datamodule.setup('fit') self.trainer.setup('fit') model.setup('fit') diff --git a/pytorch_lightning/accelerator_backends/gpu_backend.py b/pytorch_lightning/accelerator_backends/gpu_backend.py index 7f034cc1d5ed4..622d1dc769822 100644 --- a/pytorch_lightning/accelerator_backends/gpu_backend.py +++ b/pytorch_lightning/accelerator_backends/gpu_backend.py @@ -28,12 +28,12 @@ class GPUBackend(object): def __init__(self, trainer): self.trainer = trainer - def setup(self, model, datamodule=None): + def setup(self, model): # call setup if not self.trainer.testing: - if datamodule is not None: - datamodule.setup('fit') + if self.trainer.datamodule is not None: + self.trainer.datamodule.setup('fit') self.trainer.setup('fit') model.setup('fit') diff --git a/pytorch_lightning/accelerator_backends/tpu_backend.py b/pytorch_lightning/accelerator_backends/tpu_backend.py index 4bc9c2b5e7ded..7addc4d931408 100644 --- a/pytorch_lightning/accelerator_backends/tpu_backend.py +++ b/pytorch_lightning/accelerator_backends/tpu_backend.py @@ -96,15 +96,15 @@ def __load_weights_on_main_process(self): self.trainer.model = model - def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, trainer=None, mp_queue=None, datamodule=None): + def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, trainer=None, mp_queue=None): """ Here we are inside each individual process """ if not trainer: trainer = self.trainer if not trainer.testing: - if datamodule is not None: - datamodule.setup('fit') + if self.trainer.datamodule is not None: + self.trainer.datamodule.setup('fit') trainer.setup('fit') model.setup('fit') diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 04dbf4eb909f9..a911cfcef94f4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1033,7 +1033,7 @@ def fit( else: self.accelerator_backend = CPUBackend(self) - self.accelerator_backend.setup(model, datamodule=datamodule) + self.accelerator_backend.setup(model) results = self.accelerator_backend.train(model) # callbacks @@ -1094,6 +1094,8 @@ def __attach_datamodule(self, model, datamodule, stage): if self.is_overridden('test_dataloader', datamodule): model.test_dataloader = datamodule.test_dataloader + self.datamodule = datamodule + def run_pretrain_routine(self, model: LightningModule): """Sanity check a few things before starting actual training. From f059cc43bf20eb16058a02b7e1bc4848e5cd1efa Mon Sep 17 00:00:00 2001 From: nateraw Date: Fri, 31 Jul 2020 11:51:18 -0600 Subject: [PATCH 14/21] :bug: . --- pytorch_lightning/trainer/distrib_data_parallel.py | 6 +++--- pytorch_lightning/trainer/trainer.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 6fa3aeb1b3916..d2769c3e8e25c 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -488,7 +488,7 @@ def spawn_ddp_children(self, model): return results - def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0, datamodule=None): + def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0): """ Entry point for ddp @@ -531,8 +531,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 # call setup after the ddp process has connected if not self.testing: - if datamodule is not None: - datamodule.setup('fit') + if self.datamodule is not None: + self.datamodule.setup('fit') self.setup('fit') model.setup('fit') diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a911cfcef94f4..ca3ddb669c0e6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -378,6 +378,7 @@ def __init__( # training state self.model = None + self.datamodule = None self.testing = False self.prepare_data_per_node = prepare_data_per_node self.lr_schedulers = [] From a3be9e713967a20cca8f582bda6448eab815d789 Mon Sep 17 00:00:00 2001 From: nateraw Date: Fri, 31 Jul 2020 13:28:20 -0600 Subject: [PATCH 15/21] :construction: wip --- pytorch_lightning/trainer/distrib_data_parallel.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index d2769c3e8e25c..6ba0ff8678b21 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -145,6 +145,7 @@ def train_fx(trial_hparams, cluster_manager, _): from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info +from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule @@ -204,6 +205,7 @@ class TrainerDDPMixin(ABC): node_rank: int tpu_cores: int testing: bool + datamodule: Optional[LightningDataModule] @property @abstractmethod From ecc2875531610071b292740feb59cd5eaa908e28 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 1 Aug 2020 19:20:59 -0400 Subject: [PATCH 16/21] add can prepare test --- pytorch_lightning/trainer/trainer.py | 18 ++++++--- tests/core/test_datamodules.py | 55 ++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ca3ddb669c0e6..8391e48cf0546 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -958,6 +958,10 @@ def fit( model.prepare_data() self._is_data_prepared = True + # If datamodule.prepare_data() has not been called yet, call it + if dm_prepare_data_called: + datamodule.prepare_data() + # Run auto batch size scaling if self.auto_scale_batch_size: if isinstance(self.auto_scale_batch_size, bool): @@ -1053,10 +1057,14 @@ def fit( return results or 1 def can_prepare_data(self): + should_call_dm_prepare_data = True + if self.datamodule is not None and self.is_overridden('prepare_data', self.datamodule): + should_call_dm_prepare_data = not self.datamodule.has_prepared_data + if self.prepare_data_per_node: - return self.local_rank == 0 + return self.local_rank == 0 and should_call_dm_prepare_data else: - return self.node_rank == 0 and self.local_rank == 0 + return self.node_rank == 0 and self.local_rank == 0 and should_call_dm_prepare_data def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None): # when dataloader is passed via fit, patch the train_dataloader @@ -1078,9 +1086,9 @@ def __attach_datamodule(self, model, datamodule, stage): # If we have a datamodule, attach necessary hooks + dataloaders if datamodule: - # If datamodule.prepare_data() has not been called yet, call it - if self.is_overridden('prepare_data', datamodule) and not datamodule.has_prepared_data: - datamodule.prepare_data() + # # If datamodule.prepare_data() has not been called yet, call it + # if self.is_overridden('prepare_data', datamodule) and not datamodule.has_prepared_data: + # datamodule.prepare_data() # If datamodule.setup('test') has not been called yet, call it if stage == 'test': diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 5fc34f63f17a8..ec66afb71ca22 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -10,6 +10,61 @@ from tests.base.develop_utils import reset_seed +def test_can_prepare_data(tmpdir): + + dm = TrialMNISTDataModule() + trainer = Trainer() + trainer.datamodule = dm + + # 1 no DM + # prepare_data_per_node = True + # local rank = 0 (True) + trainer.prepare_data_per_node = True + trainer.local_rank = 0 + assert trainer.can_prepare_data() + + # local rank = 1 (False) + trainer.local_rank = 1 + assert not trainer.can_prepare_data() + + # prepare_data_per_node = False (prepare across all nodes) + # global rank = 0 (True) + trainer.prepare_data_per_node = False + trainer.node_rank = 0 + trainer.local_rank = 0 + assert trainer.can_prepare_data() + + # global rank = 1 (False) + trainer.node_rank = 1 + trainer.local_rank = 0 + assert not trainer.can_prepare_data() + trainer.node_rank = 0 + trainer.local_rank = 1 + assert not trainer.can_prepare_data() + + # 2 dm + # prepar per node = True + # local rank = 0 (True) + trainer.prepare_data_per_node = True + trainer.local_rank = 0 + + # is_overridden prepare data = True + # has been called + # False + dm._has_prepared_data = True + assert not trainer.can_prepare_data() + + # has not been called + # True + dm._has_prepared_data = False + assert trainer.can_prepare_data() + + # is_overridden prepare data = False + # True + dm.prepare_data = None + assert trainer.can_prepare_data() + + def test_base_datamodule(tmpdir): dm = TrialMNISTDataModule() dm.prepare_data() From c54ac9ddabb36b8952863b7a2397705b278ffd85 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 1 Aug 2020 19:21:50 -0400 Subject: [PATCH 17/21] add can prepare test --- pytorch_lightning/trainer/trainer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8391e48cf0546..dbede45a0dd70 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -955,13 +955,11 @@ def fit( # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 # or in the case where each node needs to do its own manipulation in which case just local_rank=0 if self.can_prepare_data(): + if datamodule is not None: + datamodule.prepare_data() model.prepare_data() self._is_data_prepared = True - # If datamodule.prepare_data() has not been called yet, call it - if dm_prepare_data_called: - datamodule.prepare_data() - # Run auto batch size scaling if self.auto_scale_batch_size: if isinstance(self.auto_scale_batch_size, bool): From 2bdb10e615e8823e9adb748ca0ca5330d923e077 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 1 Aug 2020 19:27:13 -0400 Subject: [PATCH 18/21] verified setup in fit --- pytorch_lightning/accelerator_backends/cpu_backend.py | 1 - pytorch_lightning/accelerator_backends/dp_backend.py | 1 - pytorch_lightning/accelerator_backends/tpu_backend.py | 1 - pytorch_lightning/trainer/distrib_parts.py | 3 +++ pytorch_lightning/trainer/trainer.py | 10 +++------- 5 files changed, 6 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/accelerator_backends/cpu_backend.py b/pytorch_lightning/accelerator_backends/cpu_backend.py index 886e81f625fb8..a81214f819883 100644 --- a/pytorch_lightning/accelerator_backends/cpu_backend.py +++ b/pytorch_lightning/accelerator_backends/cpu_backend.py @@ -31,7 +31,6 @@ def setup(self, model): self.trainer.datamodule.setup('fit') self.trainer.setup('fit') - model.setup('fit') # CHOOSE OPTIMIZER diff --git a/pytorch_lightning/accelerator_backends/dp_backend.py b/pytorch_lightning/accelerator_backends/dp_backend.py index 5a6c11c4a9d5d..4518001a0f3d7 100644 --- a/pytorch_lightning/accelerator_backends/dp_backend.py +++ b/pytorch_lightning/accelerator_backends/dp_backend.py @@ -37,7 +37,6 @@ def setup(self, model): if self.trainer.datamodule is not None: self.trainer.datamodule.setup('fit') self.trainer.setup('fit') - model.setup('fit') # put model on correct device diff --git a/pytorch_lightning/accelerator_backends/tpu_backend.py b/pytorch_lightning/accelerator_backends/tpu_backend.py index 7addc4d931408..809ea200d52b1 100644 --- a/pytorch_lightning/accelerator_backends/tpu_backend.py +++ b/pytorch_lightning/accelerator_backends/tpu_backend.py @@ -106,7 +106,6 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine if self.trainer.datamodule is not None: self.trainer.datamodule.setup('fit') trainer.setup('fit') - model.setup('fit') # setup TPU training diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 63db623d91f0b..5e746c600cbf3 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -82,6 +82,7 @@ class TrainerDPMixin(ABC): on_colab_kaggle: str save_spawn_weights: Callable logger: ... + datamodule: ... @property @abstractmethod @@ -181,6 +182,8 @@ def __transfer_batch_to_device(self, batch: Any, device: torch.device): def horovod_train(self, model): # call setup after the ddp process has connected if not self.testing: + if self.datamodule is not None: + self.datamodule.setup('fit') self.setup('fit') model.setup('fit') diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index dbede45a0dd70..462df83c36e23 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1084,14 +1084,10 @@ def __attach_datamodule(self, model, datamodule, stage): # If we have a datamodule, attach necessary hooks + dataloaders if datamodule: - # # If datamodule.prepare_data() has not been called yet, call it - # if self.is_overridden('prepare_data', datamodule) and not datamodule.has_prepared_data: - # datamodule.prepare_data() - # If datamodule.setup('test') has not been called yet, call it - if stage == 'test': - if self.is_overridden('setup', datamodule) and not datamodule.has_setup_test: - datamodule.setup('test') + # if stage == 'test': + # if self.is_overridden('setup', datamodule) and not datamodule.has_setup_test: + # datamodule.setup('test') # Override loader hooks if self.is_overridden('train_dataloader', datamodule): From 8107ef97f45f72152b79de44a8a81a4ac2459471 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 1 Aug 2020 19:35:26 -0400 Subject: [PATCH 19/21] fixed setup call --- .../accelerator_backends/cpu_backend.py | 7 +------ .../accelerator_backends/ddp_spawn_backend.py | 7 +------ .../accelerator_backends/dp_backend.py | 6 +----- .../accelerator_backends/gpu_backend.py | 7 +------ .../accelerator_backends/tpu_backend.py | 7 ++----- .../trainer/distrib_data_parallel.py | 10 +++++----- pytorch_lightning/trainer/distrib_parts.py | 11 +++++------ pytorch_lightning/trainer/trainer.py | 15 ++++++++++----- 8 files changed, 26 insertions(+), 44 deletions(-) diff --git a/pytorch_lightning/accelerator_backends/cpu_backend.py b/pytorch_lightning/accelerator_backends/cpu_backend.py index a81214f819883..d6c2dacb1a01c 100644 --- a/pytorch_lightning/accelerator_backends/cpu_backend.py +++ b/pytorch_lightning/accelerator_backends/cpu_backend.py @@ -26,12 +26,7 @@ def setup(self, model): raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option') # call setup after the ddp process has connected - if not self.trainer.testing: - if self.trainer.datamodule is not None: - self.trainer.datamodule.setup('fit') - - self.trainer.setup('fit') - model.setup('fit') + self.trainer.call_setup_hook() # CHOOSE OPTIMIZER # allow for lr schedulers as well diff --git a/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py b/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py index f03957bfebc9d..6c00d15f164c1 100644 --- a/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py +++ b/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py @@ -106,12 +106,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 ) # call setup after the ddp process has connected - if not self.trainer.testing: - if self.trainer.datamodule is not None: - self.trainer.datamodule.setup('fit') - self.trainer.setup('fit') - - model.setup('fit') + self.trainer.call_setup_hook() # on world_size=0 let everyone know training is starting if self.trainer.is_global_zero: diff --git a/pytorch_lightning/accelerator_backends/dp_backend.py b/pytorch_lightning/accelerator_backends/dp_backend.py index 4518001a0f3d7..df922a000696c 100644 --- a/pytorch_lightning/accelerator_backends/dp_backend.py +++ b/pytorch_lightning/accelerator_backends/dp_backend.py @@ -33,11 +33,7 @@ def __init__(self, trainer): def setup(self, model): # call setup after the ddp process has connected - if not self.trainer.testing: - if self.trainer.datamodule is not None: - self.trainer.datamodule.setup('fit') - self.trainer.setup('fit') - model.setup('fit') + self.trainer.call_setup_hook() # put model on correct device model.cuda(self.trainer.root_gpu) diff --git a/pytorch_lightning/accelerator_backends/gpu_backend.py b/pytorch_lightning/accelerator_backends/gpu_backend.py index 622d1dc769822..b649ceb864cac 100644 --- a/pytorch_lightning/accelerator_backends/gpu_backend.py +++ b/pytorch_lightning/accelerator_backends/gpu_backend.py @@ -31,12 +31,7 @@ def __init__(self, trainer): def setup(self, model): # call setup - if not self.trainer.testing: - if self.trainer.datamodule is not None: - self.trainer.datamodule.setup('fit') - self.trainer.setup('fit') - - model.setup('fit') + self.trainer.call_setup_hook() model.cuda(self.trainer.root_gpu) diff --git a/pytorch_lightning/accelerator_backends/tpu_backend.py b/pytorch_lightning/accelerator_backends/tpu_backend.py index 809ea200d52b1..221dfa87aef52 100644 --- a/pytorch_lightning/accelerator_backends/tpu_backend.py +++ b/pytorch_lightning/accelerator_backends/tpu_backend.py @@ -102,11 +102,8 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine """ if not trainer: trainer = self.trainer - if not trainer.testing: - if self.trainer.datamodule is not None: - self.trainer.datamodule.setup('fit') - trainer.setup('fit') - model.setup('fit') + + trainer.call_setup_hook() # setup TPU training self.__setup_tpu_training(model, trainer) diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 6ba0ff8678b21..ccaa97486d9fb 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -212,6 +212,10 @@ class TrainerDDPMixin(ABC): def is_global_zero(self) -> bool: """Warning: this is just empty shell for code implemented in other class.""" + @abstractmethod + def call_setup_hook(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" + @property @abstractmethod def num_gpus(self) -> int: @@ -532,11 +536,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 model.init_ddp_connection(self.global_rank, self.world_size, self.is_slurm_managing_tasks) # call setup after the ddp process has connected - if not self.testing: - if self.datamodule is not None: - self.datamodule.setup('fit') - self.setup('fit') - model.setup('fit') + self.call_setup_hook() # on world_size=0 let everyone know training is starting if self.is_global_zero: diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 5e746c600cbf3..22f016d29dfa3 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -82,13 +82,16 @@ class TrainerDPMixin(ABC): on_colab_kaggle: str save_spawn_weights: Callable logger: ... - datamodule: ... @property @abstractmethod def use_amp(self) -> bool: """Warning: this is just empty shell for code implemented in other class.""" + @abstractmethod + def call_setup_hook(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" + @abstractmethod def run_pretrain_routine(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @@ -181,11 +184,7 @@ def __transfer_batch_to_device(self, batch: Any, device: torch.device): def horovod_train(self, model): # call setup after the ddp process has connected - if not self.testing: - if self.datamodule is not None: - self.datamodule.setup('fit') - self.setup('fit') - model.setup('fit') + self.call_setup_hook() if torch.cuda.is_available() and self.on_gpu: # Horovod: pin GPU to local rank diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 462df83c36e23..4ca2bbdd6c76a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1297,8 +1297,6 @@ def test( # Attach datamodule to get setup/prepare_data added to model before the call to it below self.__attach_datamodule(model or self.get_model(), datamodule, 'test') - self.setup('test') - if model is not None: results = self.__test_given_model(model, test_dataloaders) else: @@ -1310,7 +1308,6 @@ def test( def __test_using_best_weights(self, ckpt_path, test_dataloaders): model = self.get_model() - model.setup('test') # if user requests the best checkpoint but we don't have it, error if ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0: @@ -1356,8 +1353,6 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): return results def __test_given_model(self, model, test_dataloaders): - # setup hook - model.setup('test') # attach data if test_dataloaders is not None: @@ -1386,6 +1381,16 @@ def barrier(self, name): # wait for all processes to catch up torch_xla.core.xla_model.rendezvous(f'pl.Trainer.{name}') + def call_setup_hook(self): + # call setup after the ddp process has connected + stage_name = 'test' if self.testing else 'fit' + if self.datamodule is not None: + self.datamodule.setup(stage_name) + self.setup(stage_name) + + model = self.get_model() + model.setup(stage_name) + class _PatchDataLoader(object): r""" From 57e942a790e3222735c6903939b965c8cba40ac7 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 1 Aug 2020 19:40:43 -0400 Subject: [PATCH 20/21] fixed setup call --- pytorch_lightning/trainer/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4ca2bbdd6c76a..58356e233be6a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1385,7 +1385,9 @@ def call_setup_hook(self): # call setup after the ddp process has connected stage_name = 'test' if self.testing else 'fit' if self.datamodule is not None: - self.datamodule.setup(stage_name) + called = self.datamodule.has_setup_test if self.testing else self.datamodule.has_setup_fit + if not called: + self.datamodule.setup(stage_name) self.setup(stage_name) model = self.get_model() From c138110884f6366d97399efaaa695fb694a6d0db Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 1 Aug 2020 20:00:47 -0400 Subject: [PATCH 21/21] fixed setup call --- pytorch_lightning/accelerator_backends/cpu_backend.py | 2 +- pytorch_lightning/accelerator_backends/ddp_spawn_backend.py | 2 +- pytorch_lightning/accelerator_backends/dp_backend.py | 2 +- pytorch_lightning/accelerator_backends/gpu_backend.py | 2 +- pytorch_lightning/accelerator_backends/tpu_backend.py | 2 +- pytorch_lightning/trainer/distrib_data_parallel.py | 2 +- pytorch_lightning/trainer/distrib_parts.py | 2 +- pytorch_lightning/trainer/trainer.py | 4 +--- 8 files changed, 8 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/accelerator_backends/cpu_backend.py b/pytorch_lightning/accelerator_backends/cpu_backend.py index d6c2dacb1a01c..7760442a206c5 100644 --- a/pytorch_lightning/accelerator_backends/cpu_backend.py +++ b/pytorch_lightning/accelerator_backends/cpu_backend.py @@ -26,7 +26,7 @@ def setup(self, model): raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option') # call setup after the ddp process has connected - self.trainer.call_setup_hook() + self.trainer.call_setup_hook(model) # CHOOSE OPTIMIZER # allow for lr schedulers as well diff --git a/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py b/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py index 6c00d15f164c1..122355856eaf1 100644 --- a/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py +++ b/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py @@ -106,7 +106,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 ) # call setup after the ddp process has connected - self.trainer.call_setup_hook() + self.trainer.call_setup_hook(model) # on world_size=0 let everyone know training is starting if self.trainer.is_global_zero: diff --git a/pytorch_lightning/accelerator_backends/dp_backend.py b/pytorch_lightning/accelerator_backends/dp_backend.py index df922a000696c..efb683ff4eaa9 100644 --- a/pytorch_lightning/accelerator_backends/dp_backend.py +++ b/pytorch_lightning/accelerator_backends/dp_backend.py @@ -33,7 +33,7 @@ def __init__(self, trainer): def setup(self, model): # call setup after the ddp process has connected - self.trainer.call_setup_hook() + self.trainer.call_setup_hook(model) # put model on correct device model.cuda(self.trainer.root_gpu) diff --git a/pytorch_lightning/accelerator_backends/gpu_backend.py b/pytorch_lightning/accelerator_backends/gpu_backend.py index b649ceb864cac..7f15d3c25f410 100644 --- a/pytorch_lightning/accelerator_backends/gpu_backend.py +++ b/pytorch_lightning/accelerator_backends/gpu_backend.py @@ -31,7 +31,7 @@ def __init__(self, trainer): def setup(self, model): # call setup - self.trainer.call_setup_hook() + self.trainer.call_setup_hook(model) model.cuda(self.trainer.root_gpu) diff --git a/pytorch_lightning/accelerator_backends/tpu_backend.py b/pytorch_lightning/accelerator_backends/tpu_backend.py index 221dfa87aef52..2c0b172b9e211 100644 --- a/pytorch_lightning/accelerator_backends/tpu_backend.py +++ b/pytorch_lightning/accelerator_backends/tpu_backend.py @@ -103,7 +103,7 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine if not trainer: trainer = self.trainer - trainer.call_setup_hook() + trainer.call_setup_hook(model) # setup TPU training self.__setup_tpu_training(model, trainer) diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index ccaa97486d9fb..c8efc55b49e37 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -536,7 +536,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 model.init_ddp_connection(self.global_rank, self.world_size, self.is_slurm_managing_tasks) # call setup after the ddp process has connected - self.call_setup_hook() + self.call_setup_hook(model) # on world_size=0 let everyone know training is starting if self.is_global_zero: diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 22f016d29dfa3..7d5a00523ef9e 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -184,7 +184,7 @@ def __transfer_batch_to_device(self, batch: Any, device: torch.device): def horovod_train(self, model): # call setup after the ddp process has connected - self.call_setup_hook() + self.call_setup_hook(model) if torch.cuda.is_available() and self.on_gpu: # Horovod: pin GPU to local rank diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 58356e233be6a..59a33dad7e5dd 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1381,7 +1381,7 @@ def barrier(self, name): # wait for all processes to catch up torch_xla.core.xla_model.rendezvous(f'pl.Trainer.{name}') - def call_setup_hook(self): + def call_setup_hook(self, model): # call setup after the ddp process has connected stage_name = 'test' if self.testing else 'fit' if self.datamodule is not None: @@ -1389,8 +1389,6 @@ def call_setup_hook(self): if not called: self.datamodule.setup(stage_name) self.setup(stage_name) - - model = self.get_model() model.setup(stage_name)