Skip to content
This repository has been archived by the owner on Mar 12, 2024. It is now read-only.

Add torchscript support for hub detection models #51

Merged
merged 5 commits into from
Jun 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ jobs:
- checkout
- run:
command: |
pip install --user --progress-bar off torch torchvision scipy pytest
pip install --user --progress-bar off scipy pytest
pip install --user --progress-bar off --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pytest .

workflows:
Expand Down
15 changes: 9 additions & 6 deletions models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from typing import Dict, List

from util.misc import NestedTensor

Expand Down Expand Up @@ -64,15 +65,17 @@ def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int,
if return_interm_layers:
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
else:
return_layers = {'layer4': 0}
return_layers = {'layer4': "0"}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This oversight took me a while to figure out, as the error messages were pretty cryptic

Error message from torchscript
Traceback (most recent call last):
  File "test_all.py", line 64, in test_model_script
    scripted_model = torch.jit.script(model)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 1340, in script
    return torch.jit._recursive.create_script_module(obj, torch.jit._recursive.infer_methods_to_compile)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 313, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 367, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 1713, in _construct
    init_fn(script_module)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 348, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, infer_methods_to_compile)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 367, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 1713, in _construct
    init_fn(script_module)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 348, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, infer_methods_to_compile)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 367, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 1713, in _construct
    init_fn(script_module)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 348, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, infer_methods_to_compile)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 367, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 1713, in _construct
    init_fn(script_module)
  File "/Users/fmassa/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 333, in init_fn
    cpp_module.setattr(name, orig_value)
RuntimeError: Unable to cast Python instance to C++ type (compile in debug mode for details)

cc @eellison if this error message could be improved it would be great.

self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.num_channels = num_channels

def forward(self, tensor_list):
def forward(self, tensor_list: NestedTensor):
xs = self.body(tensor_list.tensors)
out = OrderedDict()
out: Dict[str, NestedTensor] = {}
for name, x in xs.items():
mask = F.interpolate(tensor_list.mask[None].float(), size=x.shape[-2:]).bool()[0]
m = tensor_list.mask
assert m is not None
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
out[name] = NestedTensor(x, mask)
return out

Expand All @@ -94,9 +97,9 @@ class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)

def forward(self, tensor_list):
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out = []
out: List[NestedTensor] = []
pos = []
for name, x in xs.items():
out.append(x)
Expand Down
19 changes: 14 additions & 5 deletions models/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from torch import nn

from util import box_ops
from util.misc import (NestedTensor, accuracy, get_world_size, interpolate,
from util.misc import (NestedTensor, nested_tensor_from_tensor_list,
accuracy, get_world_size, interpolate,
is_dist_avail_and_initialized)

from .backbone import build_backbone
Expand Down Expand Up @@ -56,20 +57,28 @@ def forward(self, samples: NestedTensor):
dictionnaries containing the two above keys for each decoder layer.
"""
if not isinstance(samples, NestedTensor):
samples = NestedTensor.from_tensor_list(samples)
samples = nested_tensor_from_tensor_list(samples)
features, pos = self.backbone(samples)

src, mask = features[-1].decompose()
assert mask is not None
hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]

outputs_class = self.class_embed(hs)
outputs_coord = self.bbox_embed(hs).sigmoid()
out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
if self.aux_loss:
out['aux_outputs'] = [{'pred_logits': a, 'pred_boxes': b}
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
return out

@torch.jit.unused
def _set_aux_loss(self, outputs_class, outputs_coord):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aux_loss is not supported for exporting to torchscript, but this should be fine as we can always remove it for inference.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does aux_loss only ever have pred_logits and pred_boxes as keys ? Could use a namedtuple maybe ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might vary depending on the task. But I guess we could leave some fields of the namedtuple as None?

# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [{'pred_logits': a, 'pred_boxes': b}
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]


class SetCriterion(nn.Module):
""" This class computes the loss for DETR.
Expand Down Expand Up @@ -164,7 +173,7 @@ def loss_masks(self, outputs, targets, indices, num_boxes):
src_masks = outputs["pred_masks"]

# TODO use valid to mask invalid areas due to padding in loss
target_masks, valid = NestedTensor.from_tensor_list([t["masks"] for t in targets]).decompose()
target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets]).decompose()
target_masks = target_masks.to(src_masks)

src_masks = src_masks[src_idx]
Expand Down
7 changes: 5 additions & 2 deletions models/position_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch
from torch import nn

from util.misc import NestedTensor


