Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support a to_torchscript function on the LightningModule #3080

Closed
ananthsub opened this issue Aug 20, 2020 · 17 comments · Fixed by #3258
Closed

Support a to_torchscript function on the LightningModule #3080

ananthsub opened this issue Aug 20, 2020 · 17 comments · Fixed by #3258
Assignees
Labels
design Includes a design discussion feature Is an improvement or enhancement help wanted Open to be worked on let's do it! approved to implement
Milestone

Comments

@ananthsub
Copy link
Contributor

ananthsub commented Aug 20, 2020

🚀 Feature

Support a conversion function to PyTorch JIT similar to what's available for ONNX.

Motivation

TorchScript is a way to create serializable and optimizable models from PyTorch code. Any TorchScript program can be saved from a Python process and loaded in a process where there is no Python dependency. TorchScript is a method by which users can serve PyTorch models efficiently.

Pitch

By default, we can use TorchScript to script the LightningModule. Users can override this in their own lightning modules to use tracing, or to script specific nn.Modules inside their LightningModule. This can then be extended to other Lightning utilities like model checkpointing, so we can save TorchScript or ONNX converted models alongside the best model checkpoints to make going to serving even easier to do

    def to_torchscript(self):
        """Saves the model as a JIT module. 
            This can be overridden to support custom TorchScript module export
        Example:
            >>> class SimpleModel(LightningModule):
            ...     def __init__(self):
            ...         super().__init__()
            ...         self.l1 = torch.nn.Linear(in_features=64, out_features=4)
            ...
            ...     def forward(self, x):
            ...         return torch.relu(self.l1(x.view(x.size(0), -1)))
            >>> with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile:
            ...     model = SimpleModel()
            ...     torch.jit.save(model.to_torchscript(), tmpfile.name)
            ...     os.path.isfile(tmpfile.name)
            True
        """
       return torch.jit.script(self.eval())
@ananthsub ananthsub added feature Is an improvement or enhancement help wanted Open to be worked on labels Aug 20, 2020
@Borda Borda added Important let's do it! approved to implement labels Aug 25, 2020
@Borda
Copy link
Member

Borda commented Aug 25, 2020

I like this proposal and it shall be quite straight forward, mind send a PR? 🐰
cc: @PyTorchLightning/core-contributors

@Borda Borda added this to the 1.0.0 milestone Aug 25, 2020
@awaelchli awaelchli self-assigned this Aug 25, 2020
@ananthsub
Copy link
Contributor Author

@Borda regarding the API, what if someone wanted to expose multiple torchscript modules from their lightning module. E.g. their lightning module does GAN training, and they want to export both the generator and discriminator separately. Should the interface allow for multiple scripted modules to be returned?

@awaelchli
Copy link
Contributor

awaelchli commented Aug 26, 2020

I think that would be super useful in these cases! How about something like

model.to_torchscript()  # exports whole module
model.to_torchscript(children=["discriminator", "generator"])  # returns the list of script modules

or should that just be part of the override the user can do, without any argument passed to the method?

@williamFalcon
Copy link
Contributor

williamFalcon commented Aug 26, 2020

I think that would be super useful in these cases! How about something like

model.to_torchscript()  # exports whole module
model.to_torchscript(children=["discriminator", "generator"])  # returns the list of script modules

or should that just be part of the override the user can do, without any argument passed to the method?

so, the PL module would need attrs with discriminator and generator no? i think this makes sense as long as we specify that.

@ananthsub what about transforms? or a data pipeline the user might also need when exporting to torchscript/onnx. how are you thinking about that?

@Borda
Copy link
Member

Borda commented Aug 26, 2020

Yeah, the children paraneter is good =)

@Borda Borda added the design Includes a design discussion label Aug 26, 2020
@ananthsub
Copy link
Contributor Author

ananthsub commented Aug 26, 2020

@ananthsub what about transforms? or a data pipeline the user might also need when exporting to torchscript/onnx. how are you thinking about that?

Yes, combining with transforms is an important need. In order to serve models, we will also extract the transforms from the data module and "glue" them into the forward function (assuming the transforms are also TorchScriptable)

so, the PL module would need attrs with discriminator and generator no? i think this makes sense as long as we specify that.

I think it'd be easier to define these on the pl_module outside of this function. I'm hesitant to make this take arbitrary args and add magic inside of the base call here. What do you think about this proposal?

