diff --git a/CHANGELOG.md b/CHANGELOG.md index ff333c0fcabb8..c9108968845ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed loggers - flushing last logged metrics even before continue, e.g. `trainer.test()` results ([#1459](https://github.com/PyTorchLightning/pytorch-lightning/pull/1459)) +- Fixed LightningModule - Mixing hparams and arguments in `LightningModule.__init__()` crashes load_from_checkpoint() ([#1505](https://github.com/PyTorchLightning/pytorch-lightning/pull/1505)) + - Added a missing call to the `on_before_zero_grad` model hook ([#1493](https://github.com/PyTorchLightning/pytorch-lightning/pull/1493)). - diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 71ea357ea3944..4304bcd45101d 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1434,6 +1434,7 @@ def load_from_checkpoint( it stores the hyperparameters in the checkpoint if you initialized your :class:`LightningModule` with an argument called ``hparams`` which is a :class:`~argparse.Namespace` (output of :meth:`~argparse.ArgumentParser.parse_args` when parsing command line arguments). + Any other arguments specified through \*args and \*\*kwargs will be passed to the model. Example: .. code-block:: python @@ -1493,7 +1494,7 @@ def __init__(self, hparams): # or load passing whatever args the model takes to load MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', - learning_rate=0.1, + learning_rate=0.1, # These arguments will be passed to the model using **kwargs layers=2, pretrained_model=some_model ) @@ -1544,10 +1545,7 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'Ligh # load the state_dict on the model automatically model_args = [hparams] if hparams else [] - if len(model_args) > 0: - model = cls(*model_args) - else: - model = cls(*args, **kwargs) + model = cls(*model_args, *args, **kwargs) model.load_state_dict(checkpoint['state_dict']) # give model a chance to load something