Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Exporting Mask Rcnn to ONNX #1461

Merged
merged 6 commits into from
Oct 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion test/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
from torchvision.models.detection.mask_rcnn import MaskRCNNHeads, MaskRCNNPredictor

from collections import OrderedDict

Expand Down Expand Up @@ -259,7 +260,7 @@ def forward(self_module, features):
model = RoiHeadsModule(images)
model.eval()
model(features)
self.run_model(model, [(features,), (test_features,)], tolerate_small_mismatch=True)
self.run_model(model, [(features,), (test_features,)])

def get_image_from_url(self, url):
import requests
Expand Down Expand Up @@ -294,6 +295,45 @@ def test_faster_rcnn(self):
model(images)
self.run_model(model, [(images,), (test_images,)])

# Verify that paste_mask_in_image beahves the same in tracing.
# This test also compares both paste_masks_in_image and _onnx_paste_masks_in_image
# (since jit_trace witll call _onnx_paste_masks_in_image).
def test_paste_mask_in_image(self):
masks = torch.rand(10, 1, 26, 26)
boxes = torch.rand(10, 4)
boxes[:, 2:] += torch.rand(10, 2)
boxes *= 50
o_im_s = (100, 100)
from torchvision.models.detection.roi_heads import paste_masks_in_image
out = paste_masks_in_image(masks, boxes, o_im_s)
jit_trace = torch.jit.trace(paste_masks_in_image,
(masks, boxes,
[torch.tensor(o_im_s[0]),
torch.tensor(o_im_s[1])]))
out_trace = jit_trace(masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])])

assert torch.all(out.eq(out_trace))

masks2 = torch.rand(20, 1, 26, 26)
boxes2 = torch.rand(20, 4)
boxes2[:, 2:] += torch.rand(20, 2)
boxes2 *= 100
o_im_s2 = (200, 200)
from torchvision.models.detection.roi_heads import paste_masks_in_image
out2 = paste_masks_in_image(masks2, boxes2, o_im_s2)
out_trace2 = jit_trace(masks2, boxes2, [torch.tensor(o_im_s2[0]), torch.tensor(o_im_s2[1])])

assert torch.all(out2.eq(out_trace2))

@unittest.skip("Disable test until Resize opset 11 is implemented in ONNX Runtime")
def test_mask_rcnn(self):
images, test_images = self.get_test_images()

model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True)
model.eval()
model(images)
self.run_model(model, [(images,), (test_images,)])

lara-hdr marked this conversation as resolved.
Show resolved Hide resolved

if __name__ == '__main__':
unittest.main()
91 changes: 88 additions & 3 deletions torchvision/models/detection/roi_heads.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torchvision

import torch.nn.functional as F
from torch import nn
Expand Down Expand Up @@ -73,7 +74,11 @@ def maskrcnn_inference(x, labels):
index = torch.arange(num_masks, device=labels.device)
mask_prob = mask_prob[index, labels][:, None]

mask_prob = mask_prob.split(boxes_per_image, dim=0)
if len(boxes_per_image) == 1:
# TODO : remove when dynamic split supported in ONNX
mask_prob = (mask_prob,)
else:
mask_prob = mask_prob.split(boxes_per_image, dim=0)

return mask_prob

Expand Down Expand Up @@ -250,10 +255,29 @@ def keypointrcnn_inference(x, boxes):
return kp_probs, kp_scores


def _onnx_expand_boxes(boxes, scale):
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
lara-hdr marked this conversation as resolved.
Show resolved Hide resolved
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
y_c = (boxes[:, 3] + boxes[:, 1]) * .5

w_half = w_half.to(dtype=torch.float32) * scale
h_half = h_half.to(dtype=torch.float32) * scale

boxes_exp0 = x_c - w_half
boxes_exp1 = y_c - h_half
boxes_exp2 = x_c + w_half
boxes_exp3 = y_c + h_half
boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
return boxes_exp


# the next two functions should be merged inside Masker
# but are kept here for the moment while we need them
# temporarily for paste_mask_in_image
def expand_boxes(boxes, scale):
if torchvision._is_tracing():
return _onnx_expand_boxes(boxes, scale)
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
Expand All @@ -272,7 +296,10 @@ def expand_boxes(boxes, scale):

def expand_masks(mask, padding):
M = mask.shape[-1]
scale = float(M + 2 * padding) / M
if torchvision._is_tracing():
scale = (M + 2 * padding).to(torch.float32) / M.to(torch.float32)
else:
scale = float(M + 2 * padding) / M
padded_mask = torch.nn.functional.pad(mask, (padding,) * 4)
return padded_mask, scale

Expand Down Expand Up @@ -303,11 +330,69 @@ def paste_mask_in_image(mask, box, im_h, im_w):
return im_mask


def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
lara-hdr marked this conversation as resolved.
Show resolved Hide resolved
one = torch.ones(1, dtype=torch.int64)
zero = torch.zeros(1, dtype=torch.int64)

w = (box[2] - box[0] + one)
h = (box[3] - box[1] + one)
w = torch.max(torch.cat((w, one)))
h = torch.max(torch.cat((h, one)))

# Set shape to [batchxCxHxW]
mask = mask.expand((1, 1, mask.size(0), mask.size(1)))

# Resize mask
mask = torch.nn.functional.interpolate(mask, size=(int(h), int(w)), mode='bilinear', align_corners=False)
mask = mask[0][0]

x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))

unpaded_im_mask = mask[(y_0 - box[1]):(y_1 - box[1]),
(x_0 - box[0]):(x_1 - box[0])]

# TODO : replace below with a dynamic padding when support is added in ONNX

# pad y
zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
concat_0 = torch.cat((zeros_y0,
unpaded_im_mask.to(dtype=torch.float32),
zeros_y1), 0)[0:im_h, :]
# pad x
zeros_x0 = torch.zeros(concat_0.size(0), x_0)
zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
im_mask = torch.cat((zeros_x0,
concat_0,
zeros_x1), 1)[:, :im_w]
return im_mask


@torch.jit.script
def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
res_append = torch.zeros(0, im_h, im_w)
for i in range(masks.size(0)):
mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
mask_res = mask_res.unsqueeze(0)
res_append = torch.cat((res_append, mask_res))
lara-hdr marked this conversation as resolved.
Show resolved Hide resolved
return res_append


def paste_masks_in_image(masks, boxes, img_shape, padding=1):
lara-hdr marked this conversation as resolved.
Show resolved Hide resolved
masks, scale = expand_masks(masks, padding=padding)
boxes = expand_boxes(boxes, scale).to(dtype=torch.int64).tolist()
boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
# im_h, im_w = img_shape.tolist()
im_h, im_w = img_shape

if torchvision._is_tracing():
return _onnx_paste_masks_in_image_loop(masks, boxes,
torch.scalar_tensor(im_h, dtype=torch.int64),
torch.scalar_tensor(im_w, dtype=torch.int64))[:, None]

boxes = boxes.tolist()
res = [
paste_mask_in_image(m[0], b, im_h, im_w)
for m, b in zip(masks, boxes)
Expand Down