diff --git a/CHANGELOG.md b/CHANGELOG.md index 4bc697f5e4008..a39ad11f69d59 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528)) +- Added `LightningModule.to_torchscript` to support exporting as `ScriptModule` ([#3258](https://github.com/PyTorchLightning/pytorch-lightning/pull/3258/)) + ### Changed - Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251)) diff --git a/docs/source/lightning-module.rst b/docs/source/lightning-module.rst index 6be28af55523d..98312ccc84820 100644 --- a/docs/source/lightning-module.rst +++ b/docs/source/lightning-module.rst @@ -770,6 +770,12 @@ to_onnx .. autofunction:: pytorch_lightning.core.lightning.LightningModule.to_onnx :noindex: +to_torchscript +~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.core.lightning.LightningModule.to_torchscript + :noindex: + unfreeze ~~~~~~~~ diff --git a/docs/source/production_inference.rst b/docs/source/production_inference.rst index d3ab26b93e419..f81ffe5c8b979 100644 --- a/docs/source/production_inference.rst +++ b/docs/source/production_inference.rst @@ -28,3 +28,21 @@ Once you have the exported model, you can run it on your ONNX runtime in the fol input_name = ort_session.get_inputs()[0].name ort_inputs = {input_name: np.random.randn(1, 64).astype(np.float32)} ort_outs = ort_session.run(None, ort_inputs) + + +Exporting to TorchScript +------------------------ + +TorchScript allows you to serialize your models in a way that it can be loaded in non-Python environments. +The LightningModule has a handy method :meth:`~pytorch_lightning.core.lightning.LightningModule.to_torchscript` +that returns a scripted module which you can save or directly use. + +.. code-block:: python + + model = SimpleModel() + script = model.to_torchscript() + + # save for use in production environment + torch.jit.save(script, "model.pt") + +It is recommended that you install the latest supported version of PyTorch to use this feature without limitations. diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f7723e6945e23..9b41daa1cd470 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -23,7 +23,7 @@ import torch import torch.distributed as torch_distrib -from torch import Tensor +from torch import Tensor, ScriptModule from torch.nn import Module from torch.nn.parallel import DistributedDataParallel from torch.optim.optimizer import Optimizer @@ -184,6 +184,7 @@ def forward(self, batch): return logits """ + return super().forward(*args, **kwargs) def training_step(self, *args, **kwargs): r""" @@ -1729,6 +1730,54 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg torch.onnx.export(self, input_data, file_path, **kwargs) + def to_torchscript(self, file_path: Optional[str] = None, **kwargs) -> Union[ScriptModule, Dict[str, ScriptModule]]: + """ + By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. + If you would like to customize the modules that are scripted or you want to use tracing + you should override this method. In case you want to return multiple modules, we + recommend using a dictionary. + + Args: + file_path: Path where to save the torchscript. Default: None (no file saved). + **kwargs: Additional arguments that will be passed to the :func:`torch.jit.save` function. + + Note: + - Requires the implementation of the + :meth:`~pytorch_lightning.core.lightning.LightningModule.forward` method. + - The exported script will be set to evaluation mode. + - It is recommended that you install the latest supported version of PyTorch + to use this feature without limitations. See also the :mod:`torch.jit` + documentation for supported features. + + Example: + >>> class SimpleModel(LightningModule): + ... def __init__(self): + ... super().__init__() + ... self.l1 = torch.nn.Linear(in_features=64, out_features=4) + ... + ... def forward(self, x): + ... return torch.relu(self.l1(x.view(x.size(0), -1))) + ... + >>> model = SimpleModel() + >>> torch.jit.save(model.to_torchscript(), "model.pt") # doctest: +SKIP + >>> os.path.isfile("model.pt") # doctest: +SKIP + True + + Return: + This LightningModule as a torchscript, regardless of whether file_path is + defined or not. + """ + + mode = self.training + with torch.no_grad(): + scripted_module = torch.jit.script(self.eval(), **kwargs) + self.train(mode) + + if file_path is not None: + torch.jit.save(scripted_module, file_path) + + return scripted_module + @property def hparams(self) -> Union[AttributeDict, str]: if not hasattr(self, '_hparams'): diff --git a/tests/base/models.py b/tests/base/models.py index 7f295da59fd65..9c319add4aca1 100644 --- a/tests/base/models.py +++ b/tests/base/models.py @@ -18,7 +18,7 @@ class Generator(nn.Module): - def __init__(self, latent_dim: tuple, img_shape: tuple): + def __init__(self, latent_dim: int, img_shape: tuple): super().__init__() self.img_shape = img_shape @@ -64,10 +64,10 @@ def forward(self, img): return validity -class TestGAN(LightningModule): +class BasicGAN(LightningModule): """Implements a basic GAN for the purpose of illustrating multiple optimizers.""" - def __init__(self, hidden_dim, learning_rate, b1, b2, **kwargs): + def __init__(self, hidden_dim: int = 128, learning_rate: float = 0.001, b1: float = 0.5, b2: float = 0.999, **kwargs): super().__init__() self.hidden_dim = hidden_dim self.learning_rate = learning_rate @@ -163,6 +163,7 @@ def __init__(self): super().__init__() self.rnn = nn.LSTM(10, 20, batch_first=True) self.linear_out = nn.Linear(in_features=20, out_features=5) + self.example_input_array = torch.rand(2, 3, 10) def forward(self, x): seq, last = self.rnn(x) @@ -189,6 +190,7 @@ def __init__(self): self.c_d1_bn = nn.BatchNorm1d(128) self.c_d1_drop = nn.Dropout(0.3) self.c_d2 = nn.Linear(in_features=128, out_features=10) + self.example_input_array = torch.rand(2, 1, 28, 28) def forward(self, x): x = x.view(x.size(0), -1) diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index f48db196c104a..7c6dc3b7417c5 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -13,7 +13,7 @@ import tests.base.develop_utils as tutils from pytorch_lightning import Trainer from tests.base import EvalModelTemplate -from tests.base.models import TestGAN +from tests.base.models import BasicGAN try: from horovod.common.util import nccl_built @@ -145,7 +145,7 @@ def validation_step(self, batch, *args, **kwargs): @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") def test_horovod_multi_optimizer(tmpdir): - model = TestGAN(**EvalModelTemplate.get_default_hparams()) + model = BasicGAN(**EvalModelTemplate.get_default_hparams()) # fit model trainer = Trainer( diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py new file mode 100644 index 0000000000000..a57a931820c55 --- /dev/null +++ b/tests/models/test_torchscript.py @@ -0,0 +1,88 @@ +from distutils.version import LooseVersion + +import pytest +import torch + +from tests.base import EvalModelTemplate +from tests.base.datamodules import TrialMNISTDataModule +from tests.base.models import ParityModuleRNN, BasicGAN + + +@pytest.mark.parametrize("modelclass", [ + EvalModelTemplate, + ParityModuleRNN, + BasicGAN, +]) +def test_torchscript_input_output(modelclass): + """ Test that scripted LightningModule forward works. """ + model = modelclass() + script = model.to_torchscript() + assert isinstance(script, torch.jit.ScriptModule) + model.eval() + model_output = model(model.example_input_array) + script_output = script(model.example_input_array) + assert torch.allclose(script_output, model_output) + + +@pytest.mark.parametrize("device", [ + torch.device("cpu"), + torch.device("cuda", 0) +]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +def test_torchscript_device(device): + """ Test that scripted module is on the correct device. """ + model = EvalModelTemplate().to(device) + script = model.to_torchscript() + assert next(script.parameters()).device == device + script_output = script(model.example_input_array.to(device)) + assert script_output.device == device + + +def test_torchscript_retain_training_state(): + """ Test that torchscript export does not alter the training mode of original model. """ + model = EvalModelTemplate() + model.train(True) + script = model.to_torchscript() + assert model.training + assert not script.training + model.train(False) + _ = model.to_torchscript() + assert not model.training + assert not script.training + + +@pytest.mark.parametrize("modelclass", [ + EvalModelTemplate, + ParityModuleRNN, + BasicGAN, +]) +def test_torchscript_properties(modelclass): + """ Test that scripted LightningModule has unnecessary methods removed. """ + model = modelclass() + model.datamodule = TrialMNISTDataModule() + script = model.to_torchscript() + assert not hasattr(script, "datamodule") + assert not hasattr(model, "batch_size") or hasattr(script, "batch_size") + assert not hasattr(model, "learning_rate") or hasattr(script, "learning_rate") + + if LooseVersion(torch.__version__) >= LooseVersion("1.4.0"): + # only on torch >= 1.4 do these unused methods get removed + assert not callable(getattr(script, "training_step", None)) + + +@pytest.mark.parametrize("modelclass", [ + EvalModelTemplate, + ParityModuleRNN, + BasicGAN, +]) +@pytest.mark.skipif( + LooseVersion(torch.__version__) < LooseVersion("1.5.0"), + reason="torch.save/load has bug loading script modules on torch <= 1.4", +) +def test_torchscript_save_load(tmpdir, modelclass): + """ Test that scripted LightningModules is correctly saved and can be loaded. """ + model = modelclass() + output_file = str(tmpdir / "model.pt") + script = model.to_torchscript(file_path=output_file) + loaded_script = torch.jit.load(output_file) + assert torch.allclose(next(script.parameters()), next(loaded_script.parameters()))