class PositionEmbeddingSine(nn.Module):
"""
Expand All @@ -23,9 +25,10 @@ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=N
scale = 2 * math.pi
self.scale = scale

def forward(self, tensor_list):
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
mask = tensor_list.mask
assert mask is not None
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
Expand Down Expand Up @@ -59,7 +62,7 @@ def reset_parameters(self):
nn.init.uniform_(self.row_embed.weight)
nn.init.uniform_(self.col_embed.weight)

def forward(self, tensor_list):
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
h, w = x.shape[-2:]
i = torch.arange(w, device=x.device)
Expand Down
4 changes: 2 additions & 2 deletions models/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from PIL import Image

import util.box_ops as box_ops
from util.misc import NestedTensor, interpolate
from util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list

try:
from panopticapi.utils import id2rgb, rgb2id
Expand All @@ -34,7 +34,7 @@ def __init__(self, detr, freeze_detr=False):

def forward(self, samples: NestedTensor):
if not isinstance(samples, NestedTensor):
samples = NestedTensor.from_tensor_list(samples)
samples = nested_tensor_from_tensor_list(samples)
features, pos = self.detr.backbone(samples)

bs = features[-1].tensors.shape[0]
Expand Down
21 changes: 21 additions & 0 deletions test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import torch

from models.matcher import HungarianMatcher
from models.position_encoding import PositionEmbeddingSine, PositionEmbeddingLearned
from models.backbone import Backbone, Joiner, BackboneBase
from util import box_ops
from util.misc import nested_tensor_from_tensor_list
from hubconf import detr_resnet50


class Tester(unittest.TestCase):
Expand Down Expand Up @@ -47,6 +51,23 @@ def test_hungarian(self):
'pred_boxes': boxes.repeat(2, 1, 1)}, targets_empty * 2)
self.assertEqual(len(indices[0][0]), 0)

def test_position_encoding_script(self):
m1, m2 = PositionEmbeddingSine(), PositionEmbeddingLearned()
mm1, mm2 = torch.jit.script(m1), torch.jit.script(m2) # noqa

def test_backbone_script(self):
backbone = Backbone('resnet50', True, False, False)
torch.jit.script(backbone) # noqa

def test_model_script(self):
model = detr_resnet50(pretrained=False).eval()
scripted_model = torch.jit.script(model)
x = nested_tensor_from_tensor_list([torch.rand(3, 200, 200), torch.rand(3, 200, 250)])
out = model(x)
out_script = scripted_model(x)
self.assertTrue(out["pred_logits"].equal(out_script["pred_logits"]))
self.assertTrue(out["pred_boxes"].equal(out_script["pred_boxes"]))


if __name__ == '__main__':
unittest.main()
69 changes: 42 additions & 27 deletions util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,45 +267,60 @@ def _run(command):

def collate_fn(batch):
batch = list(zip(*batch))
batch[0] = NestedTensor.from_tensor_list(batch[0])
batch[0] = nested_tensor_from_tensor_list(batch[0])
return tuple(batch)


def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes


def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
# TODO make this more general
if tensor_list[0].ndim == 3:
# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], :img.shape[2]] = False
else:
raise ValueError('not supported')
return NestedTensor(tensor, mask)


class NestedTensor(object):
def __init__(self, tensors, mask):
def __init__(self, tensors, mask: Optional[Tensor]):
self.tensors = tensors
self.mask = mask

def to(self, *args, **kwargs):
cast_tensor = self.tensors.to(*args, **kwargs)
cast_mask = self.mask.to(*args, **kwargs) if self.mask is not None else None
return type(self)(cast_tensor, cast_mask)
def to(self, device):
# type: (Device) -> NestedTensor # noqa
cast_tensor = self.tensors.to(device)
mask = self.mask
if mask is not None:
assert mask is not None
cast_mask = mask.to(device)
else:
cast_mask = None
return NestedTensor(cast_tensor, cast_mask)

def decompose(self):
return self.tensors, self.mask

@classmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

classmethod doesn't seem to be supported, so had to move this as a separate function (with a few extra changes to make it compatible with torchscript)

def from_tensor_list(cls, tensor_list):
# TODO make this more general
if tensor_list[0].ndim == 3:
# TODO make it support different-sized images
max_size = tuple(max(s) for s in zip(*[img.shape for img in tensor_list]))
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = (len(tensor_list),) + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], :img.shape[2]] = False
else:
raise ValueError('not supported')
return cls(tensor, mask)

def __repr__(self):
return repr(self.tensors)
return str(self.tensors)


def setup_for_distributed(is_master):
Expand Down