From 83934cfc38ae23293dc0d039e84997caa06d5d5b Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 17 Jun 2020 18:09:06 -0400 Subject: [PATCH 1/9] allow regression metrics to import --- pytorch_lightning/trainer/trainer.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f421ff9c3b3a5..915afbbbcecad 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -843,6 +843,10 @@ def fit( if self.is_function_implemented('on_fit_start'): model.on_fit_start() + self.setup('fit') + if self.is_function_implemented('setup'): + model.setup('fit') + # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 # or in the case where each node needs to do its own manipulation in which case just local_rank=0 if self.can_prepare_data(): @@ -945,6 +949,10 @@ def fit( if self.is_function_implemented('on_fit_end'): model.on_fit_end() + self.teardown('fit') + if self.is_function_implemented('teardown'): + model.teardown('fit') + # return 1 when finished # used for testing or when we need to know that training succeeded return 1 @@ -1128,6 +1136,11 @@ def test( trainer = Trainer() trainer.test(model, test_dataloaders=test) """ + + self.setup('test') + if self.is_function_implemented('setup'): + model.setup('test') + if model is None and ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0: raise MisconfigurationException( 'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.') @@ -1167,6 +1180,10 @@ def test( self.testing = False + self.teardown('test') + if self.is_function_implemented('teardown'): + model.teardown('test') + def check_model_configuration(self, model: LightningModule): r""" Checks that the model is configured correctly before training or testing is started. From 26645d91fe94635e130a3303b7686dde27e82615 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 17 Jun 2020 18:13:16 -0400 Subject: [PATCH 2/9] allow regression metrics to import --- pytorch_lightning/core/hooks.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 7afce9f37b316..672f3c09ce4a3 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -17,6 +17,22 @@ class ModelHooks(Module): + def setup(self, step: str): + """ + Called at the beginning of fit and test. + + Args: + step: either 'fit' or 'test' + """ + + def teardown(self, step: str): + """ + Called at the end of fit and test. + + Args: + step: either 'fit' or 'test' + """ + def on_fit_start(self): """ Called at the very beginning of fit. From 26d357925ff1e7461683170e2f988f505c1ea3e1 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 17 Jun 2020 18:17:00 -0400 Subject: [PATCH 3/9] allow regression metrics to import --- pytorch_lightning/callbacks/base.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 63a1c148cb773..ec59314888fc3 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -14,6 +14,14 @@ class Callback(abc.ABC): Abstract base class used to build new callbacks. """ + def setup(self, trainer, step: str): + """Called when fit or test begins""" + pass + + def teardown(self, trainer, step: str): + """Called when fit or test ends""" + pass + def on_init_start(self, trainer): """Called when the trainer initialization begins, model has not yet been set.""" pass From 17de8082ae879c08cc3f9f28768b66b86402c03d Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 17 Jun 2020 18:23:26 -0400 Subject: [PATCH 4/9] allow regression metrics to import --- tests/callbacks/test_callbacks.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index a4e844b98cfd8..16ea2038c1a26 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -46,6 +46,8 @@ def _check_args(trainer, pl_module): class TestCallback(Callback): def __init__(self): super().__init__() + self.setup_called = False + self.teardown_called = False self.on_init_start_called = False self.on_init_end_called = False self.on_fit_start_called = False @@ -67,6 +69,14 @@ def __init__(self): self.on_test_start_called = False self.on_test_end_called = False + def setup(self, trainer, step: str): + assert isinstance(trainer, Trainer) + self.setup_called = True + + def teardown(self, trainer, step: str): + assert isinstance(trainer, Trainer) + self.teardown_called = True + def on_init_start(self, trainer): assert isinstance(trainer, Trainer) self.on_init_start_called = True @@ -157,6 +167,8 @@ def on_test_end(self, trainer, pl_module): progress_bar_refresh_rate=0, ) + assert not test_callback.setup_called + assert not test_callback.teardown_called assert not test_callback.on_init_start_called assert not test_callback.on_init_end_called assert not test_callback.on_fit_start_called @@ -184,6 +196,8 @@ def on_test_end(self, trainer, pl_module): assert trainer.callbacks[0] == test_callback assert test_callback.on_init_start_called assert test_callback.on_init_end_called + assert not test_callback.setup_called + assert not test_callback.teardown_called assert not test_callback.on_fit_start_called assert not test_callback.on_fit_end_called assert not test_callback.on_sanity_check_start_called @@ -205,6 +219,8 @@ def on_test_end(self, trainer, pl_module): trainer.fit(model) + assert test_callback.setup_called + assert test_callback.teardown_called assert test_callback.on_init_start_called assert test_callback.on_init_end_called assert test_callback.on_fit_start_called @@ -226,11 +242,17 @@ def on_test_end(self, trainer, pl_module): assert not test_callback.on_test_start_called assert not test_callback.on_test_end_called + # reset setup teardown callback + test_callback.teardown_called = False + test_callback.setup_called = False + test_callback = TestCallback() trainer_options.update(callbacks=[test_callback]) trainer = Trainer(**trainer_options) trainer.test(model) + assert test_callback.setup_called + assert test_callback.teardown_called assert test_callback.on_test_batch_start_called assert test_callback.on_test_batch_end_called assert test_callback.on_test_start_called From 21bcaf981da1bbc71117ee350142dee2603ea02a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 17 Jun 2020 18:26:07 -0400 Subject: [PATCH 5/9] allow regression metrics to import --- docs/source/hooks.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/hooks.rst b/docs/source/hooks.rst index 18bfb028d4406..86a659e148a27 100644 --- a/docs/source/hooks.rst +++ b/docs/source/hooks.rst @@ -12,6 +12,7 @@ To enable a hook, simply override the method in your LightningModule and the tra 3. Add it in the correct place in :mod:`pytorch_lightning.trainer` where it should be called. +--- Hooks lifecycle --------------- @@ -71,7 +72,10 @@ Test loop - ``torch.set_grad_enabled(True)`` - :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_post_performance_check` +--- +General hooks +------------- .. automodule:: pytorch_lightning.core.hooks :noindex: \ No newline at end of file From 57870572d117052b6aae78271190b66cbbd45194 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 17 Jun 2020 18:28:58 -0400 Subject: [PATCH 6/9] allow regression metrics to import --- pytorch_lightning/trainer/training_loop.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index da751b7425d23..54b8c0271582d 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -714,7 +714,6 @@ def run_training_teardown(self): # summarize profile results self.profiler.describe() - self._teardown_already_run = True def training_forward(self, batch, batch_idx, opt_idx, hiddens): From 103744bedf16d9c1fd36dd87749f1ffcdff6abf4 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 17 Jun 2020 18:38:19 -0400 Subject: [PATCH 7/9] allow regression metrics to import --- pytorch_lightning/trainer/callback_hook.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 5f4eff0f2eefd..b3d8f31b1084f 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -11,6 +11,16 @@ class TrainerCallbackHookMixin(ABC): callbacks: List[Callback] = [] get_model: Callable = ... + def setup(self, step: str): + """Called in the beginning of fit and test""" + for callback in self.callbacks: + callback.setup(self, step) + + def teardown(self, step: str): + """Called at the end of fit and test""" + for callback in self.callbacks: + callback.teardown(self, step) + def on_init_start(self): """Called when the trainer initialization begins, model has not yet been set.""" for callback in self.callbacks: From 11bee2f66ea5749f9851f4c7da2881f9f95fbf60 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 17 Jun 2020 18:47:28 -0400 Subject: [PATCH 8/9] allow regression metrics to import --- pytorch_lightning/trainer/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 915afbbbcecad..7d5c555579ef5 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1136,9 +1136,9 @@ def test( trainer = Trainer() trainer.test(model, test_dataloaders=test) """ - self.setup('test') if self.is_function_implemented('setup'): + model = model if model is not None else self.model model.setup('test') if model is None and ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0: @@ -1182,7 +1182,7 @@ def test( self.teardown('test') if self.is_function_implemented('teardown'): - model.teardown('test') + self.model.teardown('test') def check_model_configuration(self, model: LightningModule): r""" From 7d49683dfa9ec62e4e5d1f278c64e5fd69656a35 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 17 Jun 2020 19:02:31 -0400 Subject: [PATCH 9/9] allow regression metrics to import --- pytorch_lightning/trainer/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7d5c555579ef5..d41f5ff3fd6b9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1138,8 +1138,8 @@ def test( """ self.setup('test') if self.is_function_implemented('setup'): - model = model if model is not None else self.model - model.setup('test') + model_ref = self.model if model is None else model + model_ref.setup('test') if model is None and ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0: raise MisconfigurationException(