Skip to content

Commit

Permalink
tmpdir
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Jun 28, 2020
1 parent df283e6 commit 677f70c
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 8 deletions.
84 changes: 84 additions & 0 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def on_test_end(self, trainer, pl_module):
test_callback = TestCallback()

trainer_options = dict(
default_root_dir=tmpdir,
callbacks=[test_callback],
max_epochs=1,
limit_val_batches=0.1,
Expand Down Expand Up @@ -261,3 +262,86 @@ def on_test_end(self, trainer, pl_module):
assert not test_callback.on_validation_end_called
assert not test_callback.on_validation_batch_end_called
assert not test_callback.on_validation_batch_start_called


def test_early_stopping_no_val_step(tmpdir):
"""Test that early stopping callback falls back to training metrics when no validation defined."""

class CurrentModel(EvalModelTemplate):
def training_step(self, *args, **kwargs):
output = super().training_step(*args, **kwargs)
output.update({'my_train_metric': output['loss']}) # could be anything else
return output

model = CurrentModel()
model.validation_step = None
model.val_dataloader = None

stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=stopping,
overfit_batches=0.20,
max_epochs=2,
)
result = trainer.fit(model)

assert result == 1, 'training failed to complete'
assert trainer.current_epoch <= trainer.max_epochs


def test_pickling(tmpdir):
import pickle
early_stopping = EarlyStopping()
ckpt = ModelCheckpoint(tmpdir)

early_stopping_pickled = pickle.dumps(early_stopping)
ckpt_pickled = pickle.dumps(ckpt)

early_stopping_loaded = pickle.loads(early_stopping_pickled)
ckpt_loaded = pickle.loads(ckpt_pickled)

assert vars(early_stopping) == vars(early_stopping_loaded)
assert vars(ckpt) == vars(ckpt_loaded)


@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2])
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
""" Test that None in checkpoint callback is valid and that chkp_path is set correctly """
tutils.reset_seed()
model = EvalModelTemplate()

checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k)

trainer = Trainer(
default_root_dir=tmpdir,
checkpoint_callback=checkpoint,
overfit_batches=0.20,
max_epochs=2,
)
trainer.fit(model)

# These should be different if the dirpath has be overridden
assert trainer.ckpt_path != trainer.default_root_dir


@pytest.mark.parametrize(
'logger_version,expected',
[(None, 'version_0'), (1, 'version_1'), ('awesome', 'awesome')],
)
def test_model_checkpoint_path(tmpdir, logger_version, expected):
"""Test that "version_" prefix is only added when logger's version is an integer"""
tutils.reset_seed()
model = EvalModelTemplate()
logger = TensorBoardLogger(str(tmpdir), version=logger_version)

trainer = Trainer(
default_root_dir=tmpdir,
overfit_batches=0.2,
max_epochs=2,
logger=logger,
)
trainer.fit(model)

ckpt_version = Path(trainer.ckpt_path).parent.name
assert ckpt_version == expected
3 changes: 2 additions & 1 deletion tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def on_train_start(self, trainer, pl_module):
default_root_dir=tmpdir,
checkpoint_callback=checkpoint_callback,
early_stop_callback=early_stop_callback,
max_epochs=4
max_epochs=4,
)
trainer.fit(model)
early_stop_callback_state = early_stop_callback.state_dict()
Expand All @@ -44,6 +44,7 @@ def on_train_start(self, trainer, pl_module):
# ensure state is reloaded properly (assertion in the callback)
early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state)
new_trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
resume_from_checkpoint=checkpoint_filepath,
early_stop_callback=early_stop_callback,
Expand Down
8 changes: 4 additions & 4 deletions tests/callbacks/test_lr_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_lr_logger_single_lr(tmpdir):
max_epochs=2,
limit_val_batches=0.1,
limit_train_batches=0.5,
callbacks=[lr_logger]
callbacks=[lr_logger],
)
result = trainer.fit(model)
assert result
Expand All @@ -42,7 +42,7 @@ def test_lr_logger_no_lr(tmpdir):
max_epochs=2,
limit_val_batches=0.1,
limit_train_batches=0.5,
callbacks=[lr_logger]
callbacks=[lr_logger],
)

with pytest.warns(RuntimeWarning):
Expand All @@ -63,7 +63,7 @@ def test_lr_logger_multi_lrs(tmpdir):
max_epochs=2,
limit_val_batches=0.1,
limit_train_batches=0.5,
callbacks=[lr_logger]
callbacks=[lr_logger],
)
result = trainer.fit(model)
assert result
Expand All @@ -90,7 +90,7 @@ def test_lr_logger_param_groups(tmpdir):
max_epochs=2,
limit_val_batches=0.1,
limit_train_batches=0.5,
callbacks=[lr_logger]
callbacks=[lr_logger],
)
result = trainer.fit(model)
assert result
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
default_root_dir=tmpdir,
checkpoint_callback=checkpoint,
overfit_pct=0.20,
max_epochs=5
max_epochs=5,
)
trainer.fit(model)

Expand All @@ -44,7 +44,7 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected):
default_root_dir=tmpdir,
overfit_pct=0.2,
max_epochs=5,
logger=logger
logger=logger,
)
trainer.fit(model)

Expand Down
3 changes: 2 additions & 1 deletion tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,11 @@ def test_progress_bar_totals(tmpdir):
assert bar.test_batch_idx == k


def test_progress_bar_fast_dev_run():
def test_progress_bar_fast_dev_run(tmpdir):
model = EvalModelTemplate()

trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
)

Expand Down

0 comments on commit 677f70c

Please sign in to comment.