-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Add torchscript support for hub detection models #51
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,8 @@ | |
from torch import nn | ||
|
||
from util import box_ops | ||
from util.misc import (NestedTensor, accuracy, get_world_size, interpolate, | ||
from util.misc import (NestedTensor, nested_tensor_from_tensor_list, | ||
accuracy, get_world_size, interpolate, | ||
is_dist_avail_and_initialized) | ||
|
||
from .backbone import build_backbone | ||
|
@@ -56,20 +57,28 @@ def forward(self, samples: NestedTensor): | |
dictionnaries containing the two above keys for each decoder layer. | ||
""" | ||
if not isinstance(samples, NestedTensor): | ||
samples = NestedTensor.from_tensor_list(samples) | ||
samples = nested_tensor_from_tensor_list(samples) | ||
features, pos = self.backbone(samples) | ||
|
||
src, mask = features[-1].decompose() | ||
assert mask is not None | ||
hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] | ||
|
||
outputs_class = self.class_embed(hs) | ||
outputs_coord = self.bbox_embed(hs).sigmoid() | ||
out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]} | ||
if self.aux_loss: | ||
out['aux_outputs'] = [{'pred_logits': a, 'pred_boxes': b} | ||
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] | ||
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord) | ||
return out | ||
|
||
@torch.jit.unused | ||
def _set_aux_loss(self, outputs_class, outputs_coord): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might vary depending on the task. But I guess we could leave some fields of the namedtuple as |
||
# this is a workaround to make torchscript happy, as torchscript | ||
# doesn't support dictionary with non-homogeneous values, such | ||
# as a dict having both a Tensor and a list. | ||
return [{'pred_logits': a, 'pred_boxes': b} | ||
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] | ||
|
||
|
||
class SetCriterion(nn.Module): | ||
""" This class computes the loss for DETR. | ||
|
@@ -164,7 +173,7 @@ def loss_masks(self, outputs, targets, indices, num_boxes): | |
src_masks = outputs["pred_masks"] | ||
|
||
# TODO use valid to mask invalid areas due to padding in loss | ||
target_masks, valid = NestedTensor.from_tensor_list([t["masks"] for t in targets]).decompose() | ||
target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets]).decompose() | ||
target_masks = target_masks.to(src_masks) | ||
|
||
src_masks = src_masks[src_idx] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -267,45 +267,60 @@ def _run(command): | |
|
||
def collate_fn(batch): | ||
batch = list(zip(*batch)) | ||
batch[0] = NestedTensor.from_tensor_list(batch[0]) | ||
batch[0] = nested_tensor_from_tensor_list(batch[0]) | ||
return tuple(batch) | ||
|
||
|
||
def _max_by_axis(the_list): | ||
# type: (List[List[int]]) -> List[int] | ||
maxes = the_list[0] | ||
for sublist in the_list[1:]: | ||
for index, item in enumerate(sublist): | ||
maxes[index] = max(maxes[index], item) | ||
return maxes | ||
|
||
|
||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): | ||
# TODO make this more general | ||
if tensor_list[0].ndim == 3: | ||
# TODO make it support different-sized images | ||
max_size = _max_by_axis([list(img.shape) for img in tensor_list]) | ||
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) | ||
batch_shape = [len(tensor_list)] + max_size | ||
b, c, h, w = batch_shape | ||
dtype = tensor_list[0].dtype | ||
device = tensor_list[0].device | ||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device) | ||
mask = torch.ones((b, h, w), dtype=torch.bool, device=device) | ||
for img, pad_img, m in zip(tensor_list, tensor, mask): | ||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) | ||
m[: img.shape[1], :img.shape[2]] = False | ||
else: | ||
raise ValueError('not supported') | ||
return NestedTensor(tensor, mask) | ||
|
||
|
||
class NestedTensor(object): | ||
def __init__(self, tensors, mask): | ||
def __init__(self, tensors, mask: Optional[Tensor]): | ||
self.tensors = tensors | ||
self.mask = mask | ||
|
||
def to(self, *args, **kwargs): | ||
cast_tensor = self.tensors.to(*args, **kwargs) | ||
cast_mask = self.mask.to(*args, **kwargs) if self.mask is not None else None | ||
return type(self)(cast_tensor, cast_mask) | ||
def to(self, device): | ||
# type: (Device) -> NestedTensor # noqa | ||
cast_tensor = self.tensors.to(device) | ||
mask = self.mask | ||
if mask is not None: | ||
assert mask is not None | ||
cast_mask = mask.to(device) | ||
else: | ||
cast_mask = None | ||
return NestedTensor(cast_tensor, cast_mask) | ||
|
||
def decompose(self): | ||
return self.tensors, self.mask | ||
|
||
@classmethod | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
def from_tensor_list(cls, tensor_list): | ||
# TODO make this more general | ||
if tensor_list[0].ndim == 3: | ||
# TODO make it support different-sized images | ||
max_size = tuple(max(s) for s in zip(*[img.shape for img in tensor_list])) | ||
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) | ||
batch_shape = (len(tensor_list),) + max_size | ||
b, c, h, w = batch_shape | ||
dtype = tensor_list[0].dtype | ||
device = tensor_list[0].device | ||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device) | ||
mask = torch.ones((b, h, w), dtype=torch.bool, device=device) | ||
for img, pad_img, m in zip(tensor_list, tensor, mask): | ||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) | ||
m[: img.shape[1], :img.shape[2]] = False | ||
else: | ||
raise ValueError('not supported') | ||
return cls(tensor, mask) | ||
|
||
def __repr__(self): | ||
return repr(self.tensors) | ||
return str(self.tensors) | ||
|
||
|
||
def setup_for_distributed(is_master): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This oversight took me a while to figure out, as the error messages were pretty cryptic
Error message from torchscript
cc @eellison if this error message could be improved it would be great.