Skip to content

Commit

Permalink
to_torchscript method for LightningModule (#3258)
Browse files Browse the repository at this point in the history
* script

* docs

* simple test

* move test

* fix doctest

* no grad context

* extend tests


test


test

* datamodule test

* clean up test

* docs

* name

* fix import

* update changelog

* fix import

* skip pytorch 1.3 in test

* update codeblock

* skip bugged 1.4

* typehints

* doctest not working on all pytorch versions

* rename TestGAN to prevent pytest interference

* add note about pytorch version

* fix torchscript version inconsistency in tests

* reset training state + tests

* update docstring

* Apply suggestions from code review

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* update docstring, dict return

* add docs to index

* add link

* doc eval mode

* forward

* optional save to file path

* optional

* test torchscript device

* test save load with file path

* pep

* str

* Commit typing suggestion

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>

* skip test if cuda not available

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
  • Loading branch information
4 people committed Sep 3, 2020
1 parent 4a22fca commit 4ad5a78
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 6 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528))

- Added `LightningModule.to_torchscript` to support exporting as `ScriptModule` ([#3258](https://github.com/PyTorchLightning/pytorch-lightning/pull/3258/))

### Changed

- Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251))
Expand Down
6 changes: 6 additions & 0 deletions docs/source/lightning-module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,12 @@ to_onnx
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.to_onnx
:noindex:

to_torchscript
~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.core.lightning.LightningModule.to_torchscript
:noindex:

unfreeze
~~~~~~~~

Expand Down
18 changes: 18 additions & 0 deletions docs/source/production_inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,21 @@ Once you have the exported model, you can run it on your ONNX runtime in the fol
input_name = ort_session.get_inputs()[0].name
ort_inputs = {input_name: np.random.randn(1, 64).astype(np.float32)}
ort_outs = ort_session.run(None, ort_inputs)
Exporting to TorchScript
------------------------

TorchScript allows you to serialize your models in a way that it can be loaded in non-Python environments.
The LightningModule has a handy method :meth:`~pytorch_lightning.core.lightning.LightningModule.to_torchscript`
that returns a scripted module which you can save or directly use.

.. code-block:: python
model = SimpleModel()
script = model.to_torchscript()
# save for use in production environment
torch.jit.save(script, "model.pt")
It is recommended that you install the latest supported version of PyTorch to use this feature without limitations.
51 changes: 50 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import torch
import torch.distributed as torch_distrib
from torch import Tensor
from torch import Tensor, ScriptModule
from torch.nn import Module
from torch.nn.parallel import DistributedDataParallel
from torch.optim.optimizer import Optimizer
Expand Down Expand Up @@ -184,6 +184,7 @@ def forward(self, batch):
return logits
"""
return super().forward(*args, **kwargs)

def training_step(self, *args, **kwargs):
r"""
Expand Down Expand Up @@ -1729,6 +1730,54 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg

torch.onnx.export(self, input_data, file_path, **kwargs)

def to_torchscript(self, file_path: Optional[str] = None, **kwargs) -> Union[ScriptModule, Dict[str, ScriptModule]]:
"""
By default compiles the whole model to a :class:`~torch.jit.ScriptModule`.
If you would like to customize the modules that are scripted or you want to use tracing
you should override this method. In case you want to return multiple modules, we
recommend using a dictionary.
Args:
file_path: Path where to save the torchscript. Default: None (no file saved).
**kwargs: Additional arguments that will be passed to the :func:`torch.jit.save` function.
Note:
- Requires the implementation of the
:meth:`~pytorch_lightning.core.lightning.LightningModule.forward` method.
- The exported script will be set to evaluation mode.
- It is recommended that you install the latest supported version of PyTorch
to use this feature without limitations. See also the :mod:`torch.jit`
documentation for supported features.
Example:
>>> class SimpleModel(LightningModule):
... def __init__(self):
... super().__init__()
... self.l1 = torch.nn.Linear(in_features=64, out_features=4)
...
... def forward(self, x):
... return torch.relu(self.l1(x.view(x.size(0), -1)))
...
>>> model = SimpleModel()
>>> torch.jit.save(model.to_torchscript(), "model.pt") # doctest: +SKIP
>>> os.path.isfile("model.pt") # doctest: +SKIP
True
Return:
This LightningModule as a torchscript, regardless of whether file_path is
defined or not.
"""

mode = self.training
with torch.no_grad():
scripted_module = torch.jit.script(self.eval(), **kwargs)
self.train(mode)

if file_path is not None:
torch.jit.save(scripted_module, file_path)

return scripted_module

@property
def hparams(self) -> Union[AttributeDict, str]:
if not hasattr(self, '_hparams'):
Expand Down
8 changes: 5 additions & 3 deletions tests/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


