From 76a68e72230c1dc423e1837ed3a9c2080e857afb Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sat, 14 Nov 2020 10:38:53 -0800 Subject: [PATCH 01/16] [data] Support teardown hook on DataModule --- pytorch_lightning/core/datamodule.py | 49 +++++++++++++++++++++++++++- pytorch_lightning/trainer/trainer.py | 17 +++++----- tests/base/boring_model.py | 33 ++++++++++++++++++- tests/core/test_datamodules.py | 1 + 4 files changed, 90 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index fe81d641c86d6..46d9b0d50af76 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -58,6 +58,9 @@ def track_data_hook_calls(fn): - When dm.setup('fit') is called, dm.has_setup_fit gets set to True - When dm.setup('test') is called, dm.has_setup_test gets set to True - When dm.setup() is called without stage arg, both dm.has_setup_fit and dm.has_setup_test get set to True + - When dm.teardown('fit') is called, dm.has_teardown_fit gets set to True + - When dm.teardown('test') is called, dm.has_teardown_fit gets set to True + - When dm.teardown() is called without stage arg, both dm.has_teardown_fit and dm.has_teardown_test get set to True Args: fn (function): Function that will be tracked to see if it has been called. @@ -86,6 +89,21 @@ def wrapped_fn(*args, **kwargs): if stage == "test" or stage is None: obj._has_setup_test = True + # If calling teardown, we check the stage and assign stage-specific bool args + if fn.__name__ == "teardown": + + # Get stage either by grabbing from args or checking kwargs. + # If not provided, set call status of 'fit' and 'test' to True. + # We do this so __attach_datamodule in trainer.py doesn't mistakenly call teardown('test') on trainer.test() + stage = args[1] if len(args) > 1 else kwargs.get("stage", None) + + if stage == "fit" or stage is None: + obj._has_teardown_fit = True + + if stage == "test" or stage is None: + obj._has_teardown_test = True + + if fn.__name__ == "prepare_data": obj._has_prepared_data = True @@ -119,14 +137,18 @@ def val_dataloader(self): def test_dataloader(self): test_split = Dataset(...) return DataLoader(test_split) + def teardown(self): + # clean up after fit or test + # called on every process in DDP - A DataModule implements 5 key methods: + A DataModule implements 6 key methods: * **prepare_data** (things to do on 1 GPU/TPU not on every GPU/TPU in distributed mode). * **setup** (things to do on every accelerator in distributed mode). * **train_dataloader** the training dataloader. * **val_dataloader** the val dataloader(s). * **test_dataloader** the test dataloader(s). + * **teardown** (things to do on every accelerator in distributed mode after fit/test) This allows you to share a full dataset without explaining how to download, @@ -156,6 +178,8 @@ def __init__( self._has_prepared_data = False self._has_setup_fit = False self._has_setup_test = False + self._has_teardown_fit = False + self._has_teardown_test = False @property def train_transforms(self): @@ -239,6 +263,25 @@ def has_setup_test(self): """ return self._has_setup_test + + @property + def has_teardown_fit(self): + """Return bool letting you know if datamodule.teardown('fit') has been called or not. + + Returns: + bool: True if datamodule.teardown('fit') has been called. False by default. + """ + return self._has_teardown_fit + + @property + def has_teardown_test(self): + """Return bool letting you know if datamodule.teardown('test') has been called or not. + + Returns: + bool: True if datamodule.teardown('test') has been called. False by default. + """ + return self._has_teardown_test + @abstractmethod def prepare_data(self, *args, **kwargs): pass @@ -259,6 +302,10 @@ def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]] def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: pass + @abstractmethod + def teardown(self, stage: Optional[str] = None): + pass + @abstractmethod def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: pass diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 46e4abbe584ae..b75ce9a748eeb 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -477,6 +477,9 @@ def fit( # hook self.teardown('fit') + if self.datamodule is not None: + if not self.datamodule.has_teardown_fit: + self.datamodule.teardown('fit') if self.is_function_implemented('teardown'): model.teardown('fit') @@ -759,7 +762,13 @@ def test( else: results = self.__test_using_best_weights(ckpt_path, test_dataloaders) + # teardown self.teardown('test') + if self.datamodule is not None: + if not self.datamodule.has_teardown_test: + self.datamodule.teardown('test') + if self.is_function_implemented('teardown'): + model_ref = self.get_model() return results @@ -803,11 +812,6 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): self.testing = False del os.environ['PL_TESTING_MODE'] - # teardown - if self.is_function_implemented('teardown'): - model_ref = self.get_model() - model_ref.teardown('test') - return results def __test_given_model(self, model, test_dataloaders): @@ -823,9 +827,6 @@ def __test_given_model(self, model, test_dataloaders): results = self.fit(model) self.testing = False - # teardown - if self.is_function_implemented('teardown'): - model.teardown('test') return results diff --git a/tests/base/boring_model.py b/tests/base/boring_model.py index 6ceffe8562372..d00ec2f43792c 100644 --- a/tests/base/boring_model.py +++ b/tests/base/boring_model.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch -from pytorch_lightning import LightningModule +from pytorch_lightning import LightningDataModule, LightningModule from torch.utils.data import Dataset @@ -129,3 +129,34 @@ def val_dataloader(self): def test_dataloader(self): return torch.utils.data.DataLoader(RandomDataset(32, 64)) + + +class BoringDataModule(LightningDataModule): + + def __init__(self): + """ + Testing PL DataModule + + Use as follows: + - subclass + - modify the behavior for what you want + + class TestDM(BoringDataModule): + def train_dataloader(...): + # do your own thing + + or: + + model = TestDM() + model.setup = None + """ + super().__init__() + + def train_dataloader(self): + return torch.utils.data.DataLoader(RandomDataset(32, 64)) + + def val_dataloader(self): + return torch.utils.data.DataLoader(RandomDataset(32, 64)) + + def test_dataloader(self): + return torch.utils.data.DataLoader(RandomDataset(32, 64)) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 3e683025e8867..80c9dd94c4537 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -22,6 +22,7 @@ from pytorch_lightning import LightningDataModule, Trainer, seed_everything from tests.base import EvalModelTemplate +from tests.base.boring_model import BoringDataModule from tests.base.datasets import TrialMNIST from tests.base.datamodules import TrialMNISTDataModule from tests.base.develop_utils import reset_seed From 38e33f53f5878fa569e89f7be715d3c01d05f3e9 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sat, 14 Nov 2020 10:40:35 -0800 Subject: [PATCH 02/16] Update trainer.py --- pytorch_lightning/trainer/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b75ce9a748eeb..ff43a26bc1fb8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -769,6 +769,7 @@ def test( self.datamodule.teardown('test') if self.is_function_implemented('teardown'): model_ref = self.get_model() + model_ref.teardown('test') return results From 12a85b3f74d300fe673ca70090fc6e629ae6bd3a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 01:54:23 +0100 Subject: [PATCH 03/16] Pass {fit,validate,test,predict} to setup() --- pytorch_lightning/callbacks/base.py | 4 +- pytorch_lightning/core/datamodule.py | 58 ++++++++----- pytorch_lightning/core/hooks.py | 8 +- pytorch_lightning/trainer/callback_hook.py | 24 +++--- pytorch_lightning/trainer/model_hooks.py | 6 -- pytorch_lightning/trainer/states.py | 16 ++-- pytorch_lightning/trainer/trainer.py | 57 ++++++------- tests/callbacks/test_callbacks.py | 35 ++++---- tests/core/test_datamodules.py | 97 ++++++++++++---------- tests/helpers/boring_model.py | 4 +- 10 files changed, 168 insertions(+), 141 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index d53acf0f7030d..494d94cf446de 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -34,11 +34,11 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModul pass def setup(self, trainer, pl_module: LightningModule, stage: str) -> None: - """Called when fit or test begins""" + """Called when fit, validate, test, predict, or tune begins""" pass def teardown(self, trainer, pl_module: LightningModule, stage: str) -> None: - """Called when fit or test ends""" + """Called when fit, validate, test, predict, or tune ends""" pass def on_init_start(self, trainer) -> None: diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 29b93abe3e6a1..31c05e3bcc4c4 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -55,10 +55,10 @@ def __call__(cls, *args, **kwargs): def track_data_hook_calls(fn): """A decorator that checks if prepare_data/setup have been called. - - When dm.prepare_data() is called, dm.has_prepared_data gets set to True - - When dm.setup('fit') is called, dm.has_setup_fit gets set to True - - When dm.setup('test') is called, dm.has_setup_test gets set to True - - When dm.setup() is called without stage arg, both dm.has_setup_fit and dm.has_setup_test get set to True + - When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True + - When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True + - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}`` + it's corresponding `dm_has_setup_{stage}` gets set to True Args: fn (function): Function that will be tracked to see if it has been called. @@ -77,15 +77,15 @@ def wrapped_fn(*args, **kwargs): if fn.__name__ == "setup": # Get stage either by grabbing from args or checking kwargs. - # If not provided, set call status of 'fit' and 'test' to True. + # If not provided, set call status of 'fit', 'validate', and 'test' to True. # We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test() stage = args[1] if len(args) > 1 else kwargs.get("stage", None) - if stage == "fit" or stage is None: - obj._has_setup_fit = True - - if stage == "test" or stage is None: - obj._has_setup_test = True + if stage is None: + for s in ("fit", "validate", "test"): + setattr(obj, f"_has_setup_{s}", True) + else: + setattr(obj, f"_has_setup_{stage}", True) if fn.__name__ == "prepare_data": obj._has_prepared_data = True @@ -156,7 +156,9 @@ def __init__( # Private attrs to keep track of whether or not data hooks have been called yet self._has_prepared_data = False self._has_setup_fit = False + self._has_setup_validate = False self._has_setup_test = False + self._has_setup_predict = False @property def train_transforms(self): @@ -214,32 +216,50 @@ def size(self, dim=None) -> Union[Tuple, int]: return self.dims @property - def has_prepared_data(self): - """Return bool letting you know if datamodule.prepare_data() has been called or not. + def has_prepared_data(self) -> bool: + """Return bool letting you know if ``datamodule.prepare_data()`` has been called or not. Returns: - bool: True if datamodule.prepare_data() has been called. False by default. + bool: True if ``datamodule.prepare_data()`` has been called. False by default. """ return self._has_prepared_data @property - def has_setup_fit(self): - """Return bool letting you know if datamodule.setup('fit') has been called or not. + def has_setup_fit(self) -> bool: + """Return bool letting you know if ``datamodule.setup('fit')`` has been called or not. Returns: - bool: True if datamodule.setup('fit') has been called. False by default. + bool: True ``if datamodule.setup('fit')`` has been called. False by default. """ return self._has_setup_fit @property - def has_setup_test(self): - """Return bool letting you know if datamodule.setup('test') has been called or not. + def has_setup_validate(self) -> bool: + """Return bool letting you know if ``datamodule.setup('validate')`` has been called or not. + + Returns: + bool: True if ``datamodule.setup('validate')`` has been called. False by default. + """ + return self._has_setup_validate + + @property + def has_setup_test(self) -> bool: + """Return bool letting you know if ``datamodule.setup('test')`` has been called or not. Returns: - bool: True if datamodule.setup('test') has been called. False by default. + bool: True if ``datamodule.setup('test')`` has been called. False by default. """ return self._has_setup_test + @property + def has_setup_predict(self) -> bool: + """Return bool letting you know if ``datamodule.setup('predict')`` has been called or not. + + Returns: + bool: True if ``datamodule.setup('predict')`` has been called. False by default. + """ + return self._has_setup_predict + @abstractmethod def prepare_data(self, *args, **kwargs): pass diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 604803365298c..a6567e3d52f0f 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -27,12 +27,12 @@ class ModelHooks: def setup(self, stage: str) -> None: """ - Called at the beginning of fit and test. + Called at the beginning of fit (train + validate), validate, test, predict, or tune. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP. Args: - stage: either 'fit' or 'test' + stage: either ``'fit'``, ``'validate'``, ``'test'``, ``'predict'``, or ``'tune'`` Example:: @@ -55,10 +55,10 @@ def setup(stage): def teardown(self, stage: str) -> None: """ - Called at the end of fit and test. + Called at the end of fit (train + validate), validate, test, predict, or tune. Args: - stage: either 'fit' or 'test' + stage: either ``'fit'``, ``'validate'``, ``'test'``, ``'predict'``, or ``'tune'`` """ def on_fit_start(self) -> None: diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 8f9fc3ad930b0..71433429f7c03 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -29,18 +29,18 @@ class TrainerCallbackHookMixin(ABC): callbacks: List[Callback] = [] lightning_module: LightningModule - def on_before_accelerator_backend_setup(self, model): - """Called in the beginning of fit and test""" + def on_before_accelerator_backend_setup(self, model: LightningModule) -> None: + """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.on_before_accelerator_backend_setup(self, model) - def setup(self, model, stage: str): - """Called in the beginning of fit and test""" + def setup(self, model: LightningModule, stage: str) -> None: + """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.setup(self, model, stage) - def teardown(self, stage: str): - """Called at the end of fit and test""" + def teardown(self, stage: str) -> None: + """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.teardown(self, self.lightning_module, stage) @@ -124,15 +124,15 @@ def on_train_end(self): for callback in self.callbacks: callback.on_train_end(self, self.lightning_module) - def on_pretrain_routine_start(self, model): - """Called when the train begins.""" + def on_pretrain_routine_start(self) -> None: + """Called when the pre-train routine begins.""" for callback in self.callbacks: - callback.on_pretrain_routine_start(self, model) + callback.on_pretrain_routine_start(self, self.lightning_module) - def on_pretrain_routine_end(self, model): - """Called when the train ends.""" + def on_pretrain_routine_end(self) -> None: + """Called when the pre-train routine ends.""" for callback in self.callbacks: - callback.on_pretrain_routine_end(self, model) + callback.on_pretrain_routine_end(self, self.lightning_module) def on_batch_start(self): """Called when the training batch begins.""" diff --git a/pytorch_lightning/trainer/model_hooks.py b/pytorch_lightning/trainer/model_hooks.py index 7e3d6cc78320c..e98ebf088a8dc 100644 --- a/pytorch_lightning/trainer/model_hooks.py +++ b/pytorch_lightning/trainer/model_hooks.py @@ -22,12 +22,6 @@ class TrainerModelHooksMixin(ABC): lightning_module: LightningModule - def is_function_implemented(self, f_name, model=None): - if model is None: - model = self.lightning_module - f_op = getattr(model, f_name, None) - return callable(f_op) - def has_arg(self, f_name, arg_name): model = self.lightning_module f_op = getattr(model, f_name, None) diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index d0c2ded659f67..2688fb6754977 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -27,14 +27,14 @@ class TrainerState(LightningEnum): >>> TrainerState.FINISHED == 'finished' True """ - INITIALIZING = 'INITIALIZING' # trainer creation - FITTING = 'FITTING' # trainer.fit() - VALIDATING = 'VALIDATING' # trainer.validate() - TESTING = 'TESTING' # trainer.test() - PREDICTING = 'PREDICTING' # trainer.predict() - TUNING = 'TUNING' # trainer.tune() - FINISHED = 'FINISHED' - INTERRUPTED = 'INTERRUPTED' + INITIALIZING = 'initializing' # trainer creation + FITTING = 'fit' # trainer.fit() + VALIDATING = 'validate' # trainer.validate() + TESTING = 'test' # trainer.test() + PREDICTING = 'predict' # trainer.predict() + TUNING = 'tune' # trainer.tune() + FINISHED = 'finished' + INTERRUPTED = 'interrupted' @property def stopped(self) -> bool: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cc1964f07039b..7cd666b17ca7b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -443,7 +443,7 @@ def fit( # ---------------------------- self.call_setup_hook(model) self.call_hook("on_before_accelerator_backend_setup", model) - self.accelerator.setup(self, model) + self.accelerator.setup(self, model) # note: this sets up self.lightning_module self.setup_trainer(model) # ---------------------------- @@ -473,7 +473,8 @@ def fit( # TRAIN # ---------------------------- # hook - self.call_hook("on_fit_start") + if self.state == TrainerState.FITTING: + self.call_hook("on_fit_start") # plugin will setup fitting (e.g. ddp will launch child processes) self.pre_dispatch() @@ -488,12 +489,11 @@ def fit( # POST-Training CLEAN UP # ---------------------------- # hook - self.call_hook('on_fit_end') + if self.state == TrainerState.FITTING: + self.call_hook('on_fit_end') - # hook - self.teardown('fit') - if self.is_function_implemented('teardown'): - model.teardown('fit') + # teardown + self.call_teardown_hook(model) if self.state != TrainerState.INTERRUPTED: self.state = TrainerState.FINISHED @@ -541,9 +541,8 @@ def _pre_training_routine(self): # on pretrain routine start ref_model = self.lightning_module - self.on_pretrain_routine_start(ref_model) - if self.is_function_implemented("on_pretrain_routine_start"): - ref_model.on_pretrain_routine_start() + self.on_pretrain_routine_start() + ref_model.on_pretrain_routine_start() # print model summary if self.is_global_zero and self.weights_summary is not None and not self.testing: @@ -556,9 +555,8 @@ def _pre_training_routine(self): self.checkpoint_connector.restore_weights() # on pretrain routine end - self.on_pretrain_routine_end(ref_model) - if self.is_function_implemented("on_pretrain_routine_end"): - ref_model.on_pretrain_routine_end() + self.on_pretrain_routine_end() + ref_model.on_pretrain_routine_end() def run_train(self) -> None: @@ -880,8 +878,6 @@ def test( self.__evaluate_using_weights(model, ckpt_path=ckpt_path, dataloaders=test_dataloaders) ) - self.teardown('test') - assert self.state.stopped self.testing = False @@ -929,10 +925,6 @@ def __evaluate_using_weights( # run test results = self.fit(model) - # teardown - if self.is_function_implemented('teardown', model=model): - model.teardown('test') - return results def __evaluate_given_model(self, model, dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None): @@ -944,10 +936,6 @@ def __evaluate_given_model(self, model, dataloaders: Optional[Union[DataLoader, # sets up testing so we short circuit to eval results = self.fit(model) - # teardown - if self.is_function_implemented('teardown', model=model): - model.teardown('test') - return results def predict( @@ -1035,17 +1023,26 @@ def tune( assert self.state.stopped self.tuning = False - def call_setup_hook(self, model): - # call setup after the ddp process has connected - stage_name = 'test' if self.evaluating else 'fit' + def call_setup_hook(self, model: LightningModule) -> None: + assert self.state.running, f"TrainerState: {self.state}" + state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state + state = state.value if self.datamodule is not None: - called = getattr(self.datamodule, f'has_setup_{stage_name}') + called = getattr(self.datamodule, f'has_setup_{state}') if not called: - self.datamodule.setup(stage_name) + self.datamodule.setup(state) + + self.setup(model, state) + model.setup(state) + + def call_teardown_hook(self, model: LightningModule) -> None: + assert self.state.running, f"TrainerState: {self.state}" + state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state + state = state.value - self.setup(model, stage_name) - model.setup(stage_name) + self.teardown(state) + model.teardown(state) def _reset_result_and_set_hook_fx_name(self, hook_name): # on_before_zero_grad is called within training_step diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 8a25ecc9f983b..2426348f770bf 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -19,29 +19,20 @@ @mock.patch("torch.save") # need to mock torch.save or we get pickle error -def test_trainer_callback_system(_, tmpdir): - """Test the callback system.""" +def test_trainer_callback_system_fit(_, tmpdir): + """Test the callback system for fit.""" model = BoringModel() - callback_mock = MagicMock() - - trainer_options = dict( + trainer = Trainer( default_root_dir=tmpdir, callbacks=[callback_mock], max_epochs=1, limit_val_batches=1, limit_train_batches=3, - limit_test_batches=2, progress_bar_refresh_rate=0, ) - # no call yet - callback_mock.assert_not_called() - - # fit model - trainer = Trainer(**trainer_options) - # check that only the to calls exists assert trainer.callbacks[0] == callback_mock assert callback_mock.method_calls == [ @@ -49,6 +40,7 @@ def test_trainer_callback_system(_, tmpdir): call.on_init_end(trainer), ] + # fit model trainer.fit(model) assert callback_mock.method_calls == [ @@ -104,8 +96,20 @@ def test_trainer_callback_system(_, tmpdir): call.teardown(trainer, model, 'fit'), ] - callback_mock.reset_mock() - trainer = Trainer(**trainer_options) + +def test_trainer_callback_system_test(tmpdir): + """Test the callback system for test.""" + + model = BoringModel() + callback_mock = MagicMock() + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[callback_mock], + max_epochs=1, + limit_test_batches=2, + progress_bar_refresh_rate=0, + ) + trainer.test(model) assert callback_mock.method_calls == [ @@ -113,7 +117,6 @@ def test_trainer_callback_system(_, tmpdir): call.on_init_end(trainer), call.setup(trainer, model, 'test'), call.on_before_accelerator_backend_setup(trainer, model), - call.on_fit_start(trainer, model), call.on_test_start(trainer, model), call.on_test_epoch_start(trainer, model), call.on_test_batch_start(trainer, model, ANY, 0, 0), @@ -123,8 +126,6 @@ def test_trainer_callback_system(_, tmpdir): call.on_test_epoch_end(trainer, model), call.on_epoch_end(trainer, model), call.on_test_end(trainer, model), - call.on_fit_end(trainer, model), - call.teardown(trainer, model, 'fit'), call.teardown(trainer, model, 'test'), ] diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 866bffcdd7441..e1b4301842ecd 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -17,6 +17,7 @@ from unittest import mock from unittest.mock import PropertyMock +import pytest import torch import torch.nn.functional as F @@ -108,13 +109,13 @@ def prepare_data(self, *args, **kwargs): dm.prepare_data() -def test_base_datamodule(tmpdir): +def test_helper_boringdatamodule(tmpdir): dm = BoringDataModule() dm.prepare_data() dm.setup() -def test_base_datamodule_with_verbose_setup(tmpdir): +def test_helper_boringdatamodule_with_verbose_setup(tmpdir): dm = BoringDataModule() dm.prepare_data() dm.setup('fit') @@ -123,55 +124,67 @@ def test_base_datamodule_with_verbose_setup(tmpdir): def test_data_hooks_called(tmpdir): dm = BoringDataModule() - assert dm.has_prepared_data is False - assert dm.has_setup_fit is False - assert dm.has_setup_test is False + assert not dm.has_prepared_data + assert not dm.has_setup_fit + assert not dm.has_setup_test + assert not dm.has_setup_validate + assert not dm.has_setup_predict dm.prepare_data() - assert dm.has_prepared_data is True - assert dm.has_setup_fit is False - assert dm.has_setup_test is False + assert dm.has_prepared_data + assert not dm.has_setup_fit + assert not dm.has_setup_test + assert not dm.has_setup_validate + assert not dm.has_setup_predict dm.setup() - assert dm.has_prepared_data is True - assert dm.has_setup_fit is True - assert dm.has_setup_test is True + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_test + assert dm.has_setup_validate + assert not dm.has_setup_predict -def test_data_hooks_called_verbose(tmpdir): +@pytest.mark.parametrize("use_kwarg", (False, True)) +def test_data_hooks_called_verbose(tmpdir, use_kwarg): dm = BoringDataModule() - assert dm.has_prepared_data is False - assert dm.has_setup_fit is False - assert dm.has_setup_test is False + assert not dm.has_prepared_data + assert not dm.has_setup_fit + assert not dm.has_setup_test dm.prepare_data() - assert dm.has_prepared_data is True - assert dm.has_setup_fit is False - assert dm.has_setup_test is False - - dm.setup('fit') - assert dm.has_prepared_data is True - assert dm.has_setup_fit is True - assert dm.has_setup_test is False - - dm.setup('test') - assert dm.has_prepared_data is True - assert dm.has_setup_fit is True - assert dm.has_setup_test is True - - -def test_data_hooks_called_with_stage_kwarg(tmpdir): - dm = BoringDataModule() - dm.prepare_data() - assert dm.has_prepared_data is True - - dm.setup(stage='fit') - assert dm.has_setup_fit is True - assert dm.has_setup_test is False - - dm.setup(stage='test') - assert dm.has_setup_fit is True - assert dm.has_setup_test is True + assert dm.has_prepared_data + assert not dm.has_setup_fit + assert not dm.has_setup_test + assert not dm.has_setup_predict + + dm.setup(stage='fit') if use_kwarg else dm.setup('fit') + assert dm.has_prepared_data + assert dm.has_setup_fit + assert not dm.has_setup_validate + assert not dm.has_setup_test + assert not dm.has_setup_predict + + dm.setup(stage='validate') if use_kwarg else dm.setup('validate') + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_validate + assert not dm.has_setup_test + assert not dm.has_setup_predict + + dm.setup(stage='test') if use_kwarg else dm.setup('test') + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_validate + assert dm.has_setup_test + assert not dm.has_setup_predict + + dm.setup(stage='predict') if use_kwarg else dm.setup('predict') + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_validate + assert dm.has_setup_test + assert dm.has_setup_predict def test_dm_add_argparse_args(tmpdir): diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py index ea26310a45315..6ef2518bbef11 100644 --- a/tests/helpers/boring_model.py +++ b/tests/helpers/boring_model.py @@ -151,9 +151,11 @@ def prepare_data(self): def setup(self, stage: Optional[str] = None): if stage == "fit" or stage is None: self.random_train = Subset(self.random_full, indices=range(64)) - self.random_val = Subset(self.random_full, indices=range(64, 128)) self.dims = self.random_train[0].shape + if stage in ("fit", "validate") or stage is None: + self.random_val = Subset(self.random_full, indices=range(64, 128)) + if stage == "test" or stage is None: self.random_test = Subset(self.random_full, indices=range(128, 192)) self.dims = getattr(self, "dims", self.random_test[0].shape) From d49ccd1b9fb0afadaa28c4bead0e0cb7e5b1fc91 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 02:43:12 +0100 Subject: [PATCH 04/16] Fix doctest --- pytorch_lightning/trainer/states.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index 2688fb6754977..33a2326c518d5 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -21,10 +21,10 @@ class TrainerState(LightningEnum): functions such as `trainer.fit()` and `trainer.test(). >>> # you can compare the type with a string - >>> TrainerState.FITTING == 'FITTING' + >>> TrainerState.FITTING == 'fit' True >>> # which is case insensitive - >>> TrainerState.FINISHED == 'finished' + >>> TrainerState.FINISHED == 'FINISHED' True """ INITIALIZING = 'initializing' # trainer creation From 23db13507878a60cb17844f6133c5b7adf9fa9ca Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 03:24:29 +0100 Subject: [PATCH 05/16] stage: Optional[str] = None --- pytorch_lightning/callbacks/base.py | 6 +++--- pytorch_lightning/core/hooks.py | 8 ++++---- pytorch_lightning/trainer/callback_hook.py | 6 +++--- pytorch_lightning/trainer/trainer.py | 18 ++++++++++-------- tests/models/test_hooks.py | 16 ++++++---------- 5 files changed, 26 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 494d94cf446de..0ba1fd4ff7785 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -17,7 +17,7 @@ """ import abc -from typing import Any, Dict +from typing import Any, Dict, Optional from pytorch_lightning.core.lightning import LightningModule @@ -33,11 +33,11 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModul """Called before accelerator is being setup""" pass - def setup(self, trainer, pl_module: LightningModule, stage: str) -> None: + def setup(self, trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: """Called when fit, validate, test, predict, or tune begins""" pass - def teardown(self, trainer, pl_module: LightningModule, stage: str) -> None: + def teardown(self, trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: """Called when fit, validate, test, predict, or tune ends""" pass diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index a6567e3d52f0f..9826f9d44ac2c 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -25,14 +25,14 @@ class ModelHooks: """Hooks to be used in LightningModule.""" - def setup(self, stage: str) -> None: + def setup(self, stage: Optional[str] = None) -> None: """ Called at the beginning of fit (train + validate), validate, test, predict, or tune. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP. Args: - stage: either ``'fit'``, ``'validate'``, ``'test'``, ``'predict'``, or ``'tune'`` + stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` Example:: @@ -53,12 +53,12 @@ def setup(stage): """ - def teardown(self, stage: str) -> None: + def teardown(self, stage: Optional[str] = None) -> None: """ Called at the end of fit (train + validate), validate, test, predict, or tune. Args: - stage: either ``'fit'``, ``'validate'``, ``'test'``, ``'predict'``, or ``'tune'`` + stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` """ def on_fit_start(self) -> None: diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 71433429f7c03..f174cd725bd36 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -15,7 +15,7 @@ from abc import ABC from copy import deepcopy from inspect import signature -from typing import Any, Callable, Dict, List, Type +from typing import Any, Callable, Dict, List, Type, Optional from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.lightning import LightningModule @@ -34,12 +34,12 @@ def on_before_accelerator_backend_setup(self, model: LightningModule) -> None: for callback in self.callbacks: callback.on_before_accelerator_backend_setup(self, model) - def setup(self, model: LightningModule, stage: str) -> None: + def setup(self, model: LightningModule, stage: Optional[str]) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.setup(self, model, stage) - def teardown(self, stage: str) -> None: + def teardown(self, stage: Optional[str] = None) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.teardown(self, self.lightning_module, stage) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7cd666b17ca7b..d58de7d803146 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1031,18 +1031,20 @@ def call_setup_hook(self, model: LightningModule) -> None: if self.datamodule is not None: called = getattr(self.datamodule, f'has_setup_{state}') if not called: - self.datamodule.setup(state) + self.datamodule.setup(stage=state) - self.setup(model, state) - model.setup(state) + self.setup(model, stage=state) + model.setup(stage=state) def call_teardown_hook(self, model: LightningModule) -> None: - assert self.state.running, f"TrainerState: {self.state}" - state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state - state = state.value + if self.state.running: + state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state + state = state.value + else: + state = None - self.teardown(state) - model.teardown(state) + self.teardown(stage=state) + model.teardown(stage=state) def _reset_result_and_set_hook_fx_name(self, hook_name): # on_before_zero_grad is called within training_step diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 1a7803800b384..7c53925bd7cc4 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -404,7 +404,7 @@ def on_test_end(self): self.called.append(inspect.currentframe().f_code.co_name) super().on_test_end() - def teardown(self, stage: str): + def teardown(self, stage=None): self.called.append(inspect.currentframe().f_code.co_name) super().teardown(stage) @@ -420,12 +420,12 @@ def teardown(self, stage: str): limit_train_batches=2, limit_test_batches=1, progress_bar_refresh_rate=0, + weights_summary=None, ) assert model.called == [] trainer.fit(model) - expected = [ 'on_fit_start', 'on_pretrain_routine_start', @@ -469,11 +469,10 @@ def teardown(self, stage: str): assert model.called == expected - model2 = HookedModel() - trainer.test(model2) + model = HookedModel() + trainer.test(model, verbose=False) expected = [ - 'on_fit_start', 'on_test_model_eval', 'on_test_start', 'on_test_epoch_start', @@ -483,9 +482,6 @@ def teardown(self, stage: str): 'on_epoch_end', 'on_test_end', 'on_test_model_train', - 'on_fit_end', - 'teardown', # for 'fit' - 'teardown', # for 'test' + 'teardown', ] - - assert model2.called == expected + assert model.called == expected From 84f5fdb8e6b6b4254a0f635281c5356756d00ba5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 03:26:48 +0100 Subject: [PATCH 06/16] Trailing whitespace --- tests/core/test_datamodules.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index e1b4301842ecd..ab51a87329e2f 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -131,17 +131,17 @@ def test_data_hooks_called(tmpdir): assert not dm.has_setup_predict dm.prepare_data() - assert dm.has_prepared_data + assert dm.has_prepared_data assert not dm.has_setup_fit assert not dm.has_setup_test assert not dm.has_setup_validate assert not dm.has_setup_predict dm.setup() - assert dm.has_prepared_data - assert dm.has_setup_fit - assert dm.has_setup_test - assert dm.has_setup_validate + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_test + assert dm.has_setup_validate assert not dm.has_setup_predict @@ -153,21 +153,21 @@ def test_data_hooks_called_verbose(tmpdir, use_kwarg): assert not dm.has_setup_test dm.prepare_data() - assert dm.has_prepared_data + assert dm.has_prepared_data assert not dm.has_setup_fit assert not dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='fit') if use_kwarg else dm.setup('fit') assert dm.has_prepared_data - assert dm.has_setup_fit + assert dm.has_setup_fit assert not dm.has_setup_validate assert not dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='validate') if use_kwarg else dm.setup('validate') - assert dm.has_prepared_data - assert dm.has_setup_fit + assert dm.has_prepared_data + assert dm.has_setup_fit assert dm.has_setup_validate assert not dm.has_setup_test assert not dm.has_setup_predict @@ -180,11 +180,11 @@ def test_data_hooks_called_verbose(tmpdir, use_kwarg): assert not dm.has_setup_predict dm.setup(stage='predict') if use_kwarg else dm.setup('predict') - assert dm.has_prepared_data - assert dm.has_setup_fit - assert dm.has_setup_validate - assert dm.has_setup_test - assert dm.has_setup_predict + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_validate + assert dm.has_setup_test + assert dm.has_setup_predict def test_dm_add_argparse_args(tmpdir): From 188b9feae8b386114d792113758afb51fa9c5931 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 03:50:52 +0100 Subject: [PATCH 07/16] Update docs and CHANGELOG --- CHANGELOG.md | 3 +++ docs/source/extensions/datamodules.rst | 12 ++++++------ docs/source/starter/introduction_guide.rst | 8 ++++---- docs/source/starter/new-project.rst | 2 +- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f8f7a08b089b..f6ef0d56b3792 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) +- Changed `setup()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) + + ### Deprecated diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index a6c083dc61fcf..85134fda06fa2 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -80,7 +80,7 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa self.data_dir = data_dir self.batch_size = batch_size - def setup(self, stage=None): + def setup(self, stage: Optional[str] = None): self.mnist_test = MNIST(self.data_dir, train=False) mnist_full = MNIST(self.data_dir, train=True) self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) @@ -138,7 +138,7 @@ Here's a more realistic, complex DataModule that shows how much more reusable th MNIST(self.data_dir, train=True, download=True) MNIST(self.data_dir, train=False, download=True) - def setup(self, stage=None): + def setup(self, stage: Optional[str] = None): # Assign train/val datasets for use in dataloaders if stage == 'fit' or stage is None: @@ -382,12 +382,12 @@ still ensures the method runs on the correct devices) dm = MNISTDataModule() dm.prepare_data() - dm.setup('fit') + dm.setup(stage='fit') model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab) trainer.fit(model, dm) - dm.setup('test') + dm.setup(stage='test') trainer.test(datamodule=dm) ---------------- @@ -403,7 +403,7 @@ You can of course use DataModules in plain PyTorch code as well. dm.prepare_data() # splits/transforms - dm.setup('fit') + dm.setup(stage='fit') # use data for batch in dm.train_dataloader(): @@ -412,7 +412,7 @@ You can of course use DataModules in plain PyTorch code as well. ... # lazy load test data - dm.setup('test') + dm.setup(stage='test') for batch in dm.test_dataloader(): ... diff --git a/docs/source/starter/introduction_guide.rst b/docs/source/starter/introduction_guide.rst index 2ee31304299e0..c65894367a39e 100644 --- a/docs/source/starter/introduction_guide.rst +++ b/docs/source/starter/introduction_guide.rst @@ -240,7 +240,7 @@ In this case, it's better to group the full definition of a dataset into a `Data tokenize() build_vocab() - def setup(self): + def setup(self, stage: Optional[str] = None): # called on every GPU vocab = load_vocab() self.vocab_size = len(vocab) @@ -310,8 +310,8 @@ An alternative to using a DataModule is to defer initialization of the models mo download_data() tokenize() - def setup(self, step): - # step is either 'fit' or 'test' 90% of the time not relevant + def setup(self, stage: Optional[str] = None): + # step is either 'fit', 'validate', 'test', or 'predict'. 90% of the time not relevant data = load_data() num_classes = data.classes self.l1 = nn.Linear(..., num_classes) @@ -598,7 +598,7 @@ In this method we do all the preparation we need to do once (instead of on every MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) - def setup(self, stage): + def setup(self, stage: Optional[str] = None): # transform transform=transforms.Compose([transforms.ToTensor()]) mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transform) diff --git a/docs/source/starter/new-project.rst b/docs/source/starter/new-project.rst index 0f1362616a9b1..23f91914063d9 100644 --- a/docs/source/starter/new-project.rst +++ b/docs/source/starter/new-project.rst @@ -651,7 +651,7 @@ Make your data code reusable by organizing it into a :class:`~pytorch_lightning. MNIST(os.getcwd(), train=False, download=True) # OPTIONAL, called for every GPU/machine (assigning state is OK) - def setup(self, stage): + def setup(self, stage: Optional[str] = None): # transforms transform=transforms.Compose([ transforms.ToTensor(), From 37473f0c549590c5c342b53c564af7affb9fb05b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 03:52:14 +0100 Subject: [PATCH 08/16] Mention teardown --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f6ef0d56b3792..327f923a79ff1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,7 +38,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) -- Changed `setup()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) +- Changed `setup()` and `teardown()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) ### Deprecated From 0a30abf931ec5a5f1127bdf98514df6f489cb735 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 03:59:49 +0100 Subject: [PATCH 09/16] Self-review --- pytorch_lightning/core/datamodule.py | 20 ++++++++++---------- pytorch_lightning/trainer/callback_hook.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 31c05e3bcc4c4..1b6852c071fe1 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -57,8 +57,8 @@ def track_data_hook_calls(fn): - When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True - When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True - - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}`` - it's corresponding `dm_has_setup_{stage}` gets set to True + - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``. + Its corresponding `dm_has_setup_{stage}` attribute gets set to True Args: fn (function): Function that will be tracked to see if it has been called. @@ -226,37 +226,37 @@ def has_prepared_data(self) -> bool: @property def has_setup_fit(self) -> bool: - """Return bool letting you know if ``datamodule.setup('fit')`` has been called or not. + """Return bool letting you know if ``datamodule.setup(stage='fit')`` has been called or not. Returns: - bool: True ``if datamodule.setup('fit')`` has been called. False by default. + bool: True ``if datamodule.setup(stage='fit')`` has been called. False by default. """ return self._has_setup_fit @property def has_setup_validate(self) -> bool: - """Return bool letting you know if ``datamodule.setup('validate')`` has been called or not. + """Return bool letting you know if ``datamodule.setup(stage='validate')`` has been called or not. Returns: - bool: True if ``datamodule.setup('validate')`` has been called. False by default. + bool: True if ``datamodule.setup(stage='validate')`` has been called. False by default. """ return self._has_setup_validate @property def has_setup_test(self) -> bool: - """Return bool letting you know if ``datamodule.setup('test')`` has been called or not. + """Return bool letting you know if ``datamodule.setup(stage='test')`` has been called or not. Returns: - bool: True if ``datamodule.setup('test')`` has been called. False by default. + bool: True if ``datamodule.setup(stage='test')`` has been called. False by default. """ return self._has_setup_test @property def has_setup_predict(self) -> bool: - """Return bool letting you know if ``datamodule.setup('predict')`` has been called or not. + """Return bool letting you know if ``datamodule.setup(stage='predict')`` has been called or not. Returns: - bool: True if ``datamodule.setup('predict')`` has been called. False by default. + bool: True if ``datamodule.setup(stage='predict')`` has been called. False by default. """ return self._has_setup_predict diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index f174cd725bd36..5aa9f1a44276b 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -40,7 +40,7 @@ def setup(self, model: LightningModule, stage: Optional[str]) -> None: callback.setup(self, model, stage) def teardown(self, stage: Optional[str] = None) -> None: - """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" + """Called at the end of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.teardown(self, self.lightning_module, stage) From 0e9d69c35356824d5cc1b8e986c850ad71de50af Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 14:39:30 +0100 Subject: [PATCH 10/16] Address Borda's comments --- docs/source/conf.py | 1 + pytorch_lightning/trainer/model_hooks.py | 10 +++++++++- pytorch_lightning/trainer/trainer.py | 3 +-- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 813d5ee978821..ccf824bb37d9b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -371,6 +371,7 @@ def package_list_from_file(file): doctest_global_setup = """ import importlib import os +from typing import Optional import torch from torch import nn import pytorch_lightning as pl diff --git a/pytorch_lightning/trainer/model_hooks.py b/pytorch_lightning/trainer/model_hooks.py index e98ebf088a8dc..b924675d8505c 100644 --- a/pytorch_lightning/trainer/model_hooks.py +++ b/pytorch_lightning/trainer/model_hooks.py @@ -14,6 +14,7 @@ import inspect from abc import ABC +from typing import Optional from pytorch_lightning.core.lightning import LightningModule @@ -22,7 +23,14 @@ class TrainerModelHooksMixin(ABC): lightning_module: LightningModule - def has_arg(self, f_name, arg_name): + def is_function_implemented(self, f_name: str, model: Optional[LightningModule] = None) -> bool: + # note: currently unused - kept as it is public + if model is None: + model = self.lightning_module + f_op = getattr(model, f_name, None) + return callable(f_op) + + def has_arg(self, f_name: str, arg_name: str) -> bool: model = self.lightning_module f_op = getattr(model, f_name, None) return arg_name in inspect.signature(f_op).parameters diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d58de7d803146..45fc40731b545 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1025,8 +1025,8 @@ def tune( def call_setup_hook(self, model: LightningModule) -> None: assert self.state.running, f"TrainerState: {self.state}" + # 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders" state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state - state = state.value if self.datamodule is not None: called = getattr(self.datamodule, f'has_setup_{state}') @@ -1039,7 +1039,6 @@ def call_setup_hook(self, model: LightningModule) -> None: def call_teardown_hook(self, model: LightningModule) -> None: if self.state.running: state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state - state = state.value else: state = None From 60a479e6de9170be356c11cc85f64006ea020681 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 18:11:33 +0100 Subject: [PATCH 11/16] Update CHANGELOG --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 327f923a79ff1..b787b35dbaace 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -101,6 +101,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324)) +- Fixed `.teardown(stage='fit')` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) + + +- Fixed `.on_fit_{start,end}()` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) + + - Fixed PyTorch Profiler with `emit_nvtx` ([#6260](https://github.com/PyTorchLightning/pytorch-lightning/pull/6260)) From d72b3fc1da8d2b706eacd2dcccc805230a86012f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 8 Mar 2021 03:04:25 +0100 Subject: [PATCH 12/16] Add DataModule.teardown --- docs/source/extensions/datamodules.rst | 13 +- pytorch_lightning/core/datamodule.py | 66 +++++-- pytorch_lightning/core/hooks.py | 72 ++++---- pytorch_lightning/trainer/trainer.py | 5 + tests/core/test_datamodules.py | 61 ++++++- tests/models/test_hooks.py | 231 ++++++++++++++++++------- 6 files changed, 324 insertions(+), 124 deletions(-) diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index 85134fda06fa2..0cb40e836b639 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -94,6 +94,10 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=self.batch_size) + def teardown(self, stage: Optional[str] = None): + # Used to clean-up when the run is finished + ... + But now, as the complexity of your processing grows (transforms, multiple-GPU training), you can let Lightning handle those details for you while making this dataset reusable so you can share with colleagues or use in different projects. @@ -243,7 +247,10 @@ There are also data operations you might want to perform on every GPU. Use setup self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape) -.. warning:: `setup` is called from every process. Setting state here is okay. +.. warning:: ``setup`` is called from every process. Setting state here is okay. + + +.. note:: ``teardown`` can be used to clean up the state. It is also called from every process train_dataloader @@ -411,10 +418,14 @@ You can of course use DataModules in plain PyTorch code as well. for batch in dm.val_dataloader(): ... + dm.teardown(stage='fit') + # lazy load test data dm.setup(stage='test') for batch in dm.test_dataloader(): ... + dm.teardown(stage='test') + But overall, DataModules encourage reproducibility by allowing all details of a dataset to be specified in a unified structure. diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 1b6852c071fe1..86e8654476fb0 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -45,6 +45,8 @@ def __call__(cls, *args, **kwargs): cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data)) # Track setup calls cls.setup = track_data_hook_calls(cls.setup) + # Track teardown calls + cls.teardown = track_data_hook_calls(cls.teardown) # Get instance of LightningDataModule by mocking its __init__ via __call__ obj = type.__call__(cls, *args, **kwargs) @@ -53,12 +55,13 @@ def __call__(cls, *args, **kwargs): def track_data_hook_calls(fn): - """A decorator that checks if prepare_data/setup have been called. + """A decorator that checks if prepare_data/setup/teardown has been called. - When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True - When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``. Its corresponding `dm_has_setup_{stage}` attribute gets set to True + - ``dm.teardown()`` and ``dm.teardown(stage)`` acts exactly like ``dm.setup`` Args: fn (function): Function that will be tracked to see if it has been called. @@ -72,9 +75,10 @@ def wrapped_fn(*args, **kwargs): # The object instance from which setup or prepare_data was called obj = args[0] + name = fn.__name__ # If calling setup, we check the stage and assign stage-specific bool args - if fn.__name__ == "setup": + if name in ("setup", "teardown"): # Get stage either by grabbing from args or checking kwargs. # If not provided, set call status of 'fit', 'validate', and 'test' to True. @@ -83,11 +87,11 @@ def wrapped_fn(*args, **kwargs): if stage is None: for s in ("fit", "validate", "test"): - setattr(obj, f"_has_setup_{s}", True) + setattr(obj, f"_has_{name}_{s}", True) else: - setattr(obj, f"_has_setup_{stage}", True) + setattr(obj, f"_has_{name}_{stage}", True) - if fn.__name__ == "prepare_data": + elif name == "prepare_data": obj._has_prepared_data = True return fn(*args, **kwargs) @@ -120,14 +124,18 @@ def val_dataloader(self): def test_dataloader(self): test_split = Dataset(...) return DataLoader(test_split) + def teardown(self): + # clean up after fit or test + # called on every process in DDP - A DataModule implements 5 key methods: + A DataModule implements 6 key methods: * **prepare_data** (things to do on 1 GPU/TPU not on every GPU/TPU in distributed mode). * **setup** (things to do on every accelerator in distributed mode). * **train_dataloader** the training dataloader. * **val_dataloader** the val dataloader(s). * **test_dataloader** the test dataloader(s). + * **teardown** (things to do on every accelerator in distributed mode when finished) This allows you to share a full dataset without explaining how to download, @@ -155,11 +163,17 @@ def __init__( # Private attrs to keep track of whether or not data hooks have been called yet self._has_prepared_data = False + self._has_setup_fit = False self._has_setup_validate = False self._has_setup_test = False self._has_setup_predict = False + self._has_teardown_fit = False + self._has_teardown_validate = False + self._has_teardown_test = False + self._has_teardown_predict = False + @property def train_transforms(self): """ @@ -260,13 +274,41 @@ def has_setup_predict(self) -> bool: """ return self._has_setup_predict - @abstractmethod - def prepare_data(self, *args, **kwargs): - pass + @property + def has_teardown_fit(self) -> bool: + """Return bool letting you know if ``datamodule.teardown(stage='fit')`` has been called or not. - @abstractmethod - def setup(self, stage: Optional[str] = None): - pass + Returns: + bool: True ``if datamodule.teardown(stage='fit')`` has been called. False by default. + """ + return self._has_teardown_fit + + @property + def has_teardown_validate(self) -> bool: + """Return bool letting you know if ``datamodule.teardown(stage='validate')`` has been called or not. + + Returns: + bool: True if ``datamodule.teardown(stage='validate')`` has been called. False by default. + """ + return self._has_teardown_validate + + @property + def has_teardown_test(self) -> bool: + """Return bool letting you know if ``datamodule.teardown(stage='test')`` has been called or not. + + Returns: + bool: True if ``datamodule.teardown(stage='test')`` has been called. False by default. + """ + return self._has_teardown_test + + @property + def has_teardown_predict(self) -> bool: + """Return bool letting you know if ``datamodule.teardown(stage='predict')`` has been called or not. + + Returns: + bool: True if ``datamodule.teardown(stage='predict')`` has been called. False by default. + """ + return self._has_teardown_predict @classmethod def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 9826f9d44ac2c..40788324ade73 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -25,42 +25,6 @@ class ModelHooks: """Hooks to be used in LightningModule.""" - def setup(self, stage: Optional[str] = None) -> None: - """ - Called at the beginning of fit (train + validate), validate, test, predict, or tune. - This is a good hook when you need to build models dynamically or adjust something about them. - This hook is called on every process when using DDP. - - Args: - stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` - - Example:: - - class LitModel(...): - def __init__(self): - self.l1 = None - - def prepare_data(self): - download_data() - tokenize() - - # don't do this - self.something = else - - def setup(stage): - data = Load_data(...) - self.l1 = nn.Linear(28, data.num_classes) - - """ - - def teardown(self, stage: Optional[str] = None) -> None: - """ - Called at the end of fit (train + validate), validate, test, predict, or tune. - - Args: - stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` - """ - def on_fit_start(self) -> None: """ Called at the very beginning of fit. @@ -383,6 +347,42 @@ def prepare_data(self): model.test_dataloader() """ + def setup(self, stage: Optional[str] = None) -> None: + """ + Called at the beginning of fit (train + validate), validate, test, predict, or tune. + This is a good hook when you need to build models dynamically or adjust something about them. + This hook is called on every process when using DDP. + + Args: + stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` + + Example:: + + class LitModel(...): + def __init__(self): + self.l1 = None + + def prepare_data(self): + download_data() + tokenize() + + # don't do this + self.something = else + + def setup(stage): + data = Load_data(...) + self.l1 = nn.Linear(28, data.num_classes) + + """ + + def teardown(self, stage: Optional[str] = None) -> None: + """ + Called at the end of fit (train + validate), validate, test, predict, or tune. + + Args: + stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` + """ + def train_dataloader(self) -> Any: """ Implement one or more PyTorch DataLoaders for training. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 45fc40731b545..970fcb6a2fd80 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1042,6 +1042,11 @@ def call_teardown_hook(self, model: LightningModule) -> None: else: state = None + if self.datamodule is not None: + called = getattr(self.datamodule, f'has_teardown_{state}') + if not called: + self.datamodule.teardown(stage=state) + self.teardown(stage=state) model.teardown(stage=state) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index ab51a87329e2f..9e79cc5097e43 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -129,6 +129,10 @@ def test_data_hooks_called(tmpdir): assert not dm.has_setup_test assert not dm.has_setup_validate assert not dm.has_setup_predict + assert not dm.has_teardown_fit + assert not dm.has_teardown_test + assert not dm.has_teardown_validate + assert not dm.has_teardown_predict dm.prepare_data() assert dm.has_prepared_data @@ -136,6 +140,10 @@ def test_data_hooks_called(tmpdir): assert not dm.has_setup_test assert not dm.has_setup_validate assert not dm.has_setup_predict + assert not dm.has_teardown_fit + assert not dm.has_teardown_test + assert not dm.has_teardown_validate + assert not dm.has_teardown_predict dm.setup() assert dm.has_prepared_data @@ -143,49 +151,84 @@ def test_data_hooks_called(tmpdir): assert dm.has_setup_test assert dm.has_setup_validate assert not dm.has_setup_predict + assert not dm.has_teardown_fit + assert not dm.has_teardown_test + assert not dm.has_teardown_validate + assert not dm.has_teardown_predict + + dm.teardown() + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_test + assert dm.has_setup_validate + assert not dm.has_setup_predict + assert dm.has_teardown_fit + assert dm.has_teardown_test + assert dm.has_teardown_validate + assert not dm.has_teardown_predict @pytest.mark.parametrize("use_kwarg", (False, True)) def test_data_hooks_called_verbose(tmpdir, use_kwarg): dm = BoringDataModule() - assert not dm.has_prepared_data - assert not dm.has_setup_fit - assert not dm.has_setup_test - dm.prepare_data() - assert dm.has_prepared_data assert not dm.has_setup_fit assert not dm.has_setup_test + assert not dm.has_setup_validate assert not dm.has_setup_predict + assert not dm.has_teardown_fit + assert not dm.has_teardown_test + assert not dm.has_teardown_validate + assert not dm.has_teardown_predict dm.setup(stage='fit') if use_kwarg else dm.setup('fit') - assert dm.has_prepared_data assert dm.has_setup_fit assert not dm.has_setup_validate assert not dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='validate') if use_kwarg else dm.setup('validate') - assert dm.has_prepared_data assert dm.has_setup_fit assert dm.has_setup_validate assert not dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='test') if use_kwarg else dm.setup('test') - assert dm.has_prepared_data assert dm.has_setup_fit assert dm.has_setup_validate assert dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='predict') if use_kwarg else dm.setup('predict') - assert dm.has_prepared_data assert dm.has_setup_fit assert dm.has_setup_validate assert dm.has_setup_test assert dm.has_setup_predict + dm.teardown(stage='fit') if use_kwarg else dm.teardown('fit') + assert dm.has_teardown_fit + assert not dm.has_teardown_validate + assert not dm.has_teardown_test + assert not dm.has_teardown_predict + + dm.teardown(stage='validate') if use_kwarg else dm.teardown('validate') + assert dm.has_teardown_fit + assert dm.has_teardown_validate + assert not dm.has_teardown_test + assert not dm.has_teardown_predict + + dm.teardown(stage='test') if use_kwarg else dm.teardown('test') + assert dm.has_teardown_fit + assert dm.has_teardown_validate + assert dm.has_teardown_test + assert not dm.has_teardown_predict + + dm.teardown(stage='predict') if use_kwarg else dm.teardown('predict') + assert dm.has_teardown_fit + assert dm.has_teardown_validate + assert dm.has_teardown_test + assert dm.has_teardown_predict + def test_dm_add_argparse_args(tmpdir): parser = ArgumentParser() diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 7c53925bd7cc4..b6400ec5063f2 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect from unittest import mock from unittest.mock import PropertyMock @@ -20,7 +19,7 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.trainer.states import TrainerState -from tests.helpers import BoringModel, RandomDataset +from tests.helpers import BoringModel, RandomDataset, BoringDataModule from tests.helpers.runif import RunIf @@ -260,7 +259,7 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx): def test_trainer_model_hook_system(tmpdir): - """Test the hooks system.""" + """Test the LightningModule hook system.""" class HookedModel(BoringModel): @@ -269,149 +268,151 @@ def __init__(self): self.called = [] def on_after_backward(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_after_backward") super().on_after_backward() - def on_before_zero_grad(self, optimizer): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_before_zero_grad(optimizer) + def on_before_zero_grad(self, *args, **kwargs): + self.called.append("on_before_zero_grad") + super().on_before_zero_grad(*args, **kwargs) def on_epoch_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_epoch_start") super().on_epoch_start() def on_epoch_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_epoch_end") super().on_epoch_end() def on_fit_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_fit_start") super().on_fit_start() def on_fit_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_fit_end") super().on_fit_end() - def on_hpc_load(self, checkpoint): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_hpc_load(checkpoint) + def on_hpc_load(self, *args, **kwargs): + self.called.append("on_hpc_load") + super().on_hpc_load(*args, **kwargs) - def on_hpc_save(self, checkpoint): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_hpc_save(checkpoint) + def on_hpc_save(self, *args, **kwargs): + self.called.append("on_hpc_save") + super().on_hpc_save(*args, **kwargs) - def on_load_checkpoint(self, checkpoint): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_load_checkpoint(checkpoint) + def on_load_checkpoint(self, *args, **kwargs): + self.called.append("on_load_checkpoint") + super().on_load_checkpoint(*args, **kwargs) - def on_save_checkpoint(self, checkpoint): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_save_checkpoint(checkpoint) + def on_save_checkpoint(self, *args, **kwargs): + self.called.append("on_save_checkpoint") + super().on_save_checkpoint(*args, **kwargs) def on_pretrain_routine_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_pretrain_routine_start") super().on_pretrain_routine_start() def on_pretrain_routine_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_pretrain_routine_end") super().on_pretrain_routine_end() def on_train_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_train_start") super().on_train_start() def on_train_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_train_end") super().on_train_end() - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_train_batch_start(batch, batch_idx, dataloader_idx) + def on_train_batch_start(self, *args, **kwargs): + self.called.append("on_train_batch_start") + super().on_train_batch_start(*args, **kwargs) - def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx) + def on_train_batch_end(self, *args, **kwargs): + self.called.append("on_train_batch_end") + super().on_train_batch_end(*args, **kwargs) def on_train_epoch_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_train_epoch_start") super().on_train_epoch_start() def on_train_epoch_end(self, outputs): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_train_epoch_end") super().on_train_epoch_end(outputs) def on_validation_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_validation_start") super().on_validation_start() def on_validation_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_validation_end") super().on_validation_end() - def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_validation_batch_start(batch, batch_idx, dataloader_idx) + def on_validation_batch_start(self, *args, **kwargs): + self.called.append("on_validation_batch_start") + super().on_validation_batch_start(*args, **kwargs) - def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx) + def on_validation_batch_end(self, *args, **kwargs): + self.called.append("on_validation_batch_end") + super().on_validation_batch_end(*args, **kwargs) def on_validation_epoch_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_validation_epoch_start") super().on_validation_epoch_start() def on_validation_epoch_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_validation_epoch_end") super().on_validation_epoch_end() def on_test_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_test_start") super().on_test_start() - def on_test_batch_start(self, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_test_batch_start(batch, batch_idx, dataloader_idx) + def on_test_batch_start(self, *args, **kwargs): + self.called.append("on_test_batch_start") + super().on_test_batch_start(*args, **kwargs) - def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_test_batch_end(outputs, batch, batch_idx, dataloader_idx) + def on_test_batch_end(self, *args, **kwargs): + self.called.append("on_test_batch_end") + super().on_test_batch_end(*args, **kwargs) def on_test_epoch_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_test_epoch_start") super().on_test_epoch_start() def on_test_epoch_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_test_epoch_end") super().on_test_epoch_end() def on_validation_model_eval(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_validation_model_eval") super().on_validation_model_eval() def on_validation_model_train(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_validation_model_train") super().on_validation_model_train() def on_test_model_eval(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_test_model_eval") super().on_test_model_eval() def on_test_model_train(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_test_model_train") super().on_test_model_train() def on_test_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_test_end") super().on_test_end() + def setup(self, stage=None): + self.called.append(f"setup_{stage}") + super().setup(stage=stage) + def teardown(self, stage=None): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append(f"teardown_{stage}") super().teardown(stage) model = HookedModel() - assert model.called == [] - # fit model trainer = Trainer( default_root_dir=tmpdir, @@ -427,6 +428,7 @@ def teardown(self, stage=None): trainer.fit(model) expected = [ + 'setup_fit', 'on_fit_start', 'on_pretrain_routine_start', 'on_pretrain_routine_end', @@ -464,15 +466,15 @@ def teardown(self, stage=None): 'on_validation_model_train', 'on_train_end', 'on_fit_end', - 'teardown', + 'teardown_fit', ] - assert model.called == expected model = HookedModel() - trainer.test(model, verbose=False) + expected = [ + 'setup_test', 'on_test_model_eval', 'on_test_start', 'on_test_epoch_start', @@ -482,6 +484,103 @@ def teardown(self, stage=None): 'on_epoch_end', 'on_test_end', 'on_test_model_train', - 'teardown', + 'teardown_test', ] assert model.called == expected + + +def test_trainer_datamodule_hook_system(tmpdir): + """Test the LightningDataModule hook system.""" + + class HookedDataModule(BoringDataModule): + def __init__(self): + super().__init__() + self.called = [] + + def prepare_data(self): + self.called.append("prepare_data") + super().prepare_data() + + def setup(self, stage=None): + self.called.append(f"setup_{stage}") + super().setup(stage=stage) + + def teardown(self, stage=None): + self.called.append(f"teardown_{stage}") + super().teardown(stage=stage) + + def train_dataloader(self): + self.called.append("train_dataloader") + return super().train_dataloader() + + def test_dataloader(self): + self.called.append("test_dataloader") + return super().test_dataloader() + + def val_dataloader(self): + self.called.append("val_dataloader") + return super().val_dataloader() + + def predict_dataloader(self): + self.called.append("predict_dataloader") + + def transfer_batch_to_device(self, *args, **kwargs): + self.called.append("transfer_batch_to_device") + return super().transfer_batch_to_device(*args, **kwargs) + + def on_before_batch_transfer(self, *args, **kwargs): + self.called.append("on_before_batch_transfer") + return super().on_before_batch_transfer(*args, **kwargs) + + def on_after_batch_transfer(self, *args, **kwargs): + self.called.append("on_after_batch_transfer") + return super().on_after_batch_transfer(*args, **kwargs) + + model = BoringModel() + dm = HookedDataModule() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=1, + limit_train_batches=2, + limit_test_batches=1, + progress_bar_refresh_rate=0, + weights_summary=None, + ) + trainer.fit(model, datamodule=dm) + + expected = [ + 'prepare_data', + 'setup_fit', + 'val_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'train_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'teardown_fit' + ] + assert dm.called == expected + + dm = HookedDataModule() + trainer.test(model, datamodule=dm, verbose=False) + + expected = [ + 'prepare_data', + 'setup_test', + 'test_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'teardown_test' + ] + assert dm.called == expected From c2c116a3775e55d3c5646ad1981d86c31e7e09ee Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 8 Mar 2021 03:13:13 +0100 Subject: [PATCH 13/16] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b787b35dbaace..e06d6c48dd9d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915)) +- Added `teardown()` hook to LightningDataModule ([#4673](https://github.com/PyTorchLightning/pytorch-lightning/pull/4673)) + + - Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274)) From 7ad0688582f79e831bc4162bfcf756d4858a35eb Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 8 Mar 2021 03:23:25 +0100 Subject: [PATCH 14/16] Fix docs --- docs/source/common/lightning_module.rst | 4 ++-- docs/source/extensions/datamodules.rst | 2 +- pytorch_lightning/core/datamodule.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index c02f23ac60d09..cd544be3e42d0 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1248,7 +1248,7 @@ prepare_data setup ~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.setup +.. automethod:: pytorch_lightning.core.hooks.DataHooks.setup :noindex: tbptt_split_batch @@ -1260,7 +1260,7 @@ tbptt_split_batch teardown ~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.teardown +.. automethod:: pytorch_lightning.core.hooks.DataHooks.teardown :noindex: train_dataloader diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index 0cb40e836b639..881febe21316d 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -418,7 +418,7 @@ You can of course use DataModules in plain PyTorch code as well. for batch in dm.val_dataloader(): ... - dm.teardown(stage='fit') + dm.teardown(stage='fit') # lazy load test data dm.setup(stage='test') diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 86e8654476fb0..8be4974bf1a29 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -61,7 +61,7 @@ def track_data_hook_calls(fn): - When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``. Its corresponding `dm_has_setup_{stage}` attribute gets set to True - - ``dm.teardown()`` and ``dm.teardown(stage)`` acts exactly like ``dm.setup`` + - ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup`` Args: fn (function): Function that will be tracked to see if it has been called. From 61dff83318b37d5a73aad523df02c0bc9b0b5194 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 8 Mar 2021 15:48:40 +0100 Subject: [PATCH 15/16] flake8 --- pytorch_lightning/core/datamodule.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 8be4974bf1a29..4514b140c46ad 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -15,7 +15,6 @@ import functools import inspect -from abc import abstractmethod from argparse import ArgumentParser, Namespace from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union From 3a8d8d1f2753877de00d8359ea2d583121f71be3 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 11 Mar 2021 04:04:38 +0100 Subject: [PATCH 16/16] .validate() --- tests/models/test_hooks.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index cd1ebac11ddc9..c81e7eb323e6d 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -474,6 +474,7 @@ def teardown(self, stage=None): trainer.validate(model, verbose=False) expected = [ + 'setup_validate', 'on_validation_model_eval', 'on_validation_start', 'on_validation_epoch_start', @@ -483,7 +484,7 @@ def teardown(self, stage=None): 'on_epoch_end', 'on_validation_end', 'on_validation_model_train', - 'teardown', + 'teardown_validate', ] assert model.called == expected @@ -564,6 +565,7 @@ def on_after_batch_transfer(self, *args, **kwargs): limit_test_batches=1, progress_bar_refresh_rate=0, weights_summary=None, + reload_dataloaders_every_epoch=True, ) trainer.fit(model, datamodule=dm) @@ -581,6 +583,7 @@ def on_after_batch_transfer(self, *args, **kwargs): 'on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer', + 'val_dataloader', 'on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer', @@ -588,6 +591,20 @@ def on_after_batch_transfer(self, *args, **kwargs): ] assert dm.called == expected + dm = HookedDataModule() + trainer.validate(model, datamodule=dm, verbose=False) + + expected = [ + 'prepare_data', + 'setup_validate', + 'val_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'teardown_validate' + ] + assert dm.called == expected + dm = HookedDataModule() trainer.test(model, datamodule=dm, verbose=False)