Skip to content
This repository has been archived by the owner on Mar 12, 2024. It is now read-only.

Add torchscript support for hub detection models #51

Merged
merged 5 commits into from
Jun 4, 2020

Conversation

fmassa
Copy link
Contributor

@fmassa fmassa commented Jun 4, 2020

This PR adds support for torchscript export to the detection models in DETR torchhub.

It requires a recent-enough PyTorch (nightly from today worked for me), and doesn't work with PyTorch 1.5.0.

@fmassa fmassa requested review from szagoruyko and alcinos June 4, 2020 10:27
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 4, 2020
xs = self.body(tensor_list.tensors)
out = OrderedDict()
out : Dict[str, NestedTensor] = {}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be a BC-breaking change for some versions of Python, as we are changing OrderedDict to dict.

If this is a problem, we can change it to be a list.

return out

@torch.jit.unused
def _set_aux_loss(self, outputs_class, outputs_coord):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aux_loss is not supported for exporting to torchscript, but this should be fine as we can always remove it for inference.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does aux_loss only ever have pred_logits and pred_boxes as keys ? Could use a namedtuple maybe ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might vary depending on the task. But I guess we could leave some fields of the namedtuple as None?


def decompose(self):
return self.tensors, self.mask

@classmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

classmethod doesn't seem to be supported, so had to move this as a separate function (with a few extra changes to make it compatible with torchscript)

@@ -64,15 +65,17 @@ def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int,
if return_interm_layers:
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
else:
return_layers = {'layer4': 0}
return_layers = {'layer4': "0"}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This oversight took me a while to figure out, as the error messages were pretty cryptic

Error message from torchscript
Traceback (most recent call last):
  File "test_all.py", line 64, in test_model_script
    scripted_model = torch.jit.script(model)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 1340, in script
    return torch.jit._recursive.create_script_module(obj, torch.jit._recursive.infer_methods_to_compile)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 313, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 367, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 1713, in _construct
    init_fn(script_module)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 348, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, infer_methods_to_compile)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 367, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 1713, in _construct
    init_fn(script_module)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 348, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, infer_methods_to_compile)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 367, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 1713, in _construct
    init_fn(script_module)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 348, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, infer_methods_to_compile)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 367, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 1713, in _construct
    init_fn(script_module)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 333, in init_fn
    cpp_module.setattr(name, orig_value)
RuntimeError: Unable to cast Python instance to C++ type (compile in debug mode for details)

cc @eellison if this error message could be improved it would be great.

Copy link
Contributor

@alcinos alcinos left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @fmassa!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants