Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TorchScript export incompatible with DeviceDtypeModuleMixin properties #13887

Open
mattm458 opened this issue Jul 27, 2022 · 2 comments
Open
Labels
bug Something isn't working lightningmodule pl.LightningModule priority: 2 Low priority task
Milestone

Comments

@mattm458
Copy link

mattm458 commented Jul 27, 2022

🐛 Bug

My model code makes extensive use of tensors created during the forward pass. As part of this, I have to specify the device associated with each tensor to avoid conflicts. I want to export my model in TorchScript for deployment elsewhere, but the DeviceDtypeModuleMixin properties do not appear to be compatible with TorchScript.

While this appears to affect the dtype property too, for my use case, using self.device in a LightningModule prevents successful TorchScript export.

To Reproduce

import pytorch_lightning as pl
import torch


class TestModel(pl.LightningModule):
    def forward(self):
        return torch.zeros((2, 2), device=self.device)


model = TestModel()
model.to_torchscript()

results in the following error:

RuntimeError: 
Module 'TestModel' has no attribute 'device' :
  File "<snip>/test.py", line 7
    def forward(self):
        return torch.zeros((2, 2), device=self.device)
                                          ~~~~~~~~~~~ <--- HERE

I am not sure how this can be fixed without changing the way DeviceDtypeModuleMixin works. It relies on the @property decorator, which uses descriptors. TorchScript is not compatible with descriptors (3.3.2.2, 3.3.2.3).

Additionally, the device property returns a Union[str, torch.device], so even if it worked as a property in TorchScript, the return value of the function is incompatible with the type expected by the device keyword option (Optional[Device]). There appears to have been some discussion of this in this issue, but I'm not sure how it was resolved.

I am happy to try contributing a solution, but I wanted to discuss here first. A possible solution seems to involve eliminating the @property decorator and save the current state of the device as a plain device attribute on the model object. The actions currently in the device() function body should be performed as the device attribute is being set rather than on the way out.

Expected behavior

The code shown above is successfully saved as TorchScript.

Environment

  • Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): LightningModule
  • PyTorch Lightning Version (e.g., 1.5.0): 1.6.5
  • PyTorch Version (e.g., 1.10): 1.12.0
  • Python version (e.g., 3.9): 3.9.13
  • OS (e.g., Linux): Linux
  • CUDA/cuDNN version: 11.6
  • How you installed PyTorch (conda, pip, source): pip

cc @carmocca @justusschock @awaelchli @Borda @ananthsub @ninginthecloud @jjenniferdai @rohitgr7

@mattm458 mattm458 added the needs triage Waiting to be triaged by maintainers label Jul 27, 2022
@carmocca carmocca added bug Something isn't working lightningmodule pl.LightningModule labels Jul 27, 2022
@carmocca carmocca added this to the pl:1.6.x milestone Jul 27, 2022
@carmocca carmocca removed the needs triage Waiting to be triaged by maintainers label Jul 27, 2022
@justusschock
Copy link
Member

justusschock commented Jul 28, 2022

Hi,
I looked into it and first stumbled upon pytorch/pytorch#37883 which claims that properties are indeed supported. I could also find the related code (here) so I am not very sure why it fails right now. My guess is that the linked code would need to be added to the ScriptModule as well.

while looking for workarounds, I also found the possibility to add it as a raw attribute, but we likely would have to guard it with some logic in __getattr__ and __setattr__ (probably 1:1 the logic we have in the property methods right now).

As a workaround

import pytorch_lightning as pl
import torch


class TestModel(pl.LightningModule):
    def forward(self):
        return torch.zeros((2, 2), device=self.device)


model = TestModel()
model.to_torchscript(method="trace", example_inputs=())

does work for me.

If you want to take a stab at this, it would be very welcomed :)

@carmocca
Copy link
Contributor

TorchScript is no longer worked on in PyTorch (as pytorch/pytorch#67146 (comment)). So this will probably stay open forever unless someone contributes the patch

@carmocca carmocca added the priority: 2 Low priority task label Oct 13, 2022
@Borda Borda modified the milestones: v1.8.x, v1.9 Jan 6, 2023
@Borda Borda modified the milestones: v1.9, v1.9.x Jan 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working lightningmodule pl.LightningModule priority: 2 Low priority task
Projects
None yet
Development

No branches or pull requests

4 participants