-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
to_torchscript method for LightningModule (#3258)
* script * docs * simple test * move test * fix doctest * no grad context * extend tests test test * datamodule test * clean up test * docs * name * fix import * update changelog * fix import * skip pytorch 1.3 in test * update codeblock * skip bugged 1.4 * typehints * doctest not working on all pytorch versions * rename TestGAN to prevent pytest interference * add note about pytorch version * fix torchscript version inconsistency in tests * reset training state + tests * update docstring * Apply suggestions from code review Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> * update docstring, dict return * add docs to index * add link * doc eval mode * forward * optional save to file path * optional * test torchscript device * test save load with file path * pep * str * Commit typing suggestion Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> * skip test if cuda not available Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
- Loading branch information
1 parent
4a22fca
commit 4ad5a78
Showing
7 changed files
with
171 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from distutils.version import LooseVersion | ||
|
||
import pytest | ||
import torch | ||
|
||
from tests.base import EvalModelTemplate | ||
from tests.base.datamodules import TrialMNISTDataModule | ||
from tests.base.models import ParityModuleRNN, BasicGAN | ||
|
||
|
||
@pytest.mark.parametrize("modelclass", [ | ||
EvalModelTemplate, | ||
ParityModuleRNN, | ||
BasicGAN, | ||
]) | ||
def test_torchscript_input_output(modelclass): | ||
""" Test that scripted LightningModule forward works. """ | ||
model = modelclass() | ||
script = model.to_torchscript() | ||
assert isinstance(script, torch.jit.ScriptModule) | ||
model.eval() | ||
model_output = model(model.example_input_array) | ||
script_output = script(model.example_input_array) | ||
assert torch.allclose(script_output, model_output) | ||
|
||
|
||
@pytest.mark.parametrize("device", [ | ||
torch.device("cpu"), | ||
torch.device("cuda", 0) | ||
]) | ||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") | ||
def test_torchscript_device(device): | ||
""" Test that scripted module is on the correct device. """ | ||
model = EvalModelTemplate().to(device) | ||
script = model.to_torchscript() | ||
assert next(script.parameters()).device == device | ||
script_output = script(model.example_input_array.to(device)) | ||
assert script_output.device == device | ||
|
||
|
||
def test_torchscript_retain_training_state(): | ||
""" Test that torchscript export does not alter the training mode of original model. """ | ||
model = EvalModelTemplate() | ||
model.train(True) | ||
script = model.to_torchscript() | ||
assert model.training | ||
assert not script.training | ||
model.train(False) | ||
_ = model.to_torchscript() | ||
assert not model.training | ||
assert not script.training | ||
|
||
|
||
@pytest.mark.parametrize("modelclass", [ | ||
EvalModelTemplate, | ||
ParityModuleRNN, | ||
BasicGAN, | ||
]) | ||
def test_torchscript_properties(modelclass): | ||
""" Test that scripted LightningModule has unnecessary methods removed. """ | ||
model = modelclass() | ||
model.datamodule = TrialMNISTDataModule() | ||
script = model.to_torchscript() | ||
assert not hasattr(script, "datamodule") | ||
assert not hasattr(model, "batch_size") or hasattr(script, "batch_size") | ||
assert not hasattr(model, "learning_rate") or hasattr(script, "learning_rate") | ||
|
||
if LooseVersion(torch.__version__) >= LooseVersion("1.4.0"): | ||
# only on torch >= 1.4 do these unused methods get removed | ||
assert not callable(getattr(script, "training_step", None)) | ||
|
||
|
||
@pytest.mark.parametrize("modelclass", [ | ||
EvalModelTemplate, | ||
ParityModuleRNN, | ||
BasicGAN, | ||
]) | ||
@pytest.mark.skipif( | ||
LooseVersion(torch.__version__) < LooseVersion("1.5.0"), | ||
reason="torch.save/load has bug loading script modules on torch <= 1.4", | ||
) | ||
def test_torchscript_save_load(tmpdir, modelclass): | ||
""" Test that scripted LightningModules is correctly saved and can be loaded. """ | ||
model = modelclass() | ||
output_file = str(tmpdir / "model.pt") | ||
script = model.to_torchscript(file_path=output_file) | ||
loaded_script = torch.jit.load(output_file) | ||
assert torch.allclose(next(script.parameters()), next(loaded_script.parameters())) |