Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added override for hparams in load_from_ckpt #1797

Merged
merged 8 commits into from
May 13, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from pytorch_lightning.core.grads import GradInformation
from pytorch_lightning.core.hooks import ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv, update_hparams
from pytorch_lightning.core.properties import DeviceDtypeModuleMixin
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn
Expand Down Expand Up @@ -1439,6 +1439,7 @@ def load_from_checkpoint(
checkpoint_path: str,
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
tags_csv: Optional[str] = None,
hparam_overrides: Optional[Dict] = None,
*args, **kwargs
) -> 'LightningModule':
r"""
Expand Down Expand Up @@ -1480,6 +1481,7 @@ def __init__(self, hparams):
use this method to pass in a .csv file with the hparams you'd like to use.
These will be converted into a :class:`~argparse.Namespace` and passed into your
:class:`LightningModule` for use.
hparam_overrides: A dictionary with keys to override in the hparams

Return:
:class:`LightningModule` with loaded weights and hyperparameters (if available).
Expand All @@ -1503,6 +1505,12 @@ def __init__(self, hparams):
tags_csv='/path/to/hparams_file.csv'
)

# override some of the params with new values
MyLightningModule.load_from_checkpoint(
PATH,
hparam_overrides={'num_layers': 128, 'pretrained_ckpt_path': NEW_PATH}
)

# or load passing whatever args the model takes to load
MyLightningModule.load_from_checkpoint(
'path/to/checkpoint.ckpt',
Expand All @@ -1521,12 +1529,16 @@ def __init__(self, hparams):
else:
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)

# add the hparams from csv file to checkpoint
if tags_csv is not None:
# add the hparams from csv file to checkpoint
hparams = load_hparams_from_tags_csv(tags_csv)
hparams.__setattr__('on_gpu', False)
checkpoint['hparams'] = vars(hparams)

# override the hparam keys that were passed in
if hparam_overrides is not None:
update_hparams(hparams, hparam_overrides)

model = cls._load_model_state(checkpoint, *args, **kwargs)
return model

Expand Down
31 changes: 31 additions & 0 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,37 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None:
"""


def update_hparams(hparams: dict, updates: dict) -> None:
"""
Overrides hparams with new values

>>> hparams = {'c': 4}
>>> update_hparams(hparams, {'a': {'b': 2}, 'c': 1})
>>> hparams['a']['b'], hparams['c']
(2, 1)
>>> update_hparams(hparams, {'a': {'b': 4}, 'c': 7})
>>> hparams['a']['b'], hparams['c']
(4, 7)

Args:
hparams: the original params and also target object
updates: new params to be used as update

"""
for k, v in updates.items():
# if missing, add the key
if k not in hparams:
hparams[k] = v
continue

# recurse if dictionary
if isinstance(v, dict):
update_hparams(hparams[k], updates[k])
else:
# update the value
hparams.update({k: v})


def load_hparams_from_tags_csv(tags_csv: str) -> Namespace:
if not os.path.isfile(tags_csv):
log.warning(f'Missing Tags: {tags_csv}.')
Expand Down