-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add TorchScript compatibility for LightningModules #1952
Conversation
I didn't test this enough; I think this isn't complete yet. |
Codecov Report
@@ Coverage Diff @@
## master #1952 +/- ##
======================================
- Coverage 88% 88% -0%
======================================
Files 74 74
Lines 4650 4652 +2
======================================
+ Hits 4070 4071 +1
- Misses 580 581 +1 |
_device: ... | ||
_dtype: Union[str, torch.dtype] | ||
_device: torch.device | ||
_dtype: torch.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why shouldn't I be omit strings here?
|
||
@property | ||
def dtype(self) -> Union[str, torch.dtype]: | ||
def dtype(self) -> torch.dtype: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for strings
@@ -17,7 +17,7 @@ def dtype(self, new_dtype: Union[str, torch.dtype]): | |||
raise RuntimeError('Cannot set the dtype explicitly. Please use module.to(new_dtype).') | |||
|
|||
@property | |||
def device(self) -> Union[str, torch.device]: | |||
def device(self) -> torch.device: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again strings :D
@@ -125,6 +125,8 @@ def type(self, dst_type: Union[str, torch.dtype]) -> torch.nn.Module: | |||
Returns: | |||
Module: self | |||
""" | |||
if isinstance(dst_type, str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure this is working for all the types. Even if it is. Can't we maybe find a better way to do this? For devices, you can simply do device = torch.device(device_str)
. I know, this does not work for dtypes, but maybe something similar? But in general: why doesn't strings work with torchscript?
It would be good to have a test that torch script can compile the template we use in tests, because the device properties were added only recently and it seems that it broke the compatibility. |
@neighthan how is it going here? 🐰 |
Could you add the help-wanted label here? The project where I was hoping to use TorchScript didn't benefit from it, so writing the tests that should be in this PR / checking the issues with |
@awaelchli @Borda I'd like to work on this issue. Would require a little help on what would be a good approach to fix this though :] |
@lezwon It's basically adding the missing tests (like you did with onnx) and then checking the dtype-str issue. The rest should be fine from what I saw :) |
@justusschock cool :) I'll get started with it then. 👍 |
Before submitting
What does this PR do?
Fixes #1951.