Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

to_torchscript method for LightningModule #3258

Merged
merged 40 commits into from
Sep 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
da527e8
script
awaelchli Aug 29, 2020
ea4b4e1
docs
awaelchli Aug 29, 2020
0ad4af0
simple test
awaelchli Aug 29, 2020
246e875
move test
awaelchli Aug 29, 2020
cee5416
fix doctest
awaelchli Aug 29, 2020
192f10f
no grad context
awaelchli Aug 29, 2020
c12aea2
extend tests
awaelchli Aug 29, 2020
d6f6437
datamodule test
awaelchli Aug 29, 2020
49d2166
clean up test
awaelchli Aug 29, 2020
e167490
docs
awaelchli Aug 29, 2020
7223e98
name
awaelchli Aug 29, 2020
4f0dbbc
fix import
awaelchli Aug 29, 2020
e5da609
update changelog
awaelchli Aug 29, 2020
92d2c5a
fix import
awaelchli Aug 29, 2020
26ad185
skip pytorch 1.3 in test
awaelchli Aug 29, 2020
ace4f4f
update codeblock
Aug 30, 2020
b22dfc6
skip bugged 1.4
Aug 30, 2020
81bca94
typehints
Aug 30, 2020
b883e97
doctest not working on all pytorch versions
Aug 30, 2020
b7be254
rename TestGAN to prevent pytest interference
Aug 30, 2020
2af314b
add note about pytorch version
Aug 30, 2020
5d76d55
fix torchscript version inconsistency in tests
Aug 30, 2020
4fd6cee
reset training state + tests
Aug 30, 2020
26c28f9
update docstring
awaelchli Aug 30, 2020
9366029
Apply suggestions from code review
Borda Sep 1, 2020
721cb5e
update docstring, dict return
awaelchli Sep 1, 2020
46bbdab
Merge remote-tracking branch 'PyTorchLightning/feature/torchscript' i…
awaelchli Sep 1, 2020
da0f4f8
Merge branch 'master' into feature/torchscript
awaelchli Sep 1, 2020
3598f21
add docs to index
awaelchli Sep 1, 2020
c1da6bd
add link
awaelchli Sep 1, 2020
7d1124a
doc eval mode
awaelchli Sep 1, 2020
7f180c0
forward
awaelchli Sep 1, 2020
9687a29
optional save to file path
awaelchli Sep 3, 2020
868f8b4
optional
awaelchli Sep 3, 2020
713e477
test torchscript device
awaelchli Sep 3, 2020
c1fc408
test save load with file path
awaelchli Sep 3, 2020
f328959
pep
awaelchli Sep 3, 2020
1edd4e4
str
awaelchli Sep 3, 2020
1263b57
Commit typing suggestion
justusschock Sep 3, 2020
78e2d60
skip test if cuda not available
awaelchli Sep 3, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528/))

- Added `LightningModule.to_torchscript` to support exporting as `ScriptModule` ([#3258](https://github.com/PyTorchLightning/pytorch-lightning/pull/3258/))

### Changed


Expand Down
6 changes: 6 additions & 0 deletions docs/source/lightning-module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,12 @@ to_onnx
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.to_onnx
:noindex:

to_torchscript
~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.core.lightning.LightningModule.to_torchscript
:noindex:

unfreeze
~~~~~~~~

Expand Down
18 changes: 18 additions & 0 deletions docs/source/production_inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,21 @@ Once you have the exported model, you can run it on your ONNX runtime in the fol
input_name = ort_session.get_inputs()[0].name
ort_inputs = {input_name: np.random.randn(1, 64).astype(np.float32)}
ort_outs = ort_session.run(None, ort_inputs)


Exporting to TorchScript
------------------------

TorchScript allows you to serialize your models in a way that it can be loaded in non-Python environments.
The LightningModule has a handy method :meth:`~pytorch_lightning.core.lightning.LightningModule.to_torchscript`
that returns a scripted module which you can save or directly use.

.. code-block:: python

model = SimpleModel()
script = model.to_torchscript()

# save for use in production environment
torch.jit.save(script, "model.pt")
justusschock marked this conversation as resolved.
Show resolved Hide resolved

It is recommended that you install the latest supported version of PyTorch to use this feature without limitations.
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
51 changes: 50 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import torch
import torch.distributed as torch_distrib
from torch import Tensor
from torch import Tensor, ScriptModule
from torch.nn import Module
from torch.nn.parallel import DistributedDataParallel
from torch.optim.optimizer import Optimizer
Expand Down Expand Up @@ -184,6 +184,7 @@ def forward(self, batch):
return logits

"""
return super().forward(*args, **kwargs)

def training_step(self, *args, **kwargs):
r"""
Expand Down Expand Up @@ -1729,6 +1730,54 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg

torch.onnx.export(self, input_data, file_path, **kwargs)

def to_torchscript(self, file_path: Optional[str] = None, **kwargs) -> Union[ScriptModule, Dict[str, ScriptModule]]:
"""
By default compiles the whole model to a :class:`~torch.jit.ScriptModule`.
If you would like to customize the modules that are scripted or you want to use tracing
you should override this method. In case you want to return multiple modules, we
recommend using a dictionary.

Args:
file_path: Path where to save the torchscript. Default: None (no file saved).
**kwargs: Additional arguments that will be passed to the :func:`torch.jit.save` function.

Note:
- Requires the implementation of the
:meth:`~pytorch_lightning.core.lightning.LightningModule.forward` method.
- The exported script will be set to evaluation mode.
- It is recommended that you install the latest supported version of PyTorch
to use this feature without limitations. See also the :mod:`torch.jit`
documentation for supported features.

Example:
>>> class SimpleModel(LightningModule):
... def __init__(self):
... super().__init__()
... self.l1 = torch.nn.Linear(in_features=64, out_features=4)
...
... def forward(self, x):
... return torch.relu(self.l1(x.view(x.size(0), -1)))
...
>>> model = SimpleModel()
>>> torch.jit.save(model.to_torchscript(), "model.pt") # doctest: +SKIP
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
>>> os.path.isfile("model.pt") # doctest: +SKIP
True

Return:
This LightningModule as a torchscript, regardless of whether file_path is
defined or not.
"""

mode = self.training
with torch.no_grad():
scripted_module = torch.jit.script(self.eval(), **kwargs)
self.train(mode)

if file_path is not None:
torch.jit.save(scripted_module, file_path)

return scripted_module

@property
def hparams(self) -> Union[AttributeDict, str]:
if not hasattr(self, '_hparams'):
Expand Down
8 changes: 5 additions & 3 deletions tests/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


class Generator(nn.Module):
def __init__(self, latent_dim: tuple, img_shape: tuple):
def __init__(self, latent_dim: int, img_shape: tuple):
super().__init__()
self.img_shape = img_shape

Expand Down Expand Up @@ -64,10 +64,10 @@ def forward(self, img):
return validity


class TestGAN(LightningModule):
class BasicGAN(LightningModule):
"""Implements a basic GAN for the purpose of illustrating multiple optimizers."""

def __init__(self, hidden_dim, learning_rate, b1, b2, **kwargs):
def __init__(self, hidden_dim: int = 128, learning_rate: float = 0.001, b1: float = 0.5, b2: float = 0.999, **kwargs):
super().__init__()
self.hidden_dim = hidden_dim
self.learning_rate = learning_rate
Expand Down Expand Up @@ -163,6 +163,7 @@ def __init__(self):
super().__init__()
self.rnn = nn.LSTM(10, 20, batch_first=True)
self.linear_out = nn.Linear(in_features=20, out_features=5)
self.example_input_array = torch.rand(2, 3, 10)

def forward(self, x):
seq, last = self.rnn(x)
Expand All @@ -189,6 +190,7 @@ def __init__(self):
self.c_d1_bn = nn.BatchNorm1d(128)
self.c_d1_drop = nn.Dropout(0.3)
self.c_d2 = nn.Linear(in_features=128, out_features=10)
self.example_input_array = torch.rand(2, 1, 28, 28)

def forward(self, x):
x = x.view(x.size(0), -1)
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer
from tests.base import EvalModelTemplate
from tests.base.models import TestGAN
from tests.base.models import BasicGAN

try:
from horovod.common.util import nccl_built
Expand Down Expand Up @@ -145,7 +145,7 @@ def validation_step(self, batch, *args, **kwargs):

@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
def test_horovod_multi_optimizer(tmpdir):
model = TestGAN(**EvalModelTemplate.get_default_hparams())
model = BasicGAN(**EvalModelTemplate.get_default_hparams())

# fit model
trainer = Trainer(
Expand Down
88 changes: 88 additions & 0 deletions tests/models/test_torchscript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from distutils.version import LooseVersion

import pytest
import torch

from tests.base import EvalModelTemplate
from tests.base.datamodules import TrialMNISTDataModule
from tests.base.models import ParityModuleRNN, BasicGAN


@pytest.mark.parametrize("modelclass", [
EvalModelTemplate,
ParityModuleRNN,
BasicGAN,
])
def test_torchscript_input_output(modelclass):
""" Test that scripted LightningModule forward works. """
model = modelclass()
script = model.to_torchscript()
assert isinstance(script, torch.jit.ScriptModule)
model.eval()
model_output = model(model.example_input_array)
script_output = script(model.example_input_array)
assert torch.allclose(script_output, model_output)


@pytest.mark.parametrize("device", [
torch.device("cpu"),
torch.device("cuda", 0)
])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
def test_torchscript_device(device):
""" Test that scripted module is on the correct device. """
model = EvalModelTemplate().to(device)
script = model.to_torchscript()
assert next(script.parameters()).device == device
script_output = script(model.example_input_array.to(device))
assert script_output.device == device


def test_torchscript_retain_training_state():
""" Test that torchscript export does not alter the training mode of original model. """
model = EvalModelTemplate()
model.train(True)
script = model.to_torchscript()
assert model.training
assert not script.training
model.train(False)
_ = model.to_torchscript()
assert not model.training
assert not script.training


@pytest.mark.parametrize("modelclass", [
EvalModelTemplate,
ParityModuleRNN,
BasicGAN,
])
def test_torchscript_properties(modelclass):
""" Test that scripted LightningModule has unnecessary methods removed. """
model = modelclass()
model.datamodule = TrialMNISTDataModule()
script = model.to_torchscript()
assert not hasattr(script, "datamodule")
assert not hasattr(model, "batch_size") or hasattr(script, "batch_size")
assert not hasattr(model, "learning_rate") or hasattr(script, "learning_rate")

if LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
# only on torch >= 1.4 do these unused methods get removed
assert not callable(getattr(script, "training_step", None))


@pytest.mark.parametrize("modelclass", [
EvalModelTemplate,
ParityModuleRNN,
BasicGAN,
])
@pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("1.5.0"),
reason="torch.save/load has bug loading script modules on torch <= 1.4",
)
def test_torchscript_save_load(tmpdir, modelclass):
""" Test that scripted LightningModules is correctly saved and can be loaded. """
model = modelclass()
output_file = str(tmpdir / "model.pt")
script = model.to_torchscript(file_path=output_file)
loaded_script = torch.jit.load(output_file)
assert torch.allclose(next(script.parameters()), next(loaded_script.parameters()))