diff --git a/CHANGELOG.md b/CHANGELOG.md index 51ab0c4aa1eb9..0a17244a7b6c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,10 +30,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed default TQDM to use `tqdm.auto` for prettier outputs in IPython notebooks ([#752](https://github.com/PyTorchLightning/pytorch-lightning/pull/752)) - Changed `pytorch_lightning.logging` to `pytorch_lightning.loggers` ([#767](https://github.com/PyTorchLightning/pytorch-lightning/pull/767)) - Moved the default `tqdm_dict` definition from Trainer to `LightningModule`, so it can be overridden by the user ([#749](https://github.com/PyTorchLightning/pytorch-lightning/pull/749)) +- Moved functionality of `LightningModule.load_from_metrics` into `LightningModule.load_from_checkpoint` ([#995](https://github.com/PyTorchLightning/pytorch-lightning/pull/995)) ### Deprecated -- None +- Deprecated `LightningModule.load_from_metrics` in favour of `LightningModule.load_from_checkpoint` ([#995](https://github.com/PyTorchLightning/pytorch-lightning/pull/995)) ### Removed diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 273b4d07234fd..1c1373a0b6c61 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -5,6 +5,7 @@ import warnings from abc import ABC, abstractmethod from argparse import Namespace +from typing import Optional, Union, Dict, Callable import torch import torch.distributed as dist @@ -1090,77 +1091,35 @@ def val_dataloader(self): @classmethod def load_from_metrics(cls, weights_path, tags_csv, map_location=None): r""" - You should use `load_from_checkpoint` instead! - However, if your .ckpt weights don't have the hyperparameters saved, use this method to pass - in a .csv with the hparams you'd like to use. These will be converted into a argparse.Namespace - and passed into your LightningModule for use. - - Args: - - weights_path (str): Path to a PyTorch checkpoint - tags_csv (str): Path to a .csv with two columns (key, value) as in this - - Example:: - key,value - drop_prob,0.2 - batch_size,32 - - map_location (dict | str | torch.device | function): - If your checkpoint saved a GPU model and you now load on CPUs - or a different number of GPUs, use this to map to the new setup - (example: {'cuda:1':'cuda:0'}). - The behaviour is the same as in - `torch.load `_. - - Return: - LightningModule with loaded weights and hyperparameters (if available). - - Example - ------- - .. code-block:: python - - pretrained_model = MyLightningModule.load_from_metrics( - weights_path='/path/to/pytorch_checkpoint.ckpt', - tags_csv='/path/to/hparams_file.csv', - on_gpu=True, - map_location=None - ) - - # predict - pretrained_model.eval() - pretrained_model.freeze() - y_hat = pretrained_model(x) + Warning: + Deprecated in version 0.7.0. + You should use `load_from_checkpoint` instead. + Will be removed in v0.9.0. """ - - hparams = load_hparams_from_tags_csv(tags_csv) - hparams.__setattr__('on_gpu', False) - - if map_location is not None: - checkpoint = torch.load(weights_path, map_location=map_location) - else: - checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage) - - # add the hparams from csv file to checkpoint - checkpoint['hparams'] = vars(hparams) - - model = cls._load_model_state(checkpoint) - return model + warnings.warn( + "`load_from_metrics` method has been unified with `load_from_checkpoint` in v0.7.0." + " The deprecated method will be removed in v0.9.0.", DeprecationWarning + ) + return cls.load_from_checkpoint(weights_path, tags_csv=tags_csv, map_location=map_location) @classmethod - def load_from_checkpoint(cls, checkpoint_path, map_location=None): + def load_from_checkpoint( + cls, + checkpoint_path: str, + map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, + tags_csv: Optional[str] = None, + ) -> 'LightningModule': r""" Primary way of loading model from a checkpoint. When Lightning saves a checkpoint - it stores the hyperparameters in the checkpoint if you initialized your LightningModule - with an argument called `hparams` which is a Namespace or dictionary of hyperparameters + it stores the hyperparameters in the checkpoint if you initialized your LightningModule + with an argument called `hparams` which is a Namespace (output of using argparse + to parse command line arguments) or dictionary of hyperparameters. Example ------- .. code-block:: python - # -------------- - # Case 1 - # when using Namespace (output of using Argparse to parse command line arguments) from argparse import Namespace hparams = Namespace(**{'learning_rate': 0.1}) @@ -1171,12 +1130,25 @@ def __init__(self, hparams): self.learning_rate = hparams.learning_rate Args: - checkpoint_path (str): Path to checkpoint. - map_location (dict | str | torch.device | function): + checkpoint_path: Path to checkpoint. + map_location: If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in `torch.load `_. + tags_csv: Optional path to a .csv file with two columns (key, value) + as in this example:: + + key,value + drop_prob,0.2 + batch_size,32 + + You most likely won't need this since Lightning will always save the hyperparameters + to the checkpoint. + However, if your checkpoint weights don't have the hyperparameters saved, + use this method to pass in a .csv file with the hparams you'd like to use. + These will be converted into a argparse.Namespace and passed into your + LightningModule for use. Return: LightningModule with loaded weights and hyperparameters (if available). @@ -1185,20 +1157,38 @@ def __init__(self, hparams): ------- .. code-block:: python - # load weights without mapping + # load weights without mapping ... MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') - # load weights mapping all weights from GPU 1 to GPU 0 + # or load weights mapping all weights from GPU 1 to GPU 0 ... map_location = {'cuda:1':'cuda:0'} - MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt', map_location=map_location) + MyLightningModule.load_from_checkpoint( + 'path/to/checkpoint.ckpt', + map_location=map_location + ) - """ + # or load weights and hyperparameters from separate files. + MyLightningModule.load_from_checkpoint( + 'path/to/checkpoint.ckpt', + tags_csv='/path/to/hparams_file.csv' + ) + # predict + pretrained_model.eval() + pretrained_model.freeze() + y_hat = pretrained_model(x) + """ if map_location is not None: checkpoint = torch.load(checkpoint_path, map_location=map_location) else: checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) + if tags_csv is not None: + # add the hparams from csv file to checkpoint + hparams = load_hparams_from_tags_csv(tags_csv) + hparams.__setattr__('on_gpu', False) + checkpoint['hparams'] = vars(hparams) + model = cls._load_model_state(checkpoint) return model diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9a5920a32ebe5..8bdffc703c849 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1211,6 +1211,21 @@ def __call__(self) -> Union[List[DataLoader], DataLoader]: return self.dataloader +class _PatchDataLoader(object): + r''' + Callable object for patching dataloaders passed into trainer.fit(). + Use this class to override model.*_dataloader() and be pickle-compatible. + + Args: + dataloader: Dataloader object to return when called. + ''' + def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): + self.dataloader = dataloader + + def __call__(self) -> Union[List[DataLoader], DataLoader]: + return self.dataloader + + def _set_dataloader(model, dataloader, attribute): r''' Check dataloaders passed to .fit() method if they are pytorch DataLoader diff --git a/tests/models/utils.py b/tests/models/utils.py index df3c04110b57d..db517b3e28157 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -158,7 +158,10 @@ def load_model(exp, root_weights_dir, module_class=LightningTemplateModel, path_ checkpoints = [x for x in os.listdir(root_weights_dir) if '.ckpt' in x] weights_dir = os.path.join(root_weights_dir, checkpoints[0]) - trained_model = module_class.load_from_checkpoint(weights_dir) + trained_model = module_class.load_from_checkpoint( + checkpoint_path=weights_dir, + tags_csv=tags_path + ) assert trained_model is not None, 'loading model failed' diff --git a/tests/test_restore_models.py b/tests/test_restore_models.py index 19acd1fdd60f1..c0c429c240a1d 100644 --- a/tests/test_restore_models.py +++ b/tests/test_restore_models.py @@ -320,8 +320,10 @@ def test_model_saving_loading(tmpdir): # load new model tags_path = tutils.get_data_path(logger, path_dir=tmpdir) tags_path = os.path.join(tags_path, 'meta_tags.csv') - model_2 = LightningTestModel.load_from_metrics(weights_path=new_weights_path, - tags_csv=tags_path) + model_2 = LightningTestModel.load_from_checkpoint( + checkpoint_path=new_weights_path, + tags_csv=tags_path + ) model_2.eval() # make prediction diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 9134378c8fa45..a03e4087a702d 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -61,8 +61,10 @@ class CurrentTestModel(LightTrainDataloader, TestModelBase): # load new model tags_path = tutils.get_data_path(logger, path_dir=tmpdir) tags_path = os.path.join(tags_path, 'meta_tags.csv') - model_2 = LightningTestModel.load_from_metrics(weights_path=new_weights_path, - tags_csv=tags_path) + model_2 = LightningTestModel.load_from_checkpoint( + checkpoint_path=new_weights_path, + tags_csv=tags_path + ) model_2.eval() @@ -99,8 +101,10 @@ class CurrentTestModel(LightTrainDataloader, LightValidationStepMixin, TestModel # load new model tags_path = tutils.get_data_path(logger, path_dir=tmpdir) tags_path = os.path.join(tags_path, 'meta_tags.csv') - model_2 = LightningTestModel.load_from_metrics(weights_path=new_weights_path, - tags_csv=tags_path) + model_2 = LightningTestModel.load_from_checkpoint( + checkpoint_path=new_weights_path, + tags_csv=tags_path + ) model_2.eval()