Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Customizing hparams after loading checkpoint #1474

Closed
tullie opened this issue Apr 13, 2020 · 1 comment · Fixed by #1797
Closed

Customizing hparams after loading checkpoint #1474

tullie opened this issue Apr 13, 2020 · 1 comment · Fixed by #1797
Labels
question Further information is requested

Comments

@tullie
Copy link
Contributor

tullie commented Apr 13, 2020

❓ Questions and Help

Before asking:

  1. search the issues.
  2. search the docs.

What is your question?

I'm wondering what the best practice for loading a model with different hparams than what is stored in the checkpoint?

I realize I could just load the model and set them afterwards e.g.:

model = model.load_from_checkpoint(args.checkpoint_file) # Load model

# Set hparams etc..
model.hparams.arg1 = 0.0
model.hparams.arg2 = 1.0
 

But the problem is that my model init function depends on the hparams arg1 and arg2 so they're set too late.

I could also do

checkpoint = torch.load(args.checkpoint_file)
checkpoint['hparams']['arg1'] = 0.0
checkpoint['hparams']['arg2'] = 1.0
model = model._load_state_dict(checkpoint)

The problem here is that i'm using the protected function _load_state_dict. Is there another way of solving this that i've missed? Or could we consider making _load_state_dict public?

@tullie tullie added the question Further information is requested label Apr 13, 2020
@williamFalcon
Copy link
Contributor

@tullie good question, maybe we can have a flag for disabling hparam use?

load_from_checkpoint(PATH, auto_hparam=False, arg1=my_arg, arg2=my_arg2, hparam={...})

load_from_checkpoint currently allows passing in the args directly as shown above

load_from_checkpoint(PATH, auto_hparam=False, arg1=my_arg, arg2=my_arg2)

so, in this case an arg (hparam) would be a regular arg which you could then construct to be whatever you want.

Alternative 2:
We add a hparam_updates arg which sets those updates in the hparams

load_from_checkpoint(PATH, hparam_updates={'arg1': 0.0, 'arg2': 0.0})

@tullie
Copy link
Contributor Author

tullie commented Apr 21, 2020

I've been playing around with both these options and have most preferred alternative 2.
I'll send out a PR soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants