Skip to content

Commit

Permalink
fix loading with hparams (#2403)
Browse files Browse the repository at this point in the history
* fix #2386

* extra test

* extra case

* extra test

* chlog

* fix test
  • Loading branch information
Borda committed Jun 29, 2020
1 parent 058c500 commit 1e16681
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 18 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed loading past checkpoints from v0.7.x ([#2405](https://github.com/PyTorchLightning/pytorch-lightning/pull/2405))

- Fixed loading model without arguments ([#2403](https://github.com/PyTorchLightning/pytorch-lightning/pull/2403))

## [0.8.1] - 2020-06-19

### Fixed
Expand Down
16 changes: 10 additions & 6 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def load_from_checkpoint(

@classmethod
def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, **cls_kwargs):
cls_spec = inspect.getfullargspec(cls.__init__)
cls_init_args_name = inspect.signature(cls).parameters.keys()
# pass in the values we saved automatically
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
model_args = {}
Expand All @@ -183,23 +185,25 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, **cls_kwargs):
model_args = _convert_loaded_hparams(model_args, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE))

args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
cls_spec = inspect.getfullargspec(cls.__init__)
kwargs_identifier = cls_spec.varkw
cls_init_args_name = inspect.signature(cls).parameters.keys()

if args_name == 'kwargs':
# in case the class cannot take any extra argument filter only the possible
if not kwargs_identifier:
model_args = {k: v for k, v in model_args.items() if k in cls_init_args_name}
cls_kwargs.update(**model_args)
elif args_name:
if args_name in cls_init_args_name:
cls_kwargs.update({args_name: model_args})
else:
cls_args = (model_args,) + cls_args

# load the state_dict on the model automatically
if not cls_spec.varkw:
# filter kwargs according to class init unless it allows any argument via kwargs
cls_kwargs = {k: v for k, v in cls_kwargs.items() if k in cls_init_args_name}

# prevent passing positional arguments if class does not accept any
if len(cls_spec.args) <= 1 and not cls_spec.kwonlyargs:
cls_args, cls_kwargs = [], {}
model = cls(*cls_args, **cls_kwargs)
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])

# give model a chance to load something
Expand Down
94 changes: 82 additions & 12 deletions tests/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import pytest
import torch
from omegaconf import OmegaConf, Container
from torch.nn import functional as F
from torch.utils.data import DataLoader

from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml
from pytorch_lightning.utilities import AttributeDict
from tests.base import EvalModelTemplate
from tests.base import EvalModelTemplate, TrialMNIST


class SaveHparamsModel(EvalModelTemplate):
Expand Down Expand Up @@ -103,16 +105,16 @@ def test_explicit_args_hparams(tmpdir):
"""

# define model
class TestModel(EvalModelTemplate):
class LocalModel(EvalModelTemplate):
def __init__(self, test_arg, test_arg2):
super().__init__()
self.save_hyperparameters('test_arg', 'test_arg2')

model = TestModel(test_arg=14, test_arg2=90)
model = LocalModel(test_arg=14, test_arg2=90)

# run standard test suite
raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, TestModel)
model = TestModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=120)
raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, LocalModel)
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=120)

# config specific tests
assert model.hparams.test_arg2 == 120
Expand All @@ -124,16 +126,16 @@ def test_implicit_args_hparams(tmpdir):
"""

# define model
class TestModel(EvalModelTemplate):
class LocalModel(EvalModelTemplate):
def __init__(self, test_arg, test_arg2):
super().__init__()
self.save_hyperparameters()

model = TestModel(test_arg=14, test_arg2=90)
model = LocalModel(test_arg=14, test_arg2=90)

# run standard test suite
raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, TestModel)
model = TestModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=120)
raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, LocalModel)
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=120)

# config specific tests
assert model.hparams.test_arg2 == 120
Expand All @@ -145,12 +147,12 @@ def test_explicit_missing_args_hparams(tmpdir):
"""

# define model
class TestModel(EvalModelTemplate):
class LocalModel(EvalModelTemplate):
def __init__(self, test_arg, test_arg2):
super().__init__()
self.save_hyperparameters('test_arg')

model = TestModel(test_arg=14, test_arg2=90)
model = LocalModel(test_arg=14, test_arg2=90)

# test proper property assignments
assert model.hparams.test_arg == 14
Expand All @@ -166,7 +168,7 @@ def __init__(self, test_arg, test_arg2):
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]['test_arg'] == 14

# verify that model loads correctly
model = TestModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=123)
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=123)
assert model.hparams.test_arg == 14
assert 'test_arg2' not in model.hparams # test_arg2 is not registered in class init

Expand Down Expand Up @@ -427,3 +429,71 @@ def test_hparams_save_yaml(tmpdir):

save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
assert load_hparams_from_yaml(path_yaml) == hparams


class NoArgsSubClassEvalModel(EvalModelTemplate):
def __init__(self):
super().__init__()


class SimpleNoArgsModel(LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)

def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))

def training_step(self, batch, batch_nb):
x, y = batch
loss = F.cross_entropy(self(x), y)
return {'loss': loss, 'log': {'train_loss': loss}}

def test_step(self, batch, batch_nb):
x, y = batch
loss = F.cross_entropy(self(x), y)
return {'loss': loss, 'log': {'train_loss': loss}}

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)


@pytest.mark.parametrize("cls", [
SimpleNoArgsModel,
NoArgsSubClassEvalModel,
])
def test_model_nohparams_train_test(tmpdir, cls):
"""Test models that do not tae any argument in init."""

model = cls()
trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
)

train_loader = DataLoader(TrialMNIST(os.getcwd(), train=True, download=True), batch_size=32)
trainer.fit(model, train_loader)

test_loader = DataLoader(TrialMNIST(os.getcwd(), train=False, download=True), batch_size=32)
trainer.test(test_dataloaders=test_loader)


def test_model_ignores_non_exist_kwargument(tmpdir):
"""Test that the model takes only valid class arguments."""

class LocalModel(EvalModelTemplate):
def __init__(self, batch_size=15):
super().__init__(batch_size=batch_size)
self.save_hyperparameters()

model = LocalModel()
assert model.hparams.batch_size == 15

# verify that the checkpoint saved the correct values
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
trainer.fit(model)

# verify that we can overwrite whatever we want
raw_checkpoint_path = _raw_checkpoint_path(trainer)
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, non_exist_kwarg=99)
assert 'non_exist_kwarg' not in model.hparams

0 comments on commit 1e16681

Please sign in to comment.