Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ckpt_path option to LightningModule.test() #2190

Merged
merged 2 commits into from
Jun 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions docs/source/test_set.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,39 @@ Lightning forces the user to run the test set separately to make sure it isn't e

Test after fit
--------------
To run the test set after training completes, use this method
To run the test set after training completes, use this method.

.. code-block:: python
# run full training
trainer.fit(model)
# run test set
# (1) load the best checkpoint automatically (lightning tracks this for you)
trainer.test()
# (2) don't load a checkpoint, instead use the model with the latest weights
trainer.test(ckpt_path=None)
# (3) test using a specific checkpoint
trainer.test(ckpt_path='/path/to/my_checkpoint.ckpt')
# (4) test with an explicit model (will use this model and not load a checkpoint)
trainer.test(model)
Test multiple models
--------------------
You can run the test set on multiple models using the same trainer instance.

.. code-block:: python
model1 = LitModel()
model2 = GANModel()
trainer = Trainer()
trainer.test(model1)
trainer.test(model2)
Test pre-trained model
----------------------
Expand Down
39 changes: 36 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,8 @@ def run_pretrain_routine(self, model: LightningModule):
def test(
self,
model: Optional[LightningModule] = None,
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
ckpt_path: Optional[str] = 'best'
):
r"""

Expand All @@ -1019,10 +1020,13 @@ def test(
test_dataloaders: Either a single
Pytorch Dataloader or a list of them, specifying validation samples.

ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
If ``None``, use the weights from the last epoch to test. Default to ``best``.

Example::

# Option 1
# run test after fitting
# run test with the best checkpoint from ``ModelCheckpoint`` after fitting.
test = DataLoader(...)
trainer = Trainer()
model = LightningModule()
Expand All @@ -1031,12 +1035,41 @@ def test(
trainer.test(test_dataloaders=test)

# Option 2
# run test from a loaded model
# run test with the specified checkpoint after fitting
test = DataLoader(...)
trainer = Trainer()
model = LightningModule()

trainer.fit(model)
trainer.test(test_dataloaders=test, ckpt_path='path/to/checkpoint.ckpt')

# Option 3
# run test with the weights from the end of training after fitting
test = DataLoader(...)
trainer = Trainer()
model = LightningModule()

trainer.fit(model)
trainer.test(test_dataloaders=test, ckpt_path=None)

# Option 4
# run test from a loaded model. ``ckpt_path`` is ignored in this case.
test = DataLoader(...)
model = LightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
trainer = Trainer()
trainer.test(model, test_dataloaders=test)
"""
if model is None and ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0:
raise MisconfigurationException(
'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.')

# if model is not given (None), ckpt_path is given,
# load the given checkpoint for testing
if model is None and ckpt_path is not None:
# ckpt_path is 'best' so load the best model
if ckpt_path == 'best':
ckpt_path = self.checkpoint_callback.best_model_path
model = self.get_model().load_from_checkpoint(ckpt_path)

self.testing = True

Expand Down
7 changes: 6 additions & 1 deletion tests/base/model_valid_epoch_ends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class ValidationEpochEndVariations(ABC):
"""
Houses all variations of validation_epoch_end steps
"""

def validation_epoch_end(self, outputs):
"""
Called at the end of validation to aggregate outputs
Expand Down Expand Up @@ -50,5 +51,9 @@ def _mean(res, key):
pbar[key] = metric_out
logs[key] = metric_out

results = {'progress_bar': pbar, 'log': logs}
results = {
'val_loss': torch.stack([v for k, v in pbar.items() if k.startswith('val_loss')]).mean(),
'progress_bar': pbar,
'log': logs
}
return results
2 changes: 1 addition & 1 deletion tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_load_model_from_checkpoint(tmpdir):
# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
trainer.test()
trainer.test(ckpt_path=None)

# correct result and ok accuracy
assert result == 1, 'training failed to complete'
Expand Down
58 changes: 40 additions & 18 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,20 @@ def test_multiple_val_dataloader(tmpdir):
tutils.run_prediction(dataloader, trainer.model)


def test_multiple_test_dataloader(tmpdir):
@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
def test_multiple_test_dataloader(tmpdir, ckpt_path):
"""Verify multiple test_dataloader."""

model = EvalModelTemplate()
model.test_dataloader = model.test_dataloader__multiple
model.test_step = model.test_step__multiple_dataloaders
model_template = EvalModelTemplate()

class MultipleTestDataloaderModel(EvalModelTemplate):
def test_dataloader(self):
return model_template.test_dataloader__multiple()

def test_step(self, batch, batch_idx, *args, **kwargs):
return model_template.test_step__multiple_dataloaders(batch, batch_idx, *args, **kwargs)

model = MultipleTestDataloaderModel()

# fit model
trainer = Trainer(
Expand All @@ -112,7 +120,9 @@ def test_multiple_test_dataloader(tmpdir):
train_percent_check=0.2
)
trainer.fit(model)
trainer.test()
if ckpt_path == 'specific':
ckpt_path = trainer.checkpoint_callback.best_model_path
trainer.test(ckpt_path=ckpt_path)

# verify there are 2 test loaders
assert len(trainer.test_dataloaders) == 2, \
Expand All @@ -123,7 +133,7 @@ def test_multiple_test_dataloader(tmpdir):
tutils.run_prediction(dataloader, trainer.model)

# run the test method
trainer.test()
trainer.test(ckpt_path=ckpt_path)


def test_train_dataloader_passed_to_fit(tmpdir):
Expand Down Expand Up @@ -163,7 +173,8 @@ def test_train_val_dataloaders_passed_to_fit(tmpdir):
f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'


def test_all_dataloaders_passed_to_fit(tmpdir):
@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
def test_all_dataloaders_passed_to_fit(tmpdir, ckpt_path):
"""Verify train, val & test dataloader(s) can be passed to fit and test method"""

model = EvalModelTemplate()
Expand All @@ -177,9 +188,12 @@ def test_all_dataloaders_passed_to_fit(tmpdir):
)
fit_options = dict(train_dataloader=model.dataloader(train=True),
val_dataloaders=model.dataloader(train=False))
test_options = dict(test_dataloaders=model.dataloader(train=False))

result = trainer.fit(model, **fit_options)

if ckpt_path == 'specific':
ckpt_path = trainer.checkpoint_callback.best_model_path
test_options = dict(test_dataloaders=model.dataloader(train=False),
ckpt_path=ckpt_path)
trainer.test(**test_options)

assert result == 1
Expand All @@ -189,7 +203,8 @@ def test_all_dataloaders_passed_to_fit(tmpdir):
f'test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'


def test_multiple_dataloaders_passed_to_fit(tmpdir):
@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
"""Verify that multiple val & test dataloaders can be passed to fit."""

model = EvalModelTemplate()
Expand All @@ -207,10 +222,12 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir):
fit_options = dict(train_dataloader=model.dataloader(train=True),
val_dataloaders=[model.dataloader(train=False),
model.dataloader(train=False)])
test_options = dict(test_dataloaders=[model.dataloader(train=False),
model.dataloader(train=False)])

trainer.fit(model, **fit_options)
if ckpt_path == 'specific':
ckpt_path = trainer.checkpoint_callback.best_model_path
test_options = dict(test_dataloaders=[model.dataloader(train=False),
model.dataloader(train=False)],
ckpt_path=ckpt_path)
trainer.test(**test_options)

assert len(trainer.val_dataloaders) == 2, \
Expand All @@ -219,7 +236,8 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir):
f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'


def test_mixing_of_dataloader_options(tmpdir):
@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
"""Verify that dataloaders can be passed to fit"""

model = EvalModelTemplate()
Expand All @@ -240,7 +258,9 @@ def test_mixing_of_dataloader_options(tmpdir):
trainer = Trainer(**trainer_options)
results = trainer.fit(model, val_dataloaders=model.dataloader(train=False))
assert results
trainer.test(test_dataloaders=model.dataloader(train=False))
if ckpt_path == 'specific':
ckpt_path = trainer.checkpoint_callback.best_model_path
trainer.test(test_dataloaders=model.dataloader(train=False), ckpt_path=ckpt_path)

assert len(trainer.val_dataloaders) == 1, \
f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
Expand Down Expand Up @@ -341,7 +361,8 @@ def test_error_on_zero_len_dataloader(tmpdir):


@pytest.mark.skipif(platform.system() == 'Windows', reason='Does not apply to Windows platform.')
def test_warning_with_few_workers(tmpdir):
@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
def test_warning_with_few_workers(tmpdir, ckpt_path):
""" Test that error is raised if dataloader with only a few workers is used """

model = EvalModelTemplate()
Expand All @@ -365,8 +386,6 @@ def test_warning_with_few_workers(tmpdir):

fit_options = dict(train_dataloader=train_dl,
val_dataloaders=val_dl)
test_options = dict(test_dataloaders=train_dl)

trainer = Trainer(**trainer_options)

# fit model
Expand All @@ -376,6 +395,9 @@ def test_warning_with_few_workers(tmpdir):
with pytest.warns(UserWarning, match='val'):
trainer.fit(model, **fit_options)

if ckpt_path == 'specific':
ckpt_path = trainer.checkpoint_callback.best_model_path
test_options = dict(test_dataloaders=train_dl, ckpt_path=ckpt_path)
with pytest.warns(UserWarning, match='test'):
trainer.test(**test_options)

Expand Down
47 changes: 47 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pickle
import types
from argparse import Namespace
from pathlib import Path

import cloudpickle
import pytest
Expand Down Expand Up @@ -539,6 +540,52 @@ def test_testpass_overrides(tmpdir):
Trainer().test(model)


@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2])
def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k):
hparams = EvalModelTemplate.get_default_hparams()

loaded_checkpoint_path = ''

class TestBestModel(EvalModelTemplate):
@classmethod
def load_from_checkpoint(cls, checkpoint_path, *args, **kwargs):
nonlocal loaded_checkpoint_path
loaded_checkpoint_path = checkpoint_path
return super().load_from_checkpoint(checkpoint_path, *args, **kwargs)

model = TestBestModel(**hparams)
trainer = Trainer(
max_epochs=2,
progress_bar_refresh_rate=0,
default_root_dir=tmpdir,
checkpoint_callback=ModelCheckpoint(save_top_k=save_top_k),
)
trainer.fit(model)
if ckpt_path == 'best':
# ckpt_path is 'best', meaning we load the best weights
if save_top_k <= 0:
with pytest.raises(MisconfigurationException, match='.*is not configured to save the best.*'):
trainer.test(ckpt_path=ckpt_path)
else:
trainer.test(ckpt_path=ckpt_path)
assert loaded_checkpoint_path == trainer.checkpoint_callback.best_model_path
elif ckpt_path is None:
# ckpt_path is None, meaning we don't load any checkpoints and
# use the weights from the end of training
trainer.test(ckpt_path=ckpt_path)
assert loaded_checkpoint_path == ''
else:
# specific checkpoint, pick one from saved ones
if save_top_k == 0:
with pytest.raises(FileNotFoundError):
trainer.test(ckpt_path='random.ckpt')
else:
ckpt_path = str(list((Path(tmpdir) / 'lightning_logs/version_0/checkpoints').iterdir())[0])
trainer.test(ckpt_path=ckpt_path)
assert loaded_checkpoint_path == ckpt_path


def test_disabled_validation():
"""Verify that `val_percent_check=0` disables the validation loop unless `fast_dev_run=True`."""

Expand Down