I think an API like this would be extensible. We are incubating this via a Lightning callback which people can extend

    def to_torchscript(self) -> Dict[str, nn.Module]:
        # By default, scripts the whole lightning module and wraps this in a dictionary from name to scripted model
        # if you would like to customize the modules that are scripted
        # or you want to use tracing you should override this method
        with pl_module.eval() and torch.no_grad():
            scripted_model = torch.jit.script(pl_module)
        return {DEFAULT_MODEL_FILENAME: scripted_model}

@snisarg
Copy link

snisarg commented Aug 26, 2020

For consistency's sake, pointing out that to_onnx() takes a file path, though I would much rather prefer this proposed solution of returning Dict[str, nn.Module].

@hudeven
Copy link

hudeven commented Aug 26, 2020

Yes, combining with transforms is an important need. In order to serve models, we will also extract the transforms from the data module and "glue" them into the forward function (assuming the transforms are also TorchScriptable)

there are train_transforms, val_transforms and test_transforms in data module. Could we assume test_transforms always be the one to torchscript with model? Sometimes, the transform used in training/testing might be different than in inference, we might need additional logic to patch the transform during export.

How about add "to_torchscript()" to data module? such that model doesn't need to know the members in data module and any special patching for transform happens inside of data module.

@Borda
Copy link
Member

Borda commented Aug 27, 2020

@hudeven I think that datamodule shall not be part of the exported model as we move to data-agnostic models...
thoughts @nateraw @williamFalcon

@ananthsub
Copy link
Contributor Author

ananthsub commented Aug 27, 2020

@hudeven I think that datamodule shall not be part of the exported model as we move to data-agnostic models...
thoughts @nateraw @williamFalcon

That makes sense. I think we can go with the simplest approach relying on just the

It'd be great if we could also bake in bundled inputs into the torchscript export, as this would be broadly useful: https://github.com/pytorch/pytorch/blob/master/torch/utils/bundled_inputs.py#L17

@awaelchli
Copy link
Contributor

@ananthsub we also have the option to use torch.jit.trace. Should we support both in the same function, or better not?
I was thinking to_torchscript(trace=False)?

@Borda
Copy link
Member

Borda commented Aug 29, 2020

@ananthsub we also have the option to use torch.jit.trace. Should we support both in the same function, or better not?
I was thinking to_torchscript(trace=False)?

I would prefer the parameter way...

@ananthsub
Copy link
Contributor Author

Tracing also requires the user to provide an example input. So the API would become to_torchscript(trace=False, example_input=None) ?

The PyTorch JIT team encourages users to use scripting over tracing because tracing has a number of limitations, which is why I didn't want to add a new dependency here. Given that this is new, maybe we could start without tracing and then add it as a feature request if many lightning users are interested. What do you think?

@awaelchli
Copy link
Contributor

yes, perfect. was wondering the same, but script seems to be very versatile.

@Borda
Copy link
Member

Borda commented Aug 29, 2020

Tracing also requires the user to provide an example input. So the API would become to_torchscript(trace=False, example_input=None) ?

let's follow the same parameter order as we already have for ONNX export
https://github.com/PyTorchLightning/pytorch-lightning/blob/d9ea25590e95ca9e70401123a0f1f59de711e2ff/pytorch_lightning/core/lightning.py#L1689
co to have the params diff from onnx to place as last...

@awaelchli
Copy link
Contributor

awaelchli commented Aug 30, 2020

I've implemented and tested now the basic version as proposed by @ananthsub.

@Borda example input is not required for scripting module, but we can make it have the file_path=None and if provided, save to file directly.

@ananthsub what is the benefit of returning a dict instead of the script module directly? Do you plan to call this on submodules and collect them into a single dict, or let user return multiple entries by overriding?

Furthermore, it seems bundled inputs are not yet there, I couldn't get it to work with PyTorch 1.7 so, not sure if people even know about it :)

@ananthsub
Copy link
Contributor Author

@ananthsub what is the benefit of returning a dict instead of the script module directly? Do you plan to call this on submodules and collect them into a single dict, or let user return multiple entries by overriding?

I initially proposed this to keep the output parsing simpler. But I agree the most common case would be returning the script module directly. Maybe the type signature could be -> Union[ScriptModule, Dict[str, ScriptModule]].

I suppose I'm jumping ahead with bundled inputs then

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Includes a design discussion feature Is an improvement or enhancement help wanted Open to be worked on let's do it! approved to implement
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants