Skip to content

Commit

Permalink
added override for hparams in load_from_ckpt (#1797)
Browse files Browse the repository at this point in the history
* added override for hparams in load_from_ckpt

* override hparams

* override hparams

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* update doctest

* typo

* chlog

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
  • Loading branch information
4 people committed May 13, 2020
1 parent 10ce1c0 commit 35fe2ef
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added option to provide seed to random generators to ensure reproducibility ([#1572](https://github.com/PyTorchLightning/pytorch-lightning/pull/1572))

- Added override for hparams in `load_from_ckpt` ([#1797](https://github.com/PyTorchLightning/pytorch-lightning/pull/1797))

### Changed

- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609))
Expand Down
16 changes: 14 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from pytorch_lightning.core.grads import GradInformation
from pytorch_lightning.core.hooks import ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv, update_hparams
from pytorch_lightning.core.properties import DeviceDtypeModuleMixin
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn
Expand Down Expand Up @@ -1439,6 +1439,7 @@ def load_from_checkpoint(
checkpoint_path: str,
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
tags_csv: Optional[str] = None,
hparam_overrides: Optional[Dict] = None,
*args, **kwargs
) -> 'LightningModule':
r"""
Expand Down Expand Up @@ -1480,6 +1481,7 @@ def __init__(self, hparams):
use this method to pass in a .csv file with the hparams you'd like to use.
These will be converted into a :class:`~argparse.Namespace` and passed into your
:class:`LightningModule` for use.
hparam_overrides: A dictionary with keys to override in the hparams
Return:
:class:`LightningModule` with loaded weights and hyperparameters (if available).
Expand All @@ -1503,6 +1505,12 @@ def __init__(self, hparams):
tags_csv='/path/to/hparams_file.csv'
)
# override some of the params with new values
MyLightningModule.load_from_checkpoint(
PATH,
hparam_overrides={'num_layers': 128, 'pretrained_ckpt_path': NEW_PATH}
)
# or load passing whatever args the model takes to load
MyLightningModule.load_from_checkpoint(
'path/to/checkpoint.ckpt',
Expand All @@ -1521,12 +1529,16 @@ def __init__(self, hparams):
else:
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)

# add the hparams from csv file to checkpoint
if tags_csv is not None:
# add the hparams from csv file to checkpoint
hparams = load_hparams_from_tags_csv(tags_csv)
hparams.__setattr__('on_gpu', False)
checkpoint['hparams'] = vars(hparams)

# override the hparam keys that were passed in
if hparam_overrides is not None:
update_hparams(hparams, hparam_overrides)

model = cls._load_model_state(checkpoint, *args, **kwargs)
return model

Expand Down
31 changes: 31 additions & 0 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,37 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None:
"""


def update_hparams(hparams: dict, updates: dict) -> None:
"""
Overrides hparams with new values
>>> hparams = {'c': 4}
>>> update_hparams(hparams, {'a': {'b': 2}, 'c': 1})
>>> hparams['a']['b'], hparams['c']
(2, 1)
>>> update_hparams(hparams, {'a': {'b': 4}, 'c': 7})
>>> hparams['a']['b'], hparams['c']
(4, 7)
Args:
hparams: the original params and also target object
updates: new params to be used as update
"""
for k, v in updates.items():
# if missing, add the key
if k not in hparams:
hparams[k] = v
continue

# recurse if dictionary
if isinstance(v, dict):
update_hparams(hparams[k], updates[k])
else:
# update the value
hparams.update({k: v})


def load_hparams_from_tags_csv(tags_csv: str) -> Namespace:
if not os.path.isfile(tags_csv):
log.warning(f'Missing Tags: {tags_csv}.')
Expand Down

0 comments on commit 35fe2ef

Please sign in to comment.