Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes for early stopping and checkpoint callbacks #1504

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,6 @@ def training_step(self, batch, batch_idx):

"""

import atexit
import signal
import subprocess
from abc import ABC, abstractmethod
from typing import Callable
Expand Down
84 changes: 84 additions & 0 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def on_test_end(self, trainer, pl_module):
test_callback = TestCallback()

trainer_options = dict(
default_root_dir=tmpdir,
callbacks=[test_callback],
max_epochs=1,
limit_val_batches=0.1,
Expand Down Expand Up @@ -261,3 +262,86 @@ def on_test_end(self, trainer, pl_module):
assert not test_callback.on_validation_end_called
assert not test_callback.on_validation_batch_end_called
assert not test_callback.on_validation_batch_start_called


def test_early_stopping_no_val_step(tmpdir):
"""Test that early stopping callback falls back to training metrics when no validation defined."""

class CurrentModel(EvalModelTemplate):
def training_step(self, *args, **kwargs):
output = super().training_step(*args, **kwargs)
output.update({'my_train_metric': output['loss']}) # could be anything else
return output

model = CurrentModel()
model.validation_step = None
model.val_dataloader = None

stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=stopping,
overfit_batches=0.20,
max_epochs=2,
)
result = trainer.fit(model)

assert result == 1, 'training failed to complete'
assert trainer.current_epoch <= trainer.max_epochs


def test_pickling(tmpdir):
import pickle
early_stopping = EarlyStopping()
ckpt = ModelCheckpoint(tmpdir)

early_stopping_pickled = pickle.dumps(early_stopping)
ckpt_pickled = pickle.dumps(ckpt)

early_stopping_loaded = pickle.loads(early_stopping_pickled)
ckpt_loaded = pickle.loads(ckpt_pickled)

assert vars(early_stopping) == vars(early_stopping_loaded)
assert vars(ckpt) == vars(ckpt_loaded)


@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2])
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
""" Test that None in checkpoint callback is valid and that chkp_path is set correctly """
tutils.reset_seed()
model = EvalModelTemplate()

checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k)

trainer = Trainer(
default_root_dir=tmpdir,
checkpoint_callback=checkpoint,
overfit_batches=0.20,
max_epochs=2,
)
trainer.fit(model)

# These should be different if the dirpath has be overridden
assert trainer.ckpt_path != trainer.default_root_dir


@pytest.mark.parametrize(
'logger_version,expected',
[(None, 'version_0'), (1, 'version_1'), ('awesome', 'awesome')],
)
def test_model_checkpoint_path(tmpdir, logger_version, expected):
"""Test that "version_" prefix is only added when logger's version is an integer"""
tutils.reset_seed()
model = EvalModelTemplate()
logger = TensorBoardLogger(str(tmpdir), version=logger_version)

trainer = Trainer(
default_root_dir=tmpdir,
overfit_batches=0.2,
max_epochs=2,
logger=logger,
)
trainer.fit(model)

ckpt_version = Path(trainer.ckpt_path).parent.name
assert ckpt_version == expected
35 changes: 27 additions & 8 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ def on_train_start(self, trainer, pl_module):
model = EvalModelTemplate()
checkpoint_callback = ModelCheckpoint(save_top_k=1)
early_stop_callback = EarlyStopping()
trainer = Trainer(checkpoint_callback=checkpoint_callback, early_stop_callback=early_stop_callback, max_epochs=4)
trainer = Trainer(
default_root_dir=tmpdir,
checkpoint_callback=checkpoint_callback,
early_stop_callback=early_stop_callback,
max_epochs=4,
)
trainer.fit(model)
early_stop_callback_state = early_stop_callback.state_dict()

Expand All @@ -38,13 +43,16 @@ def on_train_start(self, trainer, pl_module):
assert checkpoint['early_stop_callback_state_dict'] == early_stop_callback_state
# ensure state is reloaded properly (assertion in the callback)
early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state)
new_trainer = Trainer(max_epochs=2,
resume_from_checkpoint=checkpoint_filepath,
early_stop_callback=early_stop_callback)
new_trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
resume_from_checkpoint=checkpoint_filepath,
early_stop_callback=early_stop_callback,
)
new_trainer.fit(model)


def test_early_stopping_no_extraneous_invocations():
def test_early_stopping_no_extraneous_invocations(tmpdir):
"""Test to ensure that callback methods aren't being invoked outside of the callback handler."""
class EarlyStoppingTestInvocations(EarlyStopping):
def __init__(self, expected_count):
Expand All @@ -61,7 +69,12 @@ def on_train_end(self, trainer, pl_module):
model = EvalModelTemplate()
expected_count = 4
early_stop_callback = EarlyStoppingTestInvocations(expected_count)
trainer = Trainer(early_stop_callback=early_stop_callback, val_check_interval=1.0, max_epochs=expected_count)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=early_stop_callback,
val_check_interval=1.0,
max_epochs=expected_count,
)
trainer.fit(model)


Expand All @@ -70,7 +83,7 @@ def on_train_end(self, trainer, pl_module):
([6, 5, 4, 4, 3, 3], 1, 3),
([6, 5, 6, 5, 5, 5], 3, 4),
])
def test_early_stopping_patience(loss_values, patience, expected_stop_epoch):
def test_early_stopping_patience(tmpdir, loss_values, patience, expected_stop_epoch):
"""Test to ensure that early stopping is not triggered before patience is exhausted."""

class ModelOverrideValidationReturn(EvalModelTemplate):
Expand All @@ -84,7 +97,13 @@ def validation_epoch_end(self, outputs):

model = ModelOverrideValidationReturn()
early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True)
trainer = Trainer(early_stop_callback=early_stop_callback, val_check_interval=1.0, num_sanity_val_steps=0, max_epochs=10)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=early_stop_callback,
val_check_interval=1.0,
num_sanity_val_steps=0,
max_epochs=10,
)
trainer.fit(model)
assert trainer.current_epoch == expected_stop_epoch

Expand Down
8 changes: 4 additions & 4 deletions tests/callbacks/test_lr_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_lr_logger_single_lr(tmpdir):
max_epochs=2,
limit_val_batches=0.1,
limit_train_batches=0.5,
callbacks=[lr_logger]
callbacks=[lr_logger],
)
result = trainer.fit(model)
assert result
Expand All @@ -42,7 +42,7 @@ def test_lr_logger_no_lr(tmpdir):
max_epochs=2,
limit_val_batches=0.1,
limit_train_batches=0.5,
callbacks=[lr_logger]
callbacks=[lr_logger],
)

with pytest.warns(RuntimeWarning):
Expand All @@ -63,7 +63,7 @@ def test_lr_logger_multi_lrs(tmpdir):
max_epochs=2,
limit_val_batches=0.1,
limit_train_batches=0.5,
callbacks=[lr_logger]
callbacks=[lr_logger],
)
result = trainer.fit(model)
assert result
Expand All @@ -90,7 +90,7 @@ def test_lr_logger_param_groups(tmpdir):
max_epochs=2,
limit_val_batches=0.1,
limit_train_batches=0.5,
callbacks=[lr_logger]
callbacks=[lr_logger],
)
result = trainer.fit(model)
assert result
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
default_root_dir=tmpdir,
checkpoint_callback=checkpoint,
overfit_pct=0.20,
max_epochs=5
max_epochs=5,
)
trainer.fit(model)

Expand All @@ -44,7 +44,7 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected):
default_root_dir=tmpdir,
overfit_pct=0.2,
max_epochs=5,
logger=logger
logger=logger,
)
trainer.fit(model)

Expand Down
12 changes: 8 additions & 4 deletions tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
([ProgressBar(refresh_rate=2)], 0),
([ProgressBar(refresh_rate=2)], 1),
])
def test_progress_bar_on(callbacks, refresh_rate):
def test_progress_bar_on(tmpdir, callbacks, refresh_rate):
"""Test different ways the progress bar can be turned on."""

trainer = Trainer(
default_root_dir=tmpdir,
callbacks=callbacks,
progress_bar_refresh_rate=refresh_rate,
max_epochs=1,
Expand Down Expand Up @@ -54,12 +55,13 @@ def test_progress_bar_misconfiguration():
Trainer(callbacks=callbacks)


def test_progress_bar_totals():
def test_progress_bar_totals(tmpdir):
"""Test that the progress finishes with the correct total steps processed."""

model = EvalModelTemplate()

trainer = Trainer(
default_root_dir=tmpdir,
progress_bar_refresh_rate=1,
limit_val_batches=1.0,
max_epochs=1,
Expand Down Expand Up @@ -105,10 +107,11 @@ def test_progress_bar_totals():
assert bar.test_batch_idx == k


def test_progress_bar_fast_dev_run():
def test_progress_bar_fast_dev_run(tmpdir):
model = EvalModelTemplate()

trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
)

Expand Down Expand Up @@ -136,7 +139,7 @@ def test_progress_bar_fast_dev_run():


@pytest.mark.parametrize('refresh_rate', [0, 1, 50])
def test_progress_bar_progress_refresh(refresh_rate):
def test_progress_bar_progress_refresh(tmpdir, refresh_rate):
"""Test that the three progress bars get correctly updated when using different refresh rates."""

model = EvalModelTemplate()
Expand Down Expand Up @@ -172,6 +175,7 @@ def on_test_batch_end(self, trainer, pl_module):

progress_bar = CurrentProgressBar(refresh_rate=refresh_rate)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[progress_bar],
progress_bar_refresh_rate=101, # should not matter if custom callback provided
limit_train_batches=1.0,
Expand Down
1 change: 1 addition & 0 deletions tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def log_metrics(self, metrics, step):
logger = StoreHistoryLogger(**logger_args)

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
logger=logger,
limit_train_batches=0.2,
Expand Down
10 changes: 7 additions & 3 deletions tests/loggers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_custom_logger(tmpdir):
max_epochs=1,
limit_train_batches=0.05,
logger=logger,
default_root_dir=tmpdir
default_root_dir=tmpdir,
)
result = trainer.fit(model)
assert result == 1, "Training failed"
Expand All @@ -88,7 +88,7 @@ def test_multiple_loggers(tmpdir):
max_epochs=1,
limit_train_batches=0.05,
logger=[logger1, logger2],
default_root_dir=tmpdir
default_root_dir=tmpdir,
)
result = trainer.fit(model)
assert result == 1, "Training failed"
Expand All @@ -108,7 +108,11 @@ def test_multiple_loggers_pickle(tmpdir):
logger1 = CustomLogger()
logger2 = CustomLogger()

trainer = Trainer(max_epochs=1, logger=[logger1, logger2])
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
logger=[logger1, logger2],
)
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({"acc": 1.0}, 0)
Expand Down
2 changes: 1 addition & 1 deletion tests/loggers/test_neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _run_training(logger):
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=0.05,
logger=logger
logger=logger,
)
trainer.fit(model)
return logger
Expand Down
4 changes: 2 additions & 2 deletions tests/loggers/test_trains.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_trains_logger(tmpdir):
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=0.05,
logger=logger
logger=logger,
)
result = trainer.fit(model)

Expand All @@ -40,7 +40,7 @@ def test_trains_pickle(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
logger=logger
logger=logger,
)
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)
Expand Down
Loading