Skip to content

Commit

Permalink
fix #2386
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Jun 28, 2020
1 parent 66ffbad commit c8f2c83
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
7 changes: 5 additions & 2 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def load_from_checkpoint(

@classmethod
def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, **cls_kwargs):
cls_spec = inspect.getfullargspec(cls.__init__)
# pass in the values we saved automatically
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
model_args = {}
Expand All @@ -184,7 +185,6 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, **cls_kwargs):
model_args = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_TYPE](model_args)

args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
cls_spec = inspect.getfullargspec(cls.__init__)
kwargs_identifier = cls_spec.varkw
cls_init_args_name = inspect.signature(cls).parameters.keys()

Expand All @@ -199,8 +199,11 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, **cls_kwargs):
else:
cls_args = (model_args,) + cls_args

# load the state_dict on the model automatically
# prevent passing positional arguments if class does not accept any
if len(cls_spec.args) <= 1 and not cls_spec.kwonlyargs:
cls_args, cls_kwargs = [], {}
model = cls(*cls_args, **cls_kwargs)
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])

# give model a chance to load something
Expand Down
46 changes: 45 additions & 1 deletion tests/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import pytest
import torch
from omegaconf import OmegaConf, Container
from torch.nn import functional as F
from torch.utils.data import DataLoader

from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml
from pytorch_lightning.utilities import AttributeDict
from tests.base import EvalModelTemplate
from tests.base import EvalModelTemplate, TrialMNIST


class SaveHparamsModel(EvalModelTemplate):
Expand Down Expand Up @@ -238,13 +240,19 @@ def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param='something'))
self.save_hyperparameters()


class NoArgsSubClassEvalModel(EvalModelTemplate):
def __init__(self):
super().__init__()


@pytest.mark.parametrize("cls", [
EvalModelTemplate,
SubClassEvalModel,
SubSubClassEvalModel,
AggSubClassEvalModel,
UnconventionalArgsEvalModel,
DictConfSubClassEvalModel,
# NoArgsSubClassEvalModel,
])
def test_collect_init_arguments(tmpdir, cls):
""" Test that the model automatically saves the arguments passed into the constructor """
Expand Down Expand Up @@ -425,3 +433,39 @@ def test_hparams_save_yaml(tmpdir):

save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
assert load_hparams_from_yaml(path_yaml) == hparams


def test_nohparams_train_test(tmpdir):

class SimpleNoArgsModel(LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)

def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))

def training_step(self, batch, batch_nb):
x, y = batch
loss = F.cross_entropy(self(x), y)
return {'loss': loss, 'log': {'train_loss': loss}}

def test_step(self, batch, batch_nb):
x, y = batch
loss = F.cross_entropy(self(x), y)
return {'loss': loss, 'log': {'train_loss': loss}}

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)

model = SimpleNoArgsModel()
trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
)

train_loader = DataLoader(TrialMNIST(os.getcwd(), train=True, download=True), batch_size=32)
trainer.fit(model, train_loader)

test_loader = DataLoader(TrialMNIST(os.getcwd(), train=False, download=True), batch_size=32)
trainer.test(test_dataloaders=test_loader)

0 comments on commit c8f2c83

Please sign in to comment.