Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

to_torchscript method for LightningModule #3258

Merged
merged 40 commits into from
Sep 3, 2020
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
da527e8
script
awaelchli Aug 29, 2020
ea4b4e1
docs
awaelchli Aug 29, 2020
0ad4af0
simple test
awaelchli Aug 29, 2020
246e875
move test
awaelchli Aug 29, 2020
cee5416
fix doctest
awaelchli Aug 29, 2020
192f10f
no grad context
awaelchli Aug 29, 2020
c12aea2
extend tests
awaelchli Aug 29, 2020
d6f6437
datamodule test
awaelchli Aug 29, 2020
49d2166
clean up test
awaelchli Aug 29, 2020
e167490
docs
awaelchli Aug 29, 2020
7223e98
name
awaelchli Aug 29, 2020
4f0dbbc
fix import
awaelchli Aug 29, 2020
e5da609
update changelog
awaelchli Aug 29, 2020
92d2c5a
fix import
awaelchli Aug 29, 2020
26ad185
skip pytorch 1.3 in test
awaelchli Aug 29, 2020
ace4f4f
update codeblock
Aug 30, 2020
b22dfc6
skip bugged 1.4
Aug 30, 2020
81bca94
typehints
Aug 30, 2020
b883e97
doctest not working on all pytorch versions
Aug 30, 2020
b7be254
rename TestGAN to prevent pytest interference
Aug 30, 2020
2af314b
add note about pytorch version
Aug 30, 2020
5d76d55
fix torchscript version inconsistency in tests
Aug 30, 2020
4fd6cee
reset training state + tests
Aug 30, 2020
26c28f9
update docstring
awaelchli Aug 30, 2020
9366029
Apply suggestions from code review
Borda Sep 1, 2020
721cb5e
update docstring, dict return
awaelchli Sep 1, 2020
46bbdab
Merge remote-tracking branch 'PyTorchLightning/feature/torchscript' i…
awaelchli Sep 1, 2020
da0f4f8
Merge branch 'master' into feature/torchscript
awaelchli Sep 1, 2020
3598f21
add docs to index
awaelchli Sep 1, 2020
c1da6bd
add link
awaelchli Sep 1, 2020
7d1124a
doc eval mode
awaelchli Sep 1, 2020
7f180c0
forward
awaelchli Sep 1, 2020
9687a29
optional save to file path
awaelchli Sep 3, 2020
868f8b4
optional
awaelchli Sep 3, 2020
713e477
test torchscript device
awaelchli Sep 3, 2020
c1fc408
test save load with file path
awaelchli Sep 3, 2020
f328959
pep
awaelchli Sep 3, 2020
1edd4e4
str
awaelchli Sep 3, 2020
1263b57
Commit typing suggestion
justusschock Sep 3, 2020
78e2d60
skip test if cuda not available
awaelchli Sep 3, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions docs/source/production_inference.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
.. testsetup:: *

import torch
pytorch_lightning.core.lightning import LightningModule

class SimpleModel(LightningModule):

def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(in_features=64, out_features=4)

def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))

.. _production-inference:

Inference in Production
Expand Down Expand Up @@ -28,3 +42,19 @@ Once you have the exported model, you can run it on your ONNX runtime in the fol
input_name = ort_session.get_inputs()[0].name
ort_inputs = {input_name: np.random.randn(1, 64).astype(np.float32)}
ort_outs = ort_session.run(None, ort_inputs)


Exporting to TorchScript
------------------------

TorchScript allows you to serialize your models in a way that it can be loaded in non-Python environments.
The LightningModule has a handy method :meth:`~pytorch_lightning.core.lightning.LightningModule.to_torchscript`
that returns a scripted module which you can save or directly use.

.. testcode::

model = SimpleModel()
script = model.to_torchscript()

# save for use in production environment
torch.jit.save(script, "model.pt")
justusschock marked this conversation as resolved.
Show resolved Hide resolved
29 changes: 29 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def forward(self, batch):
return logits

"""
return super().forward(*args, **kwargs)

def training_step(self, *args, **kwargs):
r"""
Expand Down Expand Up @@ -1729,6 +1730,34 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg

torch.onnx.export(self, input_data, file_path, **kwargs)

