Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge load functions #995

Merged
merged 15 commits into from
Mar 3, 2020
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed default TQDM to use `tqdm.auto` for prettier outputs in IPython notebooks ([#752](https://github.com/PyTorchLightning/pytorch-lightning/pull/752))
- Changed `pytorch_lightning.logging` to `pytorch_lightning.loggers` ([#767](https://github.com/PyTorchLightning/pytorch-lightning/pull/767))
- Moved the default `tqdm_dict` definition from Trainer to `LightningModule`, so it can be overridden by the user ([#749](https://github.com/PyTorchLightning/pytorch-lightning/pull/749))
- Moved functionality of `LightningModule.load_from_metrics` into `LightningModule.load_from_checkpoint` ([#995](https://github.com/PyTorchLightning/pytorch-lightning/pull/995))

### Deprecated

- None
- Deprecated `LightningModule.load_from_metrics` in favour of `LightningModule.load_from_checkpoint` ([#995](https://github.com/PyTorchLightning/pytorch-lightning/pull/995))

### Removed

Expand Down
122 changes: 56 additions & 66 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings
from abc import ABC, abstractmethod
from argparse import Namespace
from typing import Optional, Union, Dict, Callable

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -1090,77 +1091,35 @@ def val_dataloader(self):
@classmethod
def load_from_metrics(cls, weights_path, tags_csv, map_location=None):
r"""
You should use `load_from_checkpoint` instead!
However, if your .ckpt weights don't have the hyperparameters saved, use this method to pass
in a .csv with the hparams you'd like to use. These will be converted into a argparse.Namespace
and passed into your LightningModule for use.

Args:

weights_path (str): Path to a PyTorch checkpoint
tags_csv (str): Path to a .csv with two columns (key, value) as in this

Example::
key,value
drop_prob,0.2
batch_size,32

map_location (dict | str | torch.device | function):
If your checkpoint saved a GPU model and you now load on CPUs
or a different number of GPUs, use this to map to the new setup
(example: {'cuda:1':'cuda:0'}).
The behaviour is the same as in
`torch.load <https://pytorch.org/docs/stable/torch.html#torch.load>`_.

Return:
LightningModule with loaded weights and hyperparameters (if available).

Example
-------
.. code-block:: python

pretrained_model = MyLightningModule.load_from_metrics(
weights_path='/path/to/pytorch_checkpoint.ckpt',
tags_csv='/path/to/hparams_file.csv',
on_gpu=True,
map_location=None
)

# predict
pretrained_model.eval()
pretrained_model.freeze()
y_hat = pretrained_model(x)
Warning:
Deprecated in version 0.7.0.
You should use `load_from_checkpoint` instead.
Will be removed in v0.9.0.
"""

hparams = load_hparams_from_tags_csv(tags_csv)
hparams.__setattr__('on_gpu', False)

if map_location is not None:
checkpoint = torch.load(weights_path, map_location=map_location)
else:
checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)

# add the hparams from csv file to checkpoint
checkpoint['hparams'] = vars(hparams)

model = cls._load_model_state(checkpoint)
return model
warnings.warn(
"`load_from_metrics` method has been unified with `load_from_checkpoint` in v0.7.0."
" The deprecated method will be removed in v0.9.0.", DeprecationWarning
)
return cls.load_from_checkpoint(weights_path, tags_csv=tags_csv, map_location=map_location)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def load_from_checkpoint(cls, checkpoint_path, map_location=None):
def load_from_checkpoint(
cls,
checkpoint_path: str,
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
tags_csv: Optional[str] = None,
) -> 'LightningModule':
r"""

Primary way of loading model from a checkpoint. When Lightning saves a checkpoint
it stores the hyperparameters in the checkpoint if you initialized your LightningModule
with an argument called `hparams` which is a Namespace or dictionary of hyperparameters
it stores the hyperparameters in the checkpoint if you initialized your LightningModule
with an argument called `hparams` which is a Namespace (output of using argparse
to parse command line arguments) or dictionary of hyperparameters.

Example
-------
.. code-block:: python

# --------------
# Case 1
# when using Namespace (output of using Argparse to parse command line arguments)
from argparse import Namespace
hparams = Namespace(**{'learning_rate': 0.1})

Expand All @@ -1171,12 +1130,25 @@ def __init__(self, hparams):
self.learning_rate = hparams.learning_rate

Args:
checkpoint_path (str): Path to checkpoint.
map_location (dict | str | torch.device | function):
checkpoint_path: Path to checkpoint.
map_location:
If your checkpoint saved a GPU model and you now load on CPUs
or a different number of GPUs, use this to map to the new setup.
The behaviour is the same as in
`torch.load <https://pytorch.org/docs/stable/torch.html#torch.load>`_.
tags_csv: Optional path to a .csv file with two columns (key, value)
as in this example::

key,value
drop_prob,0.2
batch_size,32

You most likely won't need this since Lightning will always save the hyperparameters
to the checkpoint.
However, if your checkpoint weights don't have the hyperparameters saved,
use this method to pass in a .csv file with the hparams you'd like to use.
These will be converted into a argparse.Namespace and passed into your
LightningModule for use.

Return:
LightningModule with loaded weights and hyperparameters (if available).
Expand All @@ -1185,20 +1157,38 @@ def __init__(self, hparams):
-------
.. code-block:: python

# load weights without mapping
# load weights without mapping ...
MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')

# load weights mapping all weights from GPU 1 to GPU 0
# or load weights mapping all weights from GPU 1 to GPU 0 ...
map_location = {'cuda:1':'cuda:0'}
MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt', map_location=map_location)
MyLightningModule.load_from_checkpoint(
'path/to/checkpoint.ckpt',
map_location=map_location
)

"""
# or load weights and hyperparameters from separate files.
MyLightningModule.load_from_checkpoint(
'path/to/checkpoint.ckpt',
tags_csv='/path/to/hparams_file.csv'
)

# predict
pretrained_model.eval()
pretrained_model.freeze()
y_hat = pretrained_model(x)
"""
if map_location is not None:
checkpoint = torch.load(checkpoint_path, map_location=map_location)
else:
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)

if tags_csv is not None:
# add the hparams from csv file to checkpoint
hparams = load_hparams_from_tags_csv(tags_csv)
hparams.__setattr__('on_gpu', False)
checkpoint['hparams'] = vars(hparams)

model = cls._load_model_state(checkpoint)
return model

Expand Down
15 changes: 15 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,21 @@ def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader


class _PatchDataLoader(object):
r'''
Callable object for patching dataloaders passed into trainer.fit().
Use this class to override model.*_dataloader() and be pickle-compatible.

Args:
dataloader: Dataloader object to return when called.
'''
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
self.dataloader = dataloader

def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader


def _set_dataloader(model, dataloader, attribute):
r'''
Check dataloaders passed to .fit() method if they are pytorch DataLoader
Expand Down
5 changes: 4 additions & 1 deletion tests/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,10 @@ def load_model(exp, root_weights_dir, module_class=LightningTemplateModel, path_
checkpoints = [x for x in os.listdir(root_weights_dir) if '.ckpt' in x]
weights_dir = os.path.join(root_weights_dir, checkpoints[0])

trained_model = module_class.load_from_checkpoint(weights_dir)
trained_model = module_class.load_from_checkpoint(
checkpoint_path=weights_dir,
tags_csv=tags_path
)

assert trained_model is not None, 'loading model failed'

Expand Down
6 changes: 4 additions & 2 deletions tests/test_restore_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,10 @@ def test_model_saving_loading(tmpdir):
# load new model
tags_path = tutils.get_data_path(logger, path_dir=tmpdir)
tags_path = os.path.join(tags_path, 'meta_tags.csv')
model_2 = LightningTestModel.load_from_metrics(weights_path=new_weights_path,
tags_csv=tags_path)
model_2 = LightningTestModel.load_from_checkpoint(
checkpoint_path=new_weights_path,
tags_csv=tags_path
)
model_2.eval()

# make prediction
Expand Down
12 changes: 8 additions & 4 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ class CurrentTestModel(LightTrainDataloader, TestModelBase):
# load new model
tags_path = tutils.get_data_path(logger, path_dir=tmpdir)
tags_path = os.path.join(tags_path, 'meta_tags.csv')
model_2 = LightningTestModel.load_from_metrics(weights_path=new_weights_path,
tags_csv=tags_path)
model_2 = LightningTestModel.load_from_checkpoint(
checkpoint_path=new_weights_path,
tags_csv=tags_path
)
model_2.eval()


Expand Down Expand Up @@ -99,8 +101,10 @@ class CurrentTestModel(LightTrainDataloader, LightValidationStepMixin, TestModel
# load new model
tags_path = tutils.get_data_path(logger, path_dir=tmpdir)
tags_path = os.path.join(tags_path, 'meta_tags.csv')
model_2 = LightningTestModel.load_from_metrics(weights_path=new_weights_path,
tags_csv=tags_path)
model_2 = LightningTestModel.load_from_checkpoint(
checkpoint_path=new_weights_path,
tags_csv=tags_path
)
model_2.eval()


Expand Down