Skip to content

Commit

Permalink
Add ckpt_path option to LightningModule.test() (#2190)
Browse files Browse the repository at this point in the history
* Add ckpt_path option to LightningModule.test()

If ckpt_path is "best" (default), it loads the best weights saved by ModelCheckpoint for the test loop.
If ckpt_path is a path to a checkpoint file, it loads the weights from the file for the test loop.
If ckpt_path is None, it uses the weights from the end of training for the test loop.
If model parameter is set, ckpt_path is ignored.

* Update test_set.rst

Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
yukw777 and williamFalcon committed Jun 15, 2020
1 parent 48a76a7 commit 37e7582
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 25 deletions.
27 changes: 25 additions & 2 deletions docs/source/test_set.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,39 @@ Lightning forces the user to run the test set separately to make sure it isn't e

Test after fit
--------------
To run the test set after training completes, use this method
To run the test set after training completes, use this method.

.. code-block:: python
# run full training
trainer.fit(model)
# run test set
# (1) load the best checkpoint automatically (lightning tracks this for you)
trainer.test()
# (2) don't load a checkpoint, instead use the model with the latest weights
trainer.test(ckpt_path=None)
# (3) test using a specific checkpoint
trainer.test(ckpt_path='/path/to/my_checkpoint.ckpt')
# (4) test with an explicit model (will use this model and not load a checkpoint)
trainer.test(model)
Test multiple models
--------------------
You can run the test set on multiple models using the same trainer instance.

.. code-block:: python
model1 = LitModel()
model2 = GANModel()
trainer = Trainer()
trainer.test(model1)
trainer.test(model2)
Test pre-trained model
----------------------
Expand Down
39 changes: 36 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,8 @@ def run_pretrain_routine(self, model: LightningModule):
def test(
self,
model: Optional[LightningModule] = None,
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
ckpt_path: Optional[str] = 'best'
):
r"""
Expand All @@ -1019,10 +1020,13 @@ def test(
test_dataloaders: Either a single
Pytorch Dataloader or a list of them, specifying validation samples.
ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
If ``None``, use the weights from the last epoch to test. Default to ``best``.
Example::
# Option 1
# run test after fitting
# run test with the best checkpoint from ``ModelCheckpoint`` after fitting.
test = DataLoader(...)
trainer = Trainer()
model = LightningModule()
Expand All @@ -1031,12 +1035,41 @@ def test(
trainer.test(test_dataloaders=test)
# Option 2
# run test from a loaded model
# run test with the specified checkpoint after fitting
test = DataLoader(...)
trainer = Trainer()
model = LightningModule()
trainer.fit(model)
trainer.test(test_dataloaders=test, ckpt_path='path/to/checkpoint.ckpt')
# Option 3
# run test with the weights from the end of training after fitting
test = DataLoader(...)
trainer = Trainer()
model = LightningModule()
trainer.fit(model)
trainer.test(test_dataloaders=test, ckpt_path=None)
# Option 4
# run test from a loaded model. ``ckpt_path`` is ignored in this case.
test = DataLoader(...)
model = LightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
trainer = Trainer()
trainer.test(model, test_dataloaders=test)
"""
if model is None and ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0:
raise MisconfigurationException(
'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.')

# if model is not given (None), ckpt_path is given,
# load the given checkpoint for testing
if model is None and ckpt_path is not None:
# ckpt_path is 'best' so load the best model
if ckpt_path == 'best':
ckpt_path = self.checkpoint_callback.best_model_path
model = self.get_model().load_from_checkpoint(ckpt_path)

self.testing = True

Expand Down
7 changes: 6 additions & 1 deletion tests/base/model_valid_epoch_ends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class ValidationEpochEndVariations(ABC):
"""
Houses all variations of validation_epoch_end steps
"""

def validation_epoch_end(self, outputs):
"""
Called at the end of validation to aggregate outputs
Expand Down Expand Up @@ -50,5 +51,9 @@ def _mean(res, key):
pbar[key] = metric_out
logs[key] = metric_out

results = {'progress_bar': pbar, 'log': logs}
results = {
'val_loss': torch.stack([v for k, v in pbar.items() if k.startswith('val_loss')]).mean(),
'progress_bar': pbar,
'log': logs
}
return results
2 changes: 1 addition & 1 deletion tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_load_model_from_checkpoint(tmpdir):
# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
trainer.test()
trainer.test(ckpt_path=None)

# correct result and ok accuracy
assert result == 1, 'training failed to complete'
Expand Down
58 changes: 40 additions & 18 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,20 @@ def test_multiple_val_dataloader(tmpdir):
tutils.run_prediction(dataloader, trainer.model)


def test_multiple_test_dataloader(tmpdir):
@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
def test_multiple_test_dataloader(tmpdir, ckpt_path):
"""Verify multiple test_dataloader."""

model = EvalModelTemplate()
model.test_dataloader = model.test_dataloader__multiple
model.test_step = model.test_step__multiple_dataloaders
model_template = EvalModelTemplate()

class MultipleTestDataloaderModel(EvalModelTemplate):
def test_dataloader(self):
return model_template.test_dataloader__multiple()

def test_step(self, batch, batch_idx, *args, **kwargs):
return model_template.test_step__multiple_dataloaders(batch, batch_idx, *args, **kwargs)

model = MultipleTestDataloaderModel()

# fit model
trainer = Trainer(
Expand All @@ -112,7 +120,9 @@ def test_multiple_test_dataloader(tmpdir):
train_percent_check=0.2
)
trainer.fit(model)
trainer.test()
if ckpt_path == 'specific':
ckpt_path = trainer.checkpoint_callback.best_model_path
trainer.test(ckpt_path=ckpt_path)

# verify there are 2 test loaders
assert len(trainer.test_dataloaders) == 2, \
Expand All @@ -123,7 +133,7 @@ def test_multiple_test_dataloader(tmpdir):
tutils.run_prediction(dataloader, trainer.model)

# run the test method
trainer.test()
trainer.test(ckpt_path=ckpt_path)


def test_train_dataloader_passed_to_fit(tmpdir):
Expand Down Expand Up @@ -163,7 +173,8 @@ def test_train_val_dataloaders_passed_to_fit(tmpdir):
f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'


def test_all_dataloaders_passed_to_fit(tmpdir):
@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
def test_all_dataloaders_passed_to_fit(tmpdir, ckpt_path):
"""Verify train, val & test dataloader(s) can be passed to fit and test method"""

model = EvalModelTemplate()
Expand All @@ -177,9 +188,12 @@ def test_all_dataloaders_passed_to_fit(tmpdir):
)
fit_options = dict(train_dataloader=model.dataloader(train=True),
val_dataloaders=model.dataloader(train=False))
test_options = dict(test_dataloaders=model.dataloader(train=False))

result = trainer.fit(model, **fit_options)

if ckpt_path == 'specific':
ckpt_path = trainer.checkpoint_callback.best_model_path
test_options = dict(test_dataloaders=model.dataloader(train=False),
ckpt_path=ckpt_path)
trainer.test(**test_options)

assert result == 1
Expand All @@ -189,7 +203,8 @@ def test_all_dataloaders_passed_to_fit(tmpdir):
f'test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'


def test_multiple_dataloaders_passed_to_fit(tmpdir):
@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
"""Verify that multiple val & test dataloaders can be passed to fit."""

model = EvalModelTemplate()
Expand All @@ -207,10 +222,12 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir):
fit_options = dict(train_dataloader=model.dataloader(train=True),
val_dataloaders=[model.dataloader(train=False),
model.dataloader(train=False)])
test_options = dict(test_dataloaders=[model.dataloader(train=False),
model.dataloader(train=False)])

trainer.fit(model, **fit_options)
if ckpt_path == 'specific':
ckpt_path = trainer.checkpoint_callback.best_model_path
test_options = dict(test_dataloaders=[model.dataloader(train=False),
model.dataloader(train=False)],
ckpt_path=ckpt_path)
trainer.test(**test_options)

assert len(trainer.val_dataloaders) == 2, \
Expand All @@ -219,7 +236,8 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir):
f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'


def test_mixing_of_dataloader_options(tmpdir):
@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
"""Verify that dataloaders can be passed to fit"""

model = EvalModelTemplate()
Expand All @@ -240,7 +258,9 @@ def test_mixing_of_dataloader_options(tmpdir):
trainer = Trainer(**trainer_options)
results = trainer.fit(model, val_dataloaders=model.dataloader(train=False))
assert results
trainer.test(test_dataloaders=model.dataloader(train=False))
if ckpt_path == 'specific':
ckpt_path = trainer.checkpoint_callback.best_model_path
trainer.test(test_dataloaders=model.dataloader(train=False), ckpt_path=ckpt_path)

assert len(trainer.val_dataloaders) == 1, \
f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
Expand Down Expand Up @@ -341,7 +361,8 @@ def test_error_on_zero_len_dataloader(tmpdir):


@pytest.mark.skipif(platform.system() == 'Windows', reason='Does not apply to Windows platform.')
def test_warning_with_few_workers(tmpdir):
@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
def test_warning_with_few_workers(tmpdir, ckpt_path):
""" Test that error is raised if dataloader with only a few workers is used """

model = EvalModelTemplate()
Expand All @@ -365,8 +386,6 @@ def test_warning_with_few_workers(tmpdir):

fit_options = dict(train_dataloader=train_dl,
val_dataloaders=val_dl)
test_options = dict(test_dataloaders=train_dl)

trainer = Trainer(**trainer_options)

# fit model
Expand All @@ -376,6 +395,9 @@ def test_warning_with_few_workers(tmpdir):
with pytest.warns(UserWarning, match='val'):
trainer.fit(model, **fit_options)

if ckpt_path == 'specific':
ckpt_path = trainer.checkpoint_callback.best_model_path
test_options = dict(test_dataloaders=train_dl, ckpt_path=ckpt_path)
with pytest.warns(UserWarning, match='test'):
trainer.test(**test_options)

Expand Down
47 changes: 47 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import types
import sys
from argparse import Namespace
from pathlib import Path

import cloudpickle
import pytest
Expand Down Expand Up @@ -540,6 +541,52 @@ def test_testpass_overrides(tmpdir):
Trainer().test(model)


@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2])
def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k):
hparams = EvalModelTemplate.get_default_hparams()

loaded_checkpoint_path = ''

class TestBestModel(EvalModelTemplate):
@classmethod
def load_from_checkpoint(cls, checkpoint_path, *args, **kwargs):
nonlocal loaded_checkpoint_path
loaded_checkpoint_path = checkpoint_path
return super().load_from_checkpoint(checkpoint_path, *args, **kwargs)

model = TestBestModel(**hparams)
trainer = Trainer(
max_epochs=2,
progress_bar_refresh_rate=0,
default_root_dir=tmpdir,
checkpoint_callback=ModelCheckpoint(save_top_k=save_top_k),
)
trainer.fit(model)
if ckpt_path == 'best':
# ckpt_path is 'best', meaning we load the best weights
if save_top_k <= 0:
with pytest.raises(MisconfigurationException, match='.*is not configured to save the best.*'):
trainer.test(ckpt_path=ckpt_path)
else:
trainer.test(ckpt_path=ckpt_path)
assert loaded_checkpoint_path == trainer.checkpoint_callback.best_model_path
elif ckpt_path is None:
# ckpt_path is None, meaning we don't load any checkpoints and
# use the weights from the end of training
trainer.test(ckpt_path=ckpt_path)
assert loaded_checkpoint_path == ''
else:
# specific checkpoint, pick one from saved ones
if save_top_k == 0:
with pytest.raises(FileNotFoundError):
trainer.test(ckpt_path='random.ckpt')
else:
ckpt_path = str(list((Path(tmpdir) / 'lightning_logs/version_0/checkpoints').iterdir())[0])
trainer.test(ckpt_path=ckpt_path)
assert loaded_checkpoint_path == ckpt_path


def test_disabled_validation():
"""Verify that `val_percent_check=0` disables the validation loop unless `fast_dev_run=True`."""

Expand Down

0 comments on commit 37e7582

Please sign in to comment.