def to_torchscript(self) -> torch.jit.ScriptModule:
"""
Compiles the model to a :class:`~torch.jit.ScriptModule`.
This can be overridden to support custom TorchScript module export.

Note:
Requires the implementation of the :meth:`LightningModule.forward` method.

Example:
>>> class SimpleModel(LightningModule):
... def __init__(self):
... super().__init__()
... self.l1 = torch.nn.Linear(in_features=64, out_features=4)
...
... def forward(self, x):
... return torch.relu(self.l1(x.view(x.size(0), -1)))
...
>>> model = SimpleModel()
>>> torch.jit.save(model.to_torchscript(), "model.pt")
>>> os.path.isfile("model.pt")
True
"""
mode = self.training
with torch.no_grad():
scripted_module = torch.jit.script(self.eval())
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
self.training = mode
return scripted_module

@property
def hparams(self) -> Union[AttributeDict, str]:
if not hasattr(self, '_hparams'):
Expand Down
6 changes: 4 additions & 2 deletions tests/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


class Generator(nn.Module):
def __init__(self, latent_dim: tuple, img_shape: tuple):
def __init__(self, latent_dim: int, img_shape: tuple):
super().__init__()
self.img_shape = img_shape

Expand Down Expand Up @@ -67,7 +67,7 @@ def forward(self, img):
class TestGAN(LightningModule):
"""Implements a basic GAN for the purpose of illustrating multiple optimizers."""

def __init__(self, hidden_dim, learning_rate, b1, b2, **kwargs):
def __init__(self, hidden_dim=128, learning_rate=0.001, b1=0.5, b2=0.999, **kwargs):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()
self.hidden_dim = hidden_dim
self.learning_rate = learning_rate
Expand Down Expand Up @@ -163,6 +163,7 @@ def __init__(self):
super().__init__()
self.rnn = nn.LSTM(10, 20, batch_first=True)
self.linear_out = nn.Linear(in_features=20, out_features=5)
self.example_input_array = torch.rand(2, 3, 10)

def forward(self, x):
seq, last = self.rnn(x)
Expand All @@ -189,6 +190,7 @@ def __init__(self):
self.c_d1_bn = nn.BatchNorm1d(128)
self.c_d1_drop = nn.Dropout(0.3)
self.c_d2 = nn.Linear(in_features=128, out_features=10)
self.example_input_array = torch.rand(2, 1, 28, 28)

def forward(self, x):
x = x.view(x.size(0), -1)
Expand Down
52 changes: 52 additions & 0 deletions tests/models/test_torchscript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest
import torch

from tests.base import EvalModelTemplate
from tests.base.datamodules import TrialMNISTDataModule
from tests.base.models import ParityModuleRNN, TestGAN


@pytest.mark.parametrize("modelclass", [
EvalModelTemplate,
ParityModuleRNN,
TestGAN,
])
def test_torchscript_input_output(modelclass):
""" Test that scripted LightningModule forward works. """
model = modelclass()
script = model.to_torchscript()
assert isinstance(script, torch.jit.ScriptModule)
model_output = model(model.example_input_array)
script_output = script(model.example_input_array)
assert torch.allclose(script_output, model_output)


@pytest.mark.parametrize("modelclass", [
EvalModelTemplate,
ParityModuleRNN,
TestGAN,
])
def test_torchscript_properties(modelclass):
""" Test that scripted LightningModule has unnecessary methods removed. """
model = modelclass()
model.datamodule = TrialMNISTDataModule()
script = model.to_torchscript()
assert not hasattr(script, "datamodule")
assert not hasattr(model, "batch_size") or hasattr(script, "batch_size")
assert not hasattr(model, "learning_rate") or hasattr(script, "learning_rate")
assert not callable(getattr(script, "training_step", None))


@pytest.mark.parametrize("modelclass", [
EvalModelTemplate,
ParityModuleRNN,
TestGAN,
])
def test_torchscript_save_load(tmpdir, modelclass):
""" Test that scripted LightningModules can be saved and loaded. """
model = modelclass()
script = model.to_torchscript()
assert isinstance(script, torch.jit.ScriptModule)
output_file = str(tmpdir / "model.pt")
torch.jit.save(script, output_file)
torch.jit.load(output_file)