Skip to content

Commit

Permalink
Tests: refactor models (#1691)
Browse files Browse the repository at this point in the history
* refactor default model

* drop redundant seeds

* drop redundant seeds

* refactor models tests

* refactor models tests

* imports

* fix conf

* Apply suggestions from code review
  • Loading branch information
Borda committed May 4, 2020
1 parent d28b145 commit 1077159
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 86 deletions.
4 changes: 2 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,8 +1535,8 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'Ligh
hparams = Namespace(**ckpt_hparams) if is_namespace else ckpt_hparams
else:
rank_zero_warn(
f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__ "
f"contains argument 'hparams'. Will pass in an empty Namespace instead."
f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__"
" contains argument 'hparams'. Will pass in an empty Namespace instead."
" Did you forget to store your model hyperparameters in self.hparams?"
)
hparams = Namespace()
Expand Down
3 changes: 1 addition & 2 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,7 @@ def test_pickling(tmpdir):

@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2])
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
""" Test that None in checkpoint callback is valid and that chkp_path is
set correctly """
""" Test that None in checkpoint callback is valid and that chkp_path is set correctly """
tutils.reset_seed()

class CurrentTestModel(LightTrainDataloader, TestModelBase):
Expand Down
6 changes: 2 additions & 4 deletions tests/models/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import LightningTestModel, EvalModelTemplate
from tests.base import EvalModelTemplate


@pytest.mark.spawn
Expand All @@ -15,7 +15,6 @@
def test_amp_single_gpu(tmpdir, backend):
"""Make sure DP/DDP + AMP work."""
tutils.reset_seed()

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
Expand Down Expand Up @@ -63,8 +62,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
tutils.set_random_master_port()
os.environ['SLURM_LOCALID'] = str(0)

hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
model = EvalModelTemplate(tutils.get_default_hparams())

# exp file to get meta
logger = tutils.get_default_logger(tmpdir)
Expand Down
26 changes: 6 additions & 20 deletions tests/models/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,8 @@

import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import (
EarlyStopping,
)
from tests.base import (
TestModelBase,
LightTrainDataloader,
LightningTestModel,
LightTestMixin,
EvalModelTemplate,
)
from pytorch_lightning.callbacks import EarlyStopping
from tests.base import EvalModelTemplate


def test_early_stopping_cpu_model(tmpdir):
Expand Down Expand Up @@ -106,8 +98,7 @@ def test_default_logger_callbacks_cpu_model(tmpdir):

def test_running_test_after_fitting(tmpdir):
"""Verify test() on fitted model."""
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
model = EvalModelTemplate(tutils.get_default_hparams())

# logger file to get meta
logger = tutils.get_default_logger(tmpdir)
Expand Down Expand Up @@ -138,11 +129,7 @@ def test_running_test_after_fitting(tmpdir):

def test_running_test_no_val(tmpdir):
"""Verify `test()` works on a model with no `val_loader`."""
class CurrentTestModel(LightTrainDataloader, LightTestMixin, TestModelBase):
pass

hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)
model = EvalModelTemplate(tutils.get_default_hparams())

# logger file to get meta
logger = tutils.get_default_logger(tmpdir)
Expand Down Expand Up @@ -220,8 +207,7 @@ def test_single_gpu_batch_parse():

def test_simple_cpu(tmpdir):
"""Verify continue training session on CPU."""
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
model = EvalModelTemplate(tutils.get_default_hparams())

# fit model
trainer = Trainer(
Expand Down Expand Up @@ -285,7 +271,7 @@ def __getitem__(self, i):
def __len__(self):
return 1

class BpttTestModel(LightTrainDataloader, TestModelBase):
class BpttTestModel(EvalModelTemplate):
def __init__(self, hparams):
super().__init__(hparams)
self.test_hidden = None
Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pytorch_lightning.core import memory
from pytorch_lightning.trainer.distrib_parts import parse_gpu_ids, determine_root_gpu_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import LightningTestModel, EvalModelTemplate
from tests.base import EvalModelTemplate

PRETEND_N_OF_GPUS = 16

Expand Down Expand Up @@ -65,7 +65,7 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
def test_cpu_slurm_save_load(tmpdir):
"""Verify model save/load/checkpoint on CPU."""
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
model = EvalModelTemplate(hparams)

# logger file to get meta
logger = tutils.get_default_logger(tmpdir)
Expand Down Expand Up @@ -112,7 +112,7 @@ def test_cpu_slurm_save_load(tmpdir):
logger=logger,
checkpoint_callback=ModelCheckpoint(tmpdir),
)
model = LightningTestModel(hparams)
model = EvalModelTemplate(hparams)

# set the epoch start hook so we can predict before the model does the full training
def assert_pred_same():
Expand Down
16 changes: 3 additions & 13 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,19 @@

import tests.base.utils as tutils
from pytorch_lightning import Trainer
from tests.base import (
LightTrainDataloader,
LightValidationMixin,
TestModelBase,
LightTestMixin)
from tests.base import EvalModelTemplate


@pytest.mark.parametrize('max_steps', [1, 2, 3])
def test_on_before_zero_grad_called(max_steps):

class CurrentTestModel(
LightTrainDataloader,
LightValidationMixin,
LightTestMixin,
TestModelBase,
):
class CurrentTestModel(EvalModelTemplate):
on_before_zero_grad_called = 0

def on_before_zero_grad(self, optimizer):
self.on_before_zero_grad_called += 1

hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)
model = CurrentTestModel(tutils.get_default_hparams())

trainer = Trainer(
max_steps=max_steps,
Expand Down
5 changes: 3 additions & 2 deletions tests/models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pytorch_lightning import Trainer

import tests.base.utils as tutils
from tests.base import LightningTestModel
from tests.base import EvalModelTemplate
from tests.base.models import TestGAN

try:
Expand Down Expand Up @@ -107,7 +107,8 @@ def test_horovod_multi_gpu(tmpdir):
@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_horovod_transfer_batch_to_gpu(tmpdir):
class TestTrainingStepModel(LightningTestModel):

class TestTrainingStepModel(EvalModelTemplate):
def training_step(self, batch, *args, **kwargs):
x, y = batch
assert str(x.device) != 'cpu'
Expand Down
87 changes: 47 additions & 40 deletions tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import (
LightningTestModel,
LightningTestModelWithoutHyperparametersArg,
LightningTestModelWithUnusedHyperparametersArg
)
from tests.base import EvalModelTemplate


@pytest.mark.spawn
Expand All @@ -23,8 +19,7 @@ def test_running_test_pretrained_model_distrib(tmpdir, backend):
"""Verify `test()` on pretrained model."""
tutils.set_random_master_port()

hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
model = EvalModelTemplate(tutils.get_default_hparams())

# exp file to get meta
logger = tutils.get_default_logger(tmpdir)
Expand Down Expand Up @@ -53,7 +48,7 @@ def test_running_test_pretrained_model_distrib(tmpdir, backend):
assert result == 1, 'training failed to complete'
pretrained_model = tutils.load_model(logger,
trainer.checkpoint_callback.dirpath,
module_class=LightningTestModel)
module_class=EvalModelTemplate)

# run test set
new_trainer = Trainer(**trainer_options)
Expand All @@ -72,8 +67,7 @@ def test_running_test_pretrained_model_distrib(tmpdir, backend):

def test_running_test_pretrained_model_cpu(tmpdir):
"""Verify test() on pretrained model."""
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
model = EvalModelTemplate(tutils.get_default_hparams())

# logger file to get meta
logger = tutils.get_default_logger(tmpdir)
Expand All @@ -97,7 +91,7 @@ def test_running_test_pretrained_model_cpu(tmpdir):
# correct result and ok accuracy
assert result == 1, 'training failed to complete'
pretrained_model = tutils.load_model(
logger, trainer.checkpoint_callback.dirpath, module_class=LightningTestModel
logger, trainer.checkpoint_callback.dirpath, module_class=EvalModelTemplate
)

new_trainer = Trainer(**trainer_options)
Expand All @@ -110,7 +104,7 @@ def test_running_test_pretrained_model_cpu(tmpdir):
def test_load_model_from_checkpoint(tmpdir):
"""Verify test() on pretrained model."""
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
model = EvalModelTemplate(hparams)

trainer_options = dict(
progress_bar_refresh_rate=0,
Expand All @@ -131,7 +125,7 @@ def test_load_model_from_checkpoint(tmpdir):

# load last checkpoint
last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1]
pretrained_model = LightningTestModel.load_from_checkpoint(last_checkpoint)
pretrained_model = EvalModelTemplate.load_from_checkpoint(last_checkpoint)

# test that hparams loaded correctly
for k, v in vars(hparams).items():
Expand All @@ -152,7 +146,13 @@ def test_load_model_from_checkpoint(tmpdir):
def test_dp_resume(tmpdir):
"""Make sure DP continues training correctly."""
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
model = EvalModelTemplate(hparams)

trainer_options = dict(
max_epochs=1,
gpus=2,
distributed_backend='dp',
)

