Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix loading with hparams #2403

Merged
merged 6 commits into from
Jun 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed loading past checkpoints from v0.7.x ([#2405](https://github.com/PyTorchLightning/pytorch-lightning/pull/2405))

- Fixed loading model without arguments ([#2403](https://github.com/PyTorchLightning/pytorch-lightning/pull/2403))

## [0.8.1] - 2020-06-19

### Fixed
Expand Down
16 changes: 10 additions & 6 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def load_from_checkpoint(

@classmethod
def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, **cls_kwargs):
cls_spec = inspect.getfullargspec(cls.__init__)
cls_init_args_name = inspect.signature(cls).parameters.keys()
# pass in the values we saved automatically
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
model_args = {}
Expand All @@ -183,23 +185,25 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, **cls_kwargs):
model_args = _convert_loaded_hparams(model_args, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE))

args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
cls_spec = inspect.getfullargspec(cls.__init__)
kwargs_identifier = cls_spec.varkw
cls_init_args_name = inspect.signature(cls).parameters.keys()

if args_name == 'kwargs':
# in case the class cannot take any extra argument filter only the possible
if not kwargs_identifier:
model_args = {k: v for k, v in model_args.items() if k in cls_init_args_name}
cls_kwargs.update(**model_args)
elif args_name:
if args_name in cls_init_args_name:
cls_kwargs.update({args_name: model_args})
else:
cls_args = (model_args,) + cls_args

# load the state_dict on the model automatically
if not cls_spec.varkw:
# filter kwargs according to class init unless it allows any argument via kwargs
cls_kwargs = {k: v for k, v in cls_kwargs.items() if k in cls_init_args_name}

# prevent passing positional arguments if class does not accept any
if len(cls_spec.args) <= 1 and not cls_spec.kwonlyargs:
cls_args, cls_kwargs = [], {}
model = cls(*cls_args, **cls_kwargs)
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])

# give model a chance to load something
Expand Down
94 changes: 82 additions & 12 deletions tests/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import pytest
import torch
from omegaconf import OmegaConf, Container
from torch.nn import functional as F
from torch.utils.data import DataLoader

from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml
from pytorch_lightning.utilities import AttributeDict
from tests.base import EvalModelTemplate
from tests.base import EvalModelTemplate, TrialMNIST


class SaveHparamsModel(EvalModelTemplate):
Expand Down Expand Up @@ -103,16 +105,16 @@ def test_explicit_args_hparams(tmpdir):
"""

# define model
class TestModel(EvalModelTemplate):
class LocalModel(EvalModelTemplate):
def __init__(self, test_arg, test_arg2):
super().__init__()
self.save_hyperparameters('test_arg', 'test_arg2')

model = TestModel(test_arg=14, test_arg2=90)
model = LocalModel(test_arg=14, test_arg2=90)

# run standard test suite
raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, TestModel)
model = TestModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=120)
raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, LocalModel)
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=120)

# config specific tests
assert model.hparams.test_arg2 == 120
Expand All @@ -124,16 +126,16 @@ def test_implicit_args_hparams(tmpdir):
"""

# define model
class TestModel(EvalModelTemplate):
class LocalModel(EvalModelTemplate):
def __init__(self, test_arg, test_arg2):
super().__init__()
self.save_hyperparameters()

model = TestModel(test_arg=14, test_arg2=90)
model = LocalModel(test_arg=14, test_arg2=90)

# run standard test suite
raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, TestModel)
model = TestModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=120)
raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, LocalModel)
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=120)

# config specific tests
assert model.hparams.test_arg2 == 120
Expand All @@ -145,12 +147,12 @@ def test_explicit_missing_args_hparams(tmpdir):
"""

# define model
class TestModel(EvalModelTemplate):
class LocalModel(EvalModelTemplate):
def __init__(self, test_arg, test_arg2):
super().__init__()
self.save_hyperparameters('test_arg')

model = TestModel(test_arg=14, test_arg2=90)
model = LocalModel(test_arg=14, test_arg2=90)

# test proper property assignments
assert model.hparams.test_arg == 14
Expand All @@ -166,7 +168,7 @@ def __init__(self, test_arg, test_arg2):
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]['test_arg'] == 14

# verify that model loads correctly
model = TestModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=123)
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=123)
assert model.hparams.test_arg == 14
assert 'test_arg2' not in model.hparams # test_arg2 is not registered in class init

Expand Down Expand Up @@ -427,3 +429,71 @@ def test_hparams_save_yaml(tmpdir):

save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
assert load_hparams_from_yaml(path_yaml) == hparams


class NoArgsSubClassEvalModel(EvalModelTemplate):
def __init__(self):
super().__init__()


class SimpleNoArgsModel(LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)

def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))

def training_step(self, batch, batch_nb):
x, y = batch
loss = F.cross_entropy(self(x), y)
return {'loss': loss, 'log': {'train_loss': loss}}

def test_step(self, batch, batch_nb):
x, y = batch
loss = F.cross_entropy(self(x), y)
return {'loss': loss, 'log': {'train_loss': loss}}

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)


@pytest.mark.parametrize("cls", [
SimpleNoArgsModel,
NoArgsSubClassEvalModel,
])
def test_model_nohparams_train_test(tmpdir, cls):
"""Test models that do not tae any argument in init."""

model = cls()
trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
)

train_loader = DataLoader(TrialMNIST(os.getcwd(), train=True, download=True), batch_size=32)
trainer.fit(model, train_loader)

test_loader = DataLoader(TrialMNIST(os.getcwd(), train=False, download=True), batch_size=32)
trainer.test(test_dataloaders=test_loader)


def test_model_ignores_non_exist_kwargument(tmpdir):
"""Test that the model takes only valid class arguments."""

class LocalModel(EvalModelTemplate):
def __init__(self, batch_size=15):
super().__init__(batch_size=batch_size)
self.save_hyperparameters()

model = LocalModel()
assert model.hparams.batch_size == 15

# verify that the checkpoint saved the correct values
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
trainer.fit(model)

# verify that we can overwrite whatever we want
raw_checkpoint_path = _raw_checkpoint_path(trainer)
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, non_exist_kwarg=99)
assert 'non_exist_kwarg' not in model.hparams