Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge load functions #995

Merged
merged 15 commits into from
Mar 3, 2020
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed default TQDM to use `tqdm.auto` for prettier outputs in IPython notebooks ([#752](https://github.com/PyTorchLightning/pytorch-lightning/pull/752))
- Changed `pytorch_lightning.logging` to `pytorch_lightning.loggers` ([#767](https://github.com/PyTorchLightning/pytorch-lightning/pull/767))
- Moved the default `tqdm_dict` definition from Trainer to `LightningModule`, so it can be overridden by the user ([#749](https://github.com/PyTorchLightning/pytorch-lightning/pull/749))
- Moved functionality of `LightningModule.load_from_metrics` into `LightningModule.load_from_checkpoint` ([#995](https://github.com/PyTorchLightning/pytorch-lightning/pull/995))

### Deprecated

- None
- Deprecated `LightningModule.load_from_metrics` in favour of `LightningModule.load_from_checkpoint` ([#995](https://github.com/PyTorchLightning/pytorch-lightning/pull/995))

### Removed

Expand Down
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ pip install pytorch-lightning
[Copy and run this COLAB!](https://colab.research.google.com/drive/1F_RNcHzTfFuQf-LeKvSlud6x7jXYkG31#scrollTo=HOk9c4_35FKg)

## What is it?
Lightning is a way to organize your PyTorch code to decouple the science code from the engineering. It's more of a style-guide than a framework. By refactoring your code, we can automate most of the non-research code. Lightning guarantees tested, correct, modern best practices for the automated parts.
Lightning is a way to organize your PyTorch code to decouple the science code from the engineering. It's more of a style-guide than a framework.

By refactoring your code, we can automate most of the non-research code. Lightning guarantees tested, correct, modern best practices for the automated parts.

Here's an example of how to organize PyTorch code into the LightningModule.

Expand All @@ -69,7 +71,7 @@ This is how lightning separates the science (red) from the engineering (blue).
![Overview](docs/source/_static/images/pl_overview.gif)

## How much effort is it to convert?
You're probably tired of switching frameworks at this point. But it is a very quick process to refactor into the Lightning format (ie: hours). [Check out this tutorial](https://towardsdatascience.com/how-to-refactor-your-pytorch-code-to-get-these-42-benefits-of-pytorch-lighting-6fdd0dc97538).
You're probably tired of switching frameworks at this point. But it is a very quick process to refactor into the Lightning format (ie: hours). [Check out this tutorial](https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09).

## What are the differences with PyTorch?
If you're wondering what you gain out of refactoring your PyTorch code, [read this comparison!](https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09)
Expand Down
122 changes: 56 additions & 66 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings
from abc import ABC, abstractmethod
from argparse import Namespace
from typing import Optional, Union, Dict, Callable

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -1090,77 +1091,35 @@ def val_dataloader(self):
@classmethod
def load_from_metrics(cls, weights_path, tags_csv, map_location=None):
r"""
You should use `load_from_checkpoint` instead!
However, if your .ckpt weights don't have the hyperparameters saved, use this method to pass
in a .csv with the hparams you'd like to use. These will be converted into a argparse.Namespace
and passed into your LightningModule for use.

Args:

weights_path (str): Path to a PyTorch checkpoint
tags_csv (str): Path to a .csv with two columns (key, value) as in this

Example::
key,value
drop_prob,0.2
batch_size,32

map_location (dict | str | torch.device | function):
If your checkpoint saved a GPU model and you now load on CPUs
or a different number of GPUs, use this to map to the new setup
(example: {'cuda:1':'cuda:0'}).
The behaviour is the same as in
`torch.load <https://pytorch.org/docs/stable/torch.html#torch.load>`_.

Return:
LightningModule with loaded weights and hyperparameters (if available).

Example
-------
.. code-block:: python

pretrained_model = MyLightningModule.load_from_metrics(
weights_path='/path/to/pytorch_checkpoint.ckpt',
tags_csv='/path/to/hparams_file.csv',
on_gpu=True,
map_location=None
)

# predict
pretrained_model.eval()
pretrained_model.freeze()
y_hat = pretrained_model(x)
Warning:
Deprecated in version 0.7.0.
You should use `load_from_checkpoint` instead.
Will be removed in v0.9.0.
"""

hparams = load_hparams_from_tags_csv(tags_csv)
hparams.__setattr__('on_gpu', False)

if map_location is not None:
checkpoint = torch.load(weights_path, map_location=map_location)
else:
checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)

# add the hparams from csv file to checkpoint
checkpoint['hparams'] = vars(hparams)

model = cls._load_model_state(checkpoint)
return model
warnings.warn(
"`load_from_metrics` method has been unified with `load_from_checkpoint` in v0.7.0."
" The deprecated method will be removed in v0.9.0.", DeprecationWarning
)
return cls.load_from_checkpoint(weights_path, tags_csv=tags_csv, map_location=map_location)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def load_from_checkpoint(cls, checkpoint_path, map_location=None):
def load_from_checkpoint(
cls,
checkpoint_path: str,
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
tags_csv: Optional[str] = None,
) -> 'LightningModule':
r"""

Primary way of loading model from a checkpoint. When Lightning saves a checkpoint
it stores the hyperparameters in the checkpoint if you initialized your LightningModule
with an argument called `hparams` which is a Namespace or dictionary of hyperparameters
it stores the hyperparameters in the checkpoint if you initialized your LightningModule
with an argument called `hparams` which is a Namespace (output of using argparse
to parse command line arguments) or dictionary of hyperparameters.

Example
-------
.. code-block:: python

# --------------
# Case 1
# when using Namespace (output of using Argparse to parse command line arguments)
from argparse import Namespace
hparams = Namespace(**{'learning_rate': 0.1})

Expand All @@ -1171,12 +1130,25 @@ def __init__(self, hparams):
self.learning_rate = hparams.learning_rate

Args:
checkpoint_path (str): Path to checkpoint.
map_location (dict | str | torch.device | function):
checkpoint_path: Path to checkpoint.
map_location:
If your checkpoint saved a GPU model and you now load on CPUs
or a different number of GPUs, use this to map to the new setup.
The behaviour is the same as in
`torch.load <https://pytorch.org/docs/stable/torch.html#torch.load>`_.
tags_csv: Optional path to a .csv file with two columns (key, value)
as in this example::

key,value
drop_prob,0.2
batch_size,32

You most likely won't need this since Lightning will always save the hyperparameters
to the checkpoint.
However, if your checkpoint weights don't have the hyperparameters saved,
use this method to pass in a .csv file with the hparams you'd like to use.
These will be converted into a argparse.Namespace and passed into your
LightningModule for use.

Return:
LightningModule with loaded weights and hyperparameters (if available).
Expand All @@ -1185,20 +1157,38 @@ def __init__(self, hparams):
-------
.. code-block:: python

# load weights without mapping
# load weights without mapping ...
MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')

# load weights mapping all weights from GPU 1 to GPU 0
# or load weights mapping all weights from GPU 1 to GPU 0 ...
map_location = {'cuda:1':'cuda:0'}
MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt', map_location=map_location)
MyLightningModule.load_from_checkpoint(
'path/to/checkpoint.ckpt',
map_location=map_location
)

"""
# or load weights and hyperparameters from separate files.
MyLightningModule.load_from_checkpoint(
'path/to/checkpoint.ckpt',
tags_csv='/path/to/hparams_file.csv'
)

# predict
pretrained_model.eval()
pretrained_model.freeze()
y_hat = pretrained_model(x)
"""
if map_location is not None:
checkpoint = torch.load(checkpoint_path, map_location=map_location)
else:
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)

if tags_csv is not None:
# add the hparams from csv file to checkpoint
hparams = load_hparams_from_tags_csv(tags_csv)
hparams.__setattr__('on_gpu', False)
checkpoint['hparams'] = vars(hparams)

model = cls._load_model_state(checkpoint)
return model

Expand Down
30 changes: 18 additions & 12 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,30 +1002,21 @@ def __set_fit_dataloaders(self, model, train_dataloader, val_dataloaders, test_d
m = 'You called .fit() with a train_dataloader but did not define training_step()'
raise MisconfigurationException(m)

def patch_train_dataloader():
return train_dataloader

model.train_dataloader = patch_train_dataloader
model.train_dataloader = _PatchDataLoader(train_dataloader)

if val_dataloaders is not None:
if not self.is_overriden('validation_step', model):
m = 'You called .fit() with a val_dataloaders but did not define validation_step()'
raise MisconfigurationException(m)

def patch_val_dataloader():
return val_dataloaders

model.val_dataloader = patch_val_dataloader
model.val_dataloader = _PatchDataLoader(val_dataloaders)

if test_dataloaders is not None:
if not self.is_overriden('test_step', model):
m = 'You called .fit() with a test_dataloaders but did not define test_step()'
raise MisconfigurationException(m)

def patch_test_dataloader():
return test_dataloaders

model.test_dataloader = patch_test_dataloader
model.test_dataloader = _PatchDataLoader(test_dataloaders)

def init_optimizers(
self,
Expand Down Expand Up @@ -1189,6 +1180,21 @@ def test(self, model: Optional[LightningModule] = None):
self.run_evaluation(test_mode=True)


class _PatchDataLoader(object):
r'''
Callable object for patching dataloaders passed into trainer.fit().
Use this class to override model.*_dataloader() and be pickle-compatible.

Args:
dataloader: Dataloader object to return when called.
'''
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
self.dataloader = dataloader

def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader


def _set_dataloader(model, dataloader, attribute):
r'''
Check dataloaders passed to .fit() method if they are pytorch DataLoader
Expand Down
6 changes: 4 additions & 2 deletions tests/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,10 @@ def load_model(exp, root_weights_dir, module_class=LightningTemplateModel, path_
checkpoints = [x for x in os.listdir(root_weights_dir) if '.ckpt' in x]
weights_dir = os.path.join(root_weights_dir, checkpoints[0])

trained_model = module_class.load_from_metrics(weights_path=weights_dir,
tags_csv=tags_path)
trained_model = module_class.load_from_checkpoint(
checkpoint_path=weights_dir,
tags_csv=tags_path
)

assert trained_model is not None, 'loading model failed'

Expand Down
25 changes: 25 additions & 0 deletions tests/test_gpu_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,31 @@ def test_multi_gpu_model_ddp(tmpdir):
tutils.run_model_test(trainer_options, model)


def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
"""Make sure DDP works with dataloaders passed to fit()"""
if not tutils.can_run_gpu_test():
return

tutils.reset_seed()
tutils.set_random_master_port()

model, hparams = tutils.get_model()
trainer_options = dict(default_save_path=tmpdir,
show_progress_bar=False,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
gpus=[0, 1],
distributed_backend='ddp')

fit_options = dict(train_dataloader=model.train_dataloader(),
val_dataloaders=model.val_dataloader())

trainer = Trainer(**trainer_options)
result = trainer.fit(model, **fit_options)
assert result == 1, "DDP doesn't work with dataloaders passed to fit()."


def test_optimizer_return_options():
tutils.reset_seed()

Expand Down
6 changes: 4 additions & 2 deletions tests/test_restore_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,10 @@ def test_model_saving_loading(tmpdir):
# load new model
tags_path = tutils.get_data_path(logger, path_dir=tmpdir)
tags_path = os.path.join(tags_path, 'meta_tags.csv')
model_2 = LightningTestModel.load_from_metrics(weights_path=new_weights_path,
tags_csv=tags_path)
model_2 = LightningTestModel.load_from_checkpoint(
checkpoint_path=new_weights_path,
tags_csv=tags_path
)
model_2.eval()

# make prediction
Expand Down
12 changes: 8 additions & 4 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ class CurrentTestModel(LightTrainDataloader, TestModelBase):
# load new model
tags_path = tutils.get_data_path(logger, path_dir=tmpdir)
tags_path = os.path.join(tags_path, 'meta_tags.csv')
model_2 = LightningTestModel.load_from_metrics(weights_path=new_weights_path,
tags_csv=tags_path)
model_2 = LightningTestModel.load_from_checkpoint(
checkpoint_path=new_weights_path,
tags_csv=tags_path
)
model_2.eval()


Expand Down Expand Up @@ -99,8 +101,10 @@ class CurrentTestModel(LightTrainDataloader, LightValidationStepMixin, TestModel
# load new model
tags_path = tutils.get_data_path(logger, path_dir=tmpdir)
tags_path = os.path.join(tags_path, 'meta_tags.csv')
model_2 = LightningTestModel.load_from_metrics(weights_path=new_weights_path,
tags_csv=tags_path)
model_2 = LightningTestModel.load_from_checkpoint(
checkpoint_path=new_weights_path,
tags_csv=tags_path
)
model_2.eval()


Expand Down