# get logger
logger = tutils.get_default_logger(tmpdir)
Expand All @@ -161,13 +161,9 @@ def test_dp_resume(tmpdir):
# logger file to get weights
checkpoint = tutils.init_checkpoint_callback(logger)

trainer_options = dict(
max_epochs=1,
gpus=2,
distributed_backend='dp',
logger=logger,
checkpoint_callback=checkpoint,
)
# add these to the trainer options
trainer_options['logger'] = logger
trainer_options['checkpoint_callback'] = checkpoint

# fit model
trainer = Trainer(**trainer_options)
Expand All @@ -188,13 +184,11 @@ def test_dp_resume(tmpdir):

# init new trainer
new_logger = tutils.get_default_logger(tmpdir, version=logger.version)
trainer_options.update(
logger=new_logger,
checkpoint_callback=ModelCheckpoint(tmpdir),
train_percent_check=0.5,
val_percent_check=0.2,
max_epochs=1,
)
trainer_options['logger'] = new_logger
trainer_options['checkpoint_callback'] = ModelCheckpoint(tmpdir)
trainer_options['train_percent_check'] = 0.5
trainer_options['val_percent_check'] = 0.2
trainer_options['max_epochs'] = 1
new_trainer = Trainer(**trainer_options)

# set the epoch start hook so we can predict before the model does the full training
Expand All @@ -210,7 +204,7 @@ def assert_good_acc():
tutils.run_prediction(dataloader, dp_model, dp=True)

# new model
model = LightningTestModel(hparams)
model = EvalModelTemplate(hparams)
model.on_train_start = assert_good_acc

# fit new model which should load hpc weights
Expand All @@ -223,18 +217,19 @@ def assert_good_acc():

def test_model_saving_loading(tmpdir):
"""Tests use case where trainer saves the model, and user loads it from tags independently."""
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
model = EvalModelTemplate(tutils.get_default_hparams())

# logger file to get meta
logger = tutils.get_default_logger(tmpdir)

# fit model
trainer = Trainer(
trainer_options = dict(
max_epochs=1,
logger=logger,
checkpoint_callback=ModelCheckpoint(tmpdir)
)

# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)

# traning complete
Expand Down Expand Up @@ -263,7 +258,7 @@ def test_model_saving_loading(tmpdir):
# load new model
tags_path = tutils.get_data_path(logger, path_dir=tmpdir)
tags_path = os.path.join(tags_path, 'meta_tags.csv')
model_2 = LightningTestModel.load_from_checkpoint(
model_2 = EvalModelTemplate.load_from_checkpoint(
checkpoint_path=new_weights_path,
tags_csv=tags_path
)
Expand All @@ -276,31 +271,43 @@ def test_model_saving_loading(tmpdir):


def test_load_model_with_missing_hparams(tmpdir):
# fit model
trainer = Trainer(
trainer_options = dict(
progress_bar_refresh_rate=0,
max_epochs=1,
checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1),
logger=False,
default_root_dir=tmpdir,
)

model = LightningTestModelWithoutHyperparametersArg()
# fit model
trainer = Trainer(**trainer_options)

class CurrentModelWithoutHparams(EvalModelTemplate):
def __init__(self):
hparams = tutils.get_default_hparams()
super().__init__(hparams)

class CurrentModelUnusedHparams(EvalModelTemplate):
def __init__(self, hparams):
hparams = tutils.get_default_hparams()
super().__init__(hparams)

model = CurrentModelWithoutHparams()
trainer.fit(model)
last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1]

# try to load a checkpoint that has hparams but model is missing hparams arg
with pytest.raises(MisconfigurationException, match=r".*__init__ is missing the argument 'hparams'.*"):
LightningTestModelWithoutHyperparametersArg.load_from_checkpoint(last_checkpoint)
CurrentModelWithoutHparams.load_from_checkpoint(last_checkpoint)

# create a checkpoint without hyperparameters
# if the model does not take a hparams argument, it should not throw an error
ckpt = torch.load(last_checkpoint)
del(ckpt['hparams'])
torch.save(ckpt, last_checkpoint)
LightningTestModelWithoutHyperparametersArg.load_from_checkpoint(last_checkpoint)
CurrentModelWithoutHparams.load_from_checkpoint(last_checkpoint)

# load checkpoint without hparams again
# warn if user's model has hparams argument
with pytest.warns(UserWarning, match=r".*Will pass in an empty Namespace instead."):
LightningTestModelWithUnusedHyperparametersArg.load_from_checkpoint(last_checkpoint)
CurrentModelUnusedHparams.load_from_checkpoint(last_checkpoint)

0 comments on commit 1077159

Please sign in to comment.