Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when loading with LitModel.load_from_checkpoint(path)? #2364

Closed
MicPie opened this issue Jun 25, 2020 · 6 comments · Fixed by #2403
Closed

Error when loading with LitModel.load_from_checkpoint(path)? #2364

MicPie opened this issue Jun 25, 2020 · 6 comments · Fixed by #2403
Labels
duplicate This issue or pull request already exists question Further information is requested

Comments

@MicPie
Copy link

MicPie commented Jun 25, 2020

❓ Questions and Help

Before asking:

  1. search the issues. Done
  2. search the docs. Done

What is your question?

When I'm loading a Lightning model with LitModel.load_from_checkpoint(path) I always get this error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-11-5cfc485836df> in <module>
      1 # OPTION 2:
      2 # test after loading weights
----> 3 model = LitModel.load_from_checkpoint("lightning_logs/version_0/checkpoints/epoch=2.ckpt")

~/anaconda3/envs/pytorch-lightning/lib/python3.7/site-packages/pytorch_lightning/core/saving.py in load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, tags_csv, *args, **kwargs)
    169         checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
    170 
--> 171         model = cls._load_model_state(checkpoint, *args, **kwargs)
    172         return model
    173 

~/anaconda3/envs/pytorch-lightning/lib/python3.7/site-packages/pytorch_lightning/core/saving.py in _load_model_state(cls, checkpoint, *args, **kwargs)
    195 
    196         # load the state_dict on the model automatically
--> 197         model = cls(*args, **kwargs)
    198         model.load_state_dict(checkpoint['state_dict'])
    199 

TypeError: __init__() takes 1 positional argument but 2 were given

Code

Simple example is from https://pytorch-lightning.readthedocs.io/en/latest/new-project.html.
Custom LightModul is very similar but I am using other data, sets, and loaders and a torchvision ResNet.

What have you tried?

I tried different checkpoints as well as relative and absolute paths.
Checkpoints are loaded from lightning_logs/version_0/checkpoints/ which where saved automatically after stopping training.

Setup and training itself works without problems.

I also get the same error with the more complex setup.

Maybe I don't see something super obvious?

Thank you very much & kind regards

What's your environment?

  • OS: Ubuntu 18.04
  • Packaging: Conda
  • Version:
    pytorch_select 0.2
    pytorch 1.3.1
    pytorch-lightning 0.8.0
@MicPie MicPie added the question Further information is requested label Jun 25, 2020
@github-actions
Copy link
Contributor

Hi! thanks for your contribution!, great first issue!

@MicPie MicPie changed the title Error when loading Lightning model with LitModel.load_from_checkpoint(path)? Error when loading with LitModel.load_from_checkpoint(path)? Jun 25, 2020
@dscarmo
Copy link
Contributor

dscarmo commented Jun 25, 2020

Might be related to #2334
If you put an hparams argument in your LitModel does it work? What about using 0.8.1?

@MicPie
Copy link
Author

MicPie commented Jun 25, 2020

Thank you for your reply!

This code snippet gives me the same error:

model = LitModel.load_from_checkpoint(checkpoint_path="lightning_logs/version_0/checkpoints/epoch=2.ckpt",
                                      hparams_file="lightning_logs/version_0/hparams.yaml")

Via conda I get no update and also tried to upgrade with pip but still 0.8.
Do I need for 0.8.1 an pip editable install via the GitHub repo?

@dscarmo
Copy link
Contributor

dscarmo commented Jun 25, 2020

Sorry, I meant changing the init to (self, hparams).

To me it looks the load_from_checkpoint is trying to force hparams (2 arguments given)
But this might have been fixed in 0.8.1.

Try to pip uninstall pytorch-lightning and/or pip install pytorch-lightning --upgrade

@MicPie
Copy link
Author

MicPie commented Jun 26, 2020

Thank you, I was able to install the 0.8.1 version with this trick.

When I add hparams to the def __init__(self, hparams) to my lightning module I have to supply it during object creation too?

When I load the checkpoint with torch.load everything seems to be there.

I suspect the args supplied to cls._load_model_state(checkpoint, *args, **kwargs) are the reason this happens (kwargs are empty), will check later in detail.

@Borda
Copy link
Member

Borda commented Jun 28, 2020

I suspect the args supplied to cls._load_model_state(checkpoint, *args, **kwargs) are the reason this happens (kwargs are empty), will check later in detail.

yes, you need to allow your class to pass arguments to the parent class, otherwise, it seems duplicate to #2386

@Borda Borda added the duplicate This issue or pull request already exists label Jun 28, 2020
@Borda Borda mentioned this issue Jun 28, 2020
7 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
duplicate This issue or pull request already exists question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants