-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Add torchscript support for hub detection models #51
Conversation
models/backbone.py
Outdated
xs = self.body(tensor_list.tensors) | ||
out = OrderedDict() | ||
out : Dict[str, NestedTensor] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be a BC-breaking change for some versions of Python, as we are changing OrderedDict
to dict.
If this is a problem, we can change it to be a list.
return out | ||
|
||
@torch.jit.unused | ||
def _set_aux_loss(self, outputs_class, outputs_coord): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aux_loss
is not supported for exporting to torchscript, but this should be fine as we can always remove it for inference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does aux_loss
only ever have pred_logits
and pred_boxes
as keys ? Could use a namedtuple maybe ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might vary depending on the task. But I guess we could leave some fields of the namedtuple as None
?
|
||
def decompose(self): | ||
return self.tensors, self.mask | ||
|
||
@classmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
classmethod
doesn't seem to be supported, so had to move this as a separate function (with a few extra changes to make it compatible with torchscript)
@@ -64,15 +65,17 @@ def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, | |||
if return_interm_layers: | |||
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} | |||
else: | |||
return_layers = {'layer4': 0} | |||
return_layers = {'layer4': "0"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This oversight took me a while to figure out, as the error messages were pretty cryptic
Error message from torchscript
Traceback (most recent call last):
File "test_all.py", line 64, in test_model_script
scripted_model = torch.jit.script(model)
File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 1340, in script
return torch.jit._recursive.create_script_module(obj, torch.jit._recursive.infer_methods_to_compile)
File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 313, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 367, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 1713, in _construct
init_fn(script_module)
File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 348, in init_fn
scripted = create_script_module_impl(orig_value, sub_concrete_type, infer_methods_to_compile)
File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 367, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 1713, in _construct
init_fn(script_module)
File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 348, in init_fn
scripted = create_script_module_impl(orig_value, sub_concrete_type, infer_methods_to_compile)
File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 367, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 1713, in _construct
init_fn(script_module)
File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 348, in init_fn
scripted = create_script_module_impl(orig_value, sub_concrete_type, infer_methods_to_compile)
File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 367, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 1713, in _construct
init_fn(script_module)
File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 333, in init_fn
cpp_module.setattr(name, orig_value)
RuntimeError: Unable to cast Python instance to C++ type (compile in debug mode for details)
cc @eellison if this error message could be improved it would be great.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @fmassa!
This PR adds support for torchscript export to the detection models in DETR torchhub.
It requires a recent-enough PyTorch (nightly from today worked for me), and doesn't work with PyTorch 1.5.0.