class Generator(nn.Module):
def __init__(self, latent_dim: tuple, img_shape: tuple):
def __init__(self, latent_dim: int, img_shape: tuple):
super().__init__()
self.img_shape = img_shape

Expand Down Expand Up @@ -64,10 +64,10 @@ def forward(self, img):
return validity


class TestGAN(LightningModule):
class BasicGAN(LightningModule):
"""Implements a basic GAN for the purpose of illustrating multiple optimizers."""

def __init__(self, hidden_dim, learning_rate, b1, b2, **kwargs):
def __init__(self, hidden_dim: int = 128, learning_rate: float = 0.001, b1: float = 0.5, b2: float = 0.999, **kwargs):
super().__init__()
self.hidden_dim = hidden_dim
self.learning_rate = learning_rate
Expand Down Expand Up @@ -163,6 +163,7 @@ def __init__(self):
super().__init__()
self.rnn = nn.LSTM(10, 20, batch_first=True)
self.linear_out = nn.Linear(in_features=20, out_features=5)
self.example_input_array = torch.rand(2, 3, 10)

def forward(self, x):
seq, last = self.rnn(x)
Expand All @@ -189,6 +190,7 @@ def __init__(self):
self.c_d1_bn = nn.BatchNorm1d(128)
self.c_d1_drop = nn.Dropout(0.3)
self.c_d2 = nn.Linear(in_features=128, out_features=10)
self.example_input_array = torch.rand(2, 1, 28, 28)

def forward(self, x):
x = x.view(x.size(0), -1)
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer
from tests.base import EvalModelTemplate
from tests.base.models import TestGAN
from tests.base.models import BasicGAN

try:
from horovod.common.util import nccl_built
Expand Down Expand Up @@ -145,7 +145,7 @@ def validation_step(self, batch, *args, **kwargs):

@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
def test_horovod_multi_optimizer(tmpdir):
model = TestGAN(**EvalModelTemplate.get_default_hparams())
model = BasicGAN(**EvalModelTemplate.get_default_hparams())

# fit model
trainer = Trainer(
Expand Down
88 changes: 88 additions & 0 deletions tests/models/test_torchscript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from distutils.version import LooseVersion

import pytest
import torch

from tests.base import EvalModelTemplate
from tests.base.datamodules import TrialMNISTDataModule
from tests.base.models import ParityModuleRNN, BasicGAN


@pytest.mark.parametrize("modelclass", [
EvalModelTemplate,
ParityModuleRNN,
BasicGAN,
])
def test_torchscript_input_output(modelclass):
""" Test that scripted LightningModule forward works. """
model = modelclass()
script = model.to_torchscript()
assert isinstance(script, torch.jit.ScriptModule)
model.eval()
model_output = model(model.example_input_array)
script_output = script(model.example_input_array)
assert torch.allclose(script_output, model_output)


@pytest.mark.parametrize("device", [
torch.device("cpu"),
torch.device("cuda", 0)
])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
def test_torchscript_device(device):
""" Test that scripted module is on the correct device. """
model = EvalModelTemplate().to(device)
script = model.to_torchscript()
assert next(script.parameters()).device == device
script_output = script(model.example_input_array.to(device))
assert script_output.device == device


def test_torchscript_retain_training_state():
""" Test that torchscript export does not alter the training mode of original model. """
model = EvalModelTemplate()
model.train(True)
script = model.to_torchscript()
assert model.training
assert not script.training
model.train(False)
_ = model.to_torchscript()
assert not model.training
assert not script.training


@pytest.mark.parametrize("modelclass", [
EvalModelTemplate,
ParityModuleRNN,
BasicGAN,
])
def test_torchscript_properties(modelclass):
""" Test that scripted LightningModule has unnecessary methods removed. """
model = modelclass()
model.datamodule = TrialMNISTDataModule()
script = model.to_torchscript()
assert not hasattr(script, "datamodule")
assert not hasattr(model, "batch_size") or hasattr(script, "batch_size")
assert not hasattr(model, "learning_rate") or hasattr(script, "learning_rate")

if LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
# only on torch >= 1.4 do these unused methods get removed
assert not callable(getattr(script, "training_step", None))


@pytest.mark.parametrize("modelclass", [
EvalModelTemplate,
ParityModuleRNN,
BasicGAN,
])
@pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("1.5.0"),
reason="torch.save/load has bug loading script modules on torch <= 1.4",
)
def test_torchscript_save_load(tmpdir, modelclass):
""" Test that scripted LightningModules is correctly saved and can be loaded. """
model = modelclass()
output_file = str(tmpdir / "model.pt")
script = model.to_torchscript(file_path=output_file)
loaded_script = torch.jit.load(output_file)
assert torch.allclose(next(script.parameters()), next(loaded_script.parameters()))

0 comments on commit 4ad5a78

Please sign in to comment.