Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

An Extra argument passed to the class, loaded from load_from_checkpoint. #2386

Closed
nischal-sanil opened this issue Jun 27, 2020 · 1 comment · Fixed by #2403
Closed

An Extra argument passed to the class, loaded from load_from_checkpoint. #2386

nischal-sanil opened this issue Jun 27, 2020 · 1 comment · Fixed by #2403
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@nischal-sanil
Copy link

nischal-sanil commented Jun 27, 2020

🐛 Bug

Hello,
I was facing few issues while using the trainer.test() function, on debugging I found out that the problem was with the _load_model_state class method which is called by load_from_checkpoint.

Code For reference

@classmethod
def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs):
    # pass in the values we saved automatically
    if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
        model_args = {}

        # add some back compatibility, the actual one shall be last
        for hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS + (cls.CHECKPOINT_HYPER_PARAMS_KEY,):
            if hparam_key in checkpoint:
                model_args.update(checkpoint[hparam_key])

        if cls.CHECKPOINT_HYPER_PARAMS_TYPE in checkpoint:
            model_args = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_TYPE](model_args)

        args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
        init_args_name = inspect.signature(cls).parameters.keys()

        if args_name == 'kwargs':
            cls_kwargs = {k: v for k, v in model_args.items() if k in init_args_name}
            kwargs.update(**cls_kwargs)
        elif args_name:
            if args_name in init_args_name:
                kwargs.update({args_name: model_args})
        else:
            args = (model_args, ) + args

    # load the state_dict on the model automatically
    model = cls(*args, **kwargs)
    model.load_state_dict(checkpoint['state_dict'])

    # give model a chance to load something
    model.on_load_checkpoint(checkpoint)

    return model

Consider the case where the model has no arguments, which corresponds to LightModel.load_from_checkpoint('path'). Here, the else clause of the if-elif is being executed where the agrs variable is updated from an empty tuple to a tuple with an empty dictionary args = (model_args, ) + args (as model_args={}). Therefore, while unpacking the args and kwargs (model = cls(*args, **kwargs)), There is an extra argument being passed which raises a TypeError: __init__() takes 1 positional arguments but 2 were given. #2364

In some cases if the model has an argument and the user has forgotten to add it in the load_from_checkpoint, then an empty dictionary will be passed instead and it raises other errors depending on the code. For example, in the issue #2359 an empty dict is passed while loading the model and hence raises RuntimeError: Error(s) in loading state_dict for Model:.

I do not fully understand what is happening in the function. It would be great if someone can suggest changes to make in the comments so that I can start working after updating the changes in my forked repo.

Steps to reproduce

!pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@master --upgrade

import os

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl

class MNISTModel(pl.LightningModule):

    def __init__(self):
        super(MNISTModel, self).__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def test_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)


train_loader = DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

mnist_model = MNISTModel()
trainer = pl.Trainer(gpus=1,max_epochs=3)    
trainer.fit(mnist_model, train_loader)  

test_loader = DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)
trainer.test(test_dataloaders=test_loader)

Which returns:

TypeError                                 Traceback (most recent call last)

<ipython-input-5-50449ee4f6cc> in <module>()
      1 test_loader = DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)
----> 2 trainer.test(test_dataloaders=test_loader)

/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py in test(self, model, test_dataloaders, ckpt_path)
   1168             if ckpt_path == 'best':
   1169                 ckpt_path = self.checkpoint_callback.best_model_path
-> 1170             model = self.get_model().load_from_checkpoint(ckpt_path)
   1171 
   1172         self.testing = True

/usr/local/lib/python3.6/dist-packages/pytorch_lightning/core/saving.py in load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, tags_csv, *args, **kwargs)
    167         checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
    168 
--> 169         model = cls._load_model_state(checkpoint, *args, **kwargs)
    170         return model
    171 

/usr/local/lib/python3.6/dist-packages/pytorch_lightning/core/saving.py in _load_model_state(cls, checkpoint, *cls_args, **cls_kwargs)
    201 
    202         # load the state_dict on the model automatically
--> 203         model = cls(*cls_args, **cls_kwargs)
    204         model.load_state_dict(checkpoint['state_dict'])
    205 

TypeError: __init__() takes 1 positional argument but 2 were given

Expected behavior

Start testing

@nischal-sanil nischal-sanil added bug Something isn't working help wanted Open to be worked on labels Jun 27, 2020
@github-actions
Copy link
Contributor

Hi! thanks for your contribution!, great first issue!

@nischal-sanil nischal-sanil changed the title An Extra argument passed to the class loaded from load_from_checkpoint. An Extra argument passed to the class, loaded from load_from_checkpoint. Jun 27, 2020
Borda added a commit that referenced this issue Jun 28, 2020
@Borda Borda mentioned this issue Jun 28, 2020
7 tasks
Borda added a commit that referenced this issue Jun 28, 2020
Borda added a commit that referenced this issue Jun 28, 2020
williamFalcon pushed a commit that referenced this issue Jun 29, 2020
* fix #2386

* extra test

* extra case

* extra test

* chlog

* fix test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant