From 37e75824861765281bef9fcb628b91cb7bb24e0a Mon Sep 17 00:00:00 2001 From: Peter Yu <2057325+yukw777@users.noreply.github.com> Date: Mon, 15 Jun 2020 08:02:37 -0400 Subject: [PATCH] Add ckpt_path option to LightningModule.test() (#2190) * Add ckpt_path option to LightningModule.test() If ckpt_path is "best" (default), it loads the best weights saved by ModelCheckpoint for the test loop. If ckpt_path is a path to a checkpoint file, it loads the weights from the file for the test loop. If ckpt_path is None, it uses the weights from the end of training for the test loop. If model parameter is set, ckpt_path is ignored. * Update test_set.rst Co-authored-by: William Falcon --- docs/source/test_set.rst | 27 ++++++++++++- pytorch_lightning/trainer/trainer.py | 39 +++++++++++++++++-- tests/base/model_valid_epoch_ends.py | 7 +++- tests/models/test_restore.py | 2 +- tests/trainer/test_dataloaders.py | 58 +++++++++++++++++++--------- tests/trainer/test_trainer.py | 47 ++++++++++++++++++++++ 6 files changed, 155 insertions(+), 25 deletions(-) diff --git a/docs/source/test_set.rst b/docs/source/test_set.rst index 7873f765a5092..e88e205e4f890 100644 --- a/docs/source/test_set.rst +++ b/docs/source/test_set.rst @@ -5,16 +5,39 @@ Lightning forces the user to run the test set separately to make sure it isn't e Test after fit -------------- -To run the test set after training completes, use this method +To run the test set after training completes, use this method. .. code-block:: python # run full training trainer.fit(model) - # run test set + # (1) load the best checkpoint automatically (lightning tracks this for you) trainer.test() + # (2) don't load a checkpoint, instead use the model with the latest weights + trainer.test(ckpt_path=None) + + # (3) test using a specific checkpoint + trainer.test(ckpt_path='/path/to/my_checkpoint.ckpt') + + # (4) test with an explicit model (will use this model and not load a checkpoint) + trainer.test(model) + + +Test multiple models +-------------------- +You can run the test set on multiple models using the same trainer instance. + +.. code-block:: python + + model1 = LitModel() + model2 = GANModel() + + trainer = Trainer() + trainer.test(model1) + trainer.test(model2) + Test pre-trained model ---------------------- diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c6b22bb196535..36035bbf9766c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1007,7 +1007,8 @@ def run_pretrain_routine(self, model: LightningModule): def test( self, model: Optional[LightningModule] = None, - test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None + test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + ckpt_path: Optional[str] = 'best' ): r""" @@ -1019,10 +1020,13 @@ def test( test_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. + ckpt_path: Either ``best`` or path to the checkpoint you wish to test. + If ``None``, use the weights from the last epoch to test. Default to ``best``. + Example:: # Option 1 - # run test after fitting + # run test with the best checkpoint from ``ModelCheckpoint`` after fitting. test = DataLoader(...) trainer = Trainer() model = LightningModule() @@ -1031,12 +1035,41 @@ def test( trainer.test(test_dataloaders=test) # Option 2 - # run test from a loaded model + # run test with the specified checkpoint after fitting + test = DataLoader(...) + trainer = Trainer() + model = LightningModule() + + trainer.fit(model) + trainer.test(test_dataloaders=test, ckpt_path='path/to/checkpoint.ckpt') + + # Option 3 + # run test with the weights from the end of training after fitting + test = DataLoader(...) + trainer = Trainer() + model = LightningModule() + + trainer.fit(model) + trainer.test(test_dataloaders=test, ckpt_path=None) + + # Option 4 + # run test from a loaded model. ``ckpt_path`` is ignored in this case. test = DataLoader(...) model = LightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') trainer = Trainer() trainer.test(model, test_dataloaders=test) """ + if model is None and ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0: + raise MisconfigurationException( + 'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.') + + # if model is not given (None), ckpt_path is given, + # load the given checkpoint for testing + if model is None and ckpt_path is not None: + # ckpt_path is 'best' so load the best model + if ckpt_path == 'best': + ckpt_path = self.checkpoint_callback.best_model_path + model = self.get_model().load_from_checkpoint(ckpt_path) self.testing = True diff --git a/tests/base/model_valid_epoch_ends.py b/tests/base/model_valid_epoch_ends.py index 6c4844d3e5c9e..b8c02d7cc6194 100644 --- a/tests/base/model_valid_epoch_ends.py +++ b/tests/base/model_valid_epoch_ends.py @@ -7,6 +7,7 @@ class ValidationEpochEndVariations(ABC): """ Houses all variations of validation_epoch_end steps """ + def validation_epoch_end(self, outputs): """ Called at the end of validation to aggregate outputs @@ -50,5 +51,9 @@ def _mean(res, key): pbar[key] = metric_out logs[key] = metric_out - results = {'progress_bar': pbar, 'log': logs} + results = { + 'val_loss': torch.stack([v for k, v in pbar.items() if k.startswith('val_loss')]).mean(), + 'progress_bar': pbar, + 'log': logs + } return results diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 8650b0fd89a04..99919ce402600 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -119,7 +119,7 @@ def test_load_model_from_checkpoint(tmpdir): # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) - trainer.test() + trainer.test(ckpt_path=None) # correct result and ok accuracy assert result == 1, 'training failed to complete' diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 4dde52536e985..b700197b9cc73 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -97,12 +97,20 @@ def test_multiple_val_dataloader(tmpdir): tutils.run_prediction(dataloader, trainer.model) -def test_multiple_test_dataloader(tmpdir): +@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) +def test_multiple_test_dataloader(tmpdir, ckpt_path): """Verify multiple test_dataloader.""" - model = EvalModelTemplate() - model.test_dataloader = model.test_dataloader__multiple - model.test_step = model.test_step__multiple_dataloaders + model_template = EvalModelTemplate() + + class MultipleTestDataloaderModel(EvalModelTemplate): + def test_dataloader(self): + return model_template.test_dataloader__multiple() + + def test_step(self, batch, batch_idx, *args, **kwargs): + return model_template.test_step__multiple_dataloaders(batch, batch_idx, *args, **kwargs) + + model = MultipleTestDataloaderModel() # fit model trainer = Trainer( @@ -112,7 +120,9 @@ def test_multiple_test_dataloader(tmpdir): train_percent_check=0.2 ) trainer.fit(model) - trainer.test() + if ckpt_path == 'specific': + ckpt_path = trainer.checkpoint_callback.best_model_path + trainer.test(ckpt_path=ckpt_path) # verify there are 2 test loaders assert len(trainer.test_dataloaders) == 2, \ @@ -123,7 +133,7 @@ def test_multiple_test_dataloader(tmpdir): tutils.run_prediction(dataloader, trainer.model) # run the test method - trainer.test() + trainer.test(ckpt_path=ckpt_path) def test_train_dataloader_passed_to_fit(tmpdir): @@ -163,7 +173,8 @@ def test_train_val_dataloaders_passed_to_fit(tmpdir): f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' -def test_all_dataloaders_passed_to_fit(tmpdir): +@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) +def test_all_dataloaders_passed_to_fit(tmpdir, ckpt_path): """Verify train, val & test dataloader(s) can be passed to fit and test method""" model = EvalModelTemplate() @@ -177,9 +188,12 @@ def test_all_dataloaders_passed_to_fit(tmpdir): ) fit_options = dict(train_dataloader=model.dataloader(train=True), val_dataloaders=model.dataloader(train=False)) - test_options = dict(test_dataloaders=model.dataloader(train=False)) - result = trainer.fit(model, **fit_options) + + if ckpt_path == 'specific': + ckpt_path = trainer.checkpoint_callback.best_model_path + test_options = dict(test_dataloaders=model.dataloader(train=False), + ckpt_path=ckpt_path) trainer.test(**test_options) assert result == 1 @@ -189,7 +203,8 @@ def test_all_dataloaders_passed_to_fit(tmpdir): f'test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' -def test_multiple_dataloaders_passed_to_fit(tmpdir): +@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) +def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path): """Verify that multiple val & test dataloaders can be passed to fit.""" model = EvalModelTemplate() @@ -207,10 +222,12 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir): fit_options = dict(train_dataloader=model.dataloader(train=True), val_dataloaders=[model.dataloader(train=False), model.dataloader(train=False)]) - test_options = dict(test_dataloaders=[model.dataloader(train=False), - model.dataloader(train=False)]) - trainer.fit(model, **fit_options) + if ckpt_path == 'specific': + ckpt_path = trainer.checkpoint_callback.best_model_path + test_options = dict(test_dataloaders=[model.dataloader(train=False), + model.dataloader(train=False)], + ckpt_path=ckpt_path) trainer.test(**test_options) assert len(trainer.val_dataloaders) == 2, \ @@ -219,7 +236,8 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir): f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' -def test_mixing_of_dataloader_options(tmpdir): +@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) +def test_mixing_of_dataloader_options(tmpdir, ckpt_path): """Verify that dataloaders can be passed to fit""" model = EvalModelTemplate() @@ -240,7 +258,9 @@ def test_mixing_of_dataloader_options(tmpdir): trainer = Trainer(**trainer_options) results = trainer.fit(model, val_dataloaders=model.dataloader(train=False)) assert results - trainer.test(test_dataloaders=model.dataloader(train=False)) + if ckpt_path == 'specific': + ckpt_path = trainer.checkpoint_callback.best_model_path + trainer.test(test_dataloaders=model.dataloader(train=False), ckpt_path=ckpt_path) assert len(trainer.val_dataloaders) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' @@ -341,7 +361,8 @@ def test_error_on_zero_len_dataloader(tmpdir): @pytest.mark.skipif(platform.system() == 'Windows', reason='Does not apply to Windows platform.') -def test_warning_with_few_workers(tmpdir): +@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) +def test_warning_with_few_workers(tmpdir, ckpt_path): """ Test that error is raised if dataloader with only a few workers is used """ model = EvalModelTemplate() @@ -365,8 +386,6 @@ def test_warning_with_few_workers(tmpdir): fit_options = dict(train_dataloader=train_dl, val_dataloaders=val_dl) - test_options = dict(test_dataloaders=train_dl) - trainer = Trainer(**trainer_options) # fit model @@ -376,6 +395,9 @@ def test_warning_with_few_workers(tmpdir): with pytest.warns(UserWarning, match='val'): trainer.fit(model, **fit_options) + if ckpt_path == 'specific': + ckpt_path = trainer.checkpoint_callback.best_model_path + test_options = dict(test_dataloaders=train_dl, ckpt_path=ckpt_path) with pytest.warns(UserWarning, match='test'): trainer.test(**test_options) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c5965da3c0b16..e397f6f132d52 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -5,6 +5,7 @@ import types import sys from argparse import Namespace +from pathlib import Path import cloudpickle import pytest @@ -540,6 +541,52 @@ def test_testpass_overrides(tmpdir): Trainer().test(model) +@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) +@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2]) +def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k): + hparams = EvalModelTemplate.get_default_hparams() + + loaded_checkpoint_path = '' + + class TestBestModel(EvalModelTemplate): + @classmethod + def load_from_checkpoint(cls, checkpoint_path, *args, **kwargs): + nonlocal loaded_checkpoint_path + loaded_checkpoint_path = checkpoint_path + return super().load_from_checkpoint(checkpoint_path, *args, **kwargs) + + model = TestBestModel(**hparams) + trainer = Trainer( + max_epochs=2, + progress_bar_refresh_rate=0, + default_root_dir=tmpdir, + checkpoint_callback=ModelCheckpoint(save_top_k=save_top_k), + ) + trainer.fit(model) + if ckpt_path == 'best': + # ckpt_path is 'best', meaning we load the best weights + if save_top_k <= 0: + with pytest.raises(MisconfigurationException, match='.*is not configured to save the best.*'): + trainer.test(ckpt_path=ckpt_path) + else: + trainer.test(ckpt_path=ckpt_path) + assert loaded_checkpoint_path == trainer.checkpoint_callback.best_model_path + elif ckpt_path is None: + # ckpt_path is None, meaning we don't load any checkpoints and + # use the weights from the end of training + trainer.test(ckpt_path=ckpt_path) + assert loaded_checkpoint_path == '' + else: + # specific checkpoint, pick one from saved ones + if save_top_k == 0: + with pytest.raises(FileNotFoundError): + trainer.test(ckpt_path='random.ckpt') + else: + ckpt_path = str(list((Path(tmpdir) / 'lightning_logs/version_0/checkpoints').iterdir())[0]) + trainer.test(ckpt_path=ckpt_path) + assert loaded_checkpoint_path == ckpt_path + + def test_disabled_validation(): """Verify that `val_percent_check=0` disables the validation loop unless `fast_dev_run=True`."""