Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support torch.jit.script on LightningModules #1951

Closed
neighthan opened this issue May 26, 2020 · 7 comments
Closed

Support torch.jit.script on LightningModules #1951

neighthan opened this issue May 26, 2020 · 7 comments
Labels
feature Is an improvement or enhancement help wanted Open to be worked on won't fix This will not be worked on

Comments

@neighthan
Copy link
Contributor

🚀 Feature

There are a number of advantages to converting a model with TorchScript (e.g. static optimizations, better saving / loading, especially into non-Python environments for deployment). However, no LightningModules can be converted using torch.jit.script. Here's a simple example with the error produced (note that this works as-is if we inherit from nn.Module instead of pl.LightningModule):

import pytorch_lightning as pl
import torch

# class Model(nn.Module): # works fine
class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(5, 10)
    
    def forward(self, x):
        return self.layer(x)

torch.jit.script(Model())
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-70-1fe19c1470da> in <module>
     10         return self.layer(x)
     11 
---> 12 torch.jit.script(Model())

~/anaconda/envs/vsrl/lib/python3.7/site-packages/torch/jit/__init__.py in script(obj, optimize, _frames_up, _rcb)
   1259 
   1260     if isinstance(obj, torch.nn.Module):
-> 1261         return torch.jit._recursive.create_script_module(obj, torch.jit._recursive.infer_methods_to_compile)
   1262 
   1263     qualified_name = _qualified_name(obj)

~/anaconda/envs/vsrl/lib/python3.7/site-packages/torch/jit/_recursive.py in create_script_module(nn_module, stubs_fn, share_types)
    295     if share_types:
    296         # Look into the store of cached JIT types
--> 297         concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
    298     else:
    299         # Get a concrete type directly, without trying to re-use an existing JIT

~/anaconda/envs/vsrl/lib/python3.7/site-packages/torch/jit/_recursive.py in get_or_create_concrete_type(self, nn_module)
    254             return nn_module._concrete_type
    255 
--> 256         concrete_type_builder = infer_concrete_type_builder(nn_module)
    257 
    258         nn_module_type = type(nn_module)

~/anaconda/envs/vsrl/lib/python3.7/site-packages/torch/jit/_recursive.py in infer_concrete_type_builder(nn_module)
    133     # Constants annotated via `Final[T]` rather than being added to `__constants__`
    134     for name, ann in class_annotations.items():
--> 135         if torch._jit_internal.is_final(ann):
    136             constants_set.add(name)
    137 

~/anaconda/envs/vsrl/lib/python3.7/site-packages/torch/_jit_internal.py in is_final(ann)
    681 
    682     def is_final(ann):
--> 683         return ann.__module__ == 'typing_extensions' and \
    684             (getattr(ann, '__origin__', None) is typing_extensions.Final)
    685 except ImportError:

AttributeError: 'ellipsis' object has no attribute '__module__'

Digging into this a little, we have

print(Model.__annotations__)
# {'_device': Ellipsis, '_dtype': typing.Union[str, torch.dtype]}

and the _device annotation comes from DeviceDtypeModuleMixin (one of the super-classes of LightningModule). Here's the relevant snippet:

class DeviceDtypeModuleMixin(torch.nn.Module):
    _device: ...

This seems to be the only issue because this code works:

import pytorch_lightning as pl
import torch

class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(5, 10)
    
    def forward(self, x):
        return self.layer(x)

# Model.__annotations__ = {} # this works too but doesn't seem as nice
Model.__annotations__["_device"] = torch.device
torch.jit.script(Model())

However, if I try to set the annotation to typing.Union[str, torch.device] (which seems to be the true type based on this line), then I get ValueError: Unknown type annotation: 'typing.Union[str, torch.device]' in torch.jit.script`.

Is the str type for _device actually used? I don't see that anywhere, and I actually do see at least one place where there would be an error if self.device returned a string (here. I'll just go ahead and submit a PR to update the annotations, but feel free to comment here or on the PR if there's something I'm missing about the type annotations here.

@neighthan neighthan added feature Is an improvement or enhancement help wanted Open to be worked on labels May 26, 2020
@neighthan
Copy link
Contributor Author

The _dtype annotation from the same class also appears to be an issue.

@williamFalcon
Copy link
Contributor

i thought we made it compatible in september?

@neighthan
Copy link
Contributor Author

neighthan commented May 26, 2020

Does the code block I shared work for you? Are there any tests that actually try to do torch.jit.script on a LightningModule? Let me check if I'm on the bleeding-edge version

@neighthan
Copy link
Contributor Author

Okay, I'm

  • using Python 3.7.7
  • with PyTorch 1.5.0
  • and I just installed PyTorch Lightning from the latest commit (pip install git+https://github.com/PyTorchLightning/pytorch-lightning)

and I still see this issue. Perhaps the annotations were added after September and there weren't any tests that caught it?

@neighthan
Copy link
Contributor Author

_dtype actually can be a string, so the current type annotation is correct but an issue for TorchScript. I think the only place a conversion is needed is here and that we should be able to do that with something like

if isinstance(dst_type, str):
    dst_type = getattr(torch, dst_type)

@lezwon lezwon mentioned this issue Jul 15, 2020
7 tasks
@lezwon
Copy link
Contributor

lezwon commented Jul 22, 2020

I think this issue is fixed after #2657. I am able to convert the model to Torchscript format now.

@stale
Copy link

stale bot commented Jul 25, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on won't fix This will not be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants