Skip to content

Commit

Permalink
make paste_mask scriptable
Browse files Browse the repository at this point in the history
Reviewed By: theschnitz

Differential Revision: D24822645

fbshipit-source-id: 2c999d0d819cd9007b126cd88019f48b63376525
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Nov 11, 2020
1 parent b74fb4e commit a1acc7f
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 12 deletions.
26 changes: 18 additions & 8 deletions detectron2/layers/mask_ops.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import numpy as np
from typing import Tuple
import torch
from PIL import Image
from torch.nn import functional as F

from detectron2.structures import Boxes

__all__ = ["paste_masks_in_image"]


Expand All @@ -13,7 +16,7 @@
GPU_MEM_LIMIT = 1024 ** 3 # 1 GB memory limit


def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True):
def _do_paste_mask(masks, boxes, img_h: int, img_w: int, skip_empty: bool = True):
"""
Args:
masks: N, 1, H, W
Expand All @@ -33,7 +36,8 @@ def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True):
# Compared to pasting them one by one,
# this has more operations but is faster on COCO-scale dataset.
device = masks.device
if skip_empty:

if skip_empty and not torch.jit.is_scripting():
x0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to(
dtype=torch.int32
)
Expand All @@ -56,17 +60,20 @@ def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True):
gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
grid = torch.stack([gx, gy], dim=3)

if not masks.dtype.is_floating_point:
masks = masks.float()
if not torch.jit.is_scripting():
if not masks.dtype.is_floating_point:
masks = masks.float()
img_masks = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False)

if skip_empty:
if skip_empty and not torch.jit.is_scripting():
return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int))
else:
return img_masks[:, 0], ()


def paste_masks_in_image(masks, boxes, image_shape, threshold=0.5):
def paste_masks_in_image(
masks: torch.Tensor, boxes: Boxes, image_shape: Tuple[int, int], threshold: float = 0.5
):
"""
Paste a set of masks that are of a fixed resolution (e.g., 28 x 28) into an image.
The location, height, and width for pasting each mask is determined by their
Expand Down Expand Up @@ -106,7 +113,7 @@ def paste_masks_in_image(masks, boxes, image_shape, threshold=0.5):

# The actual implementation split the input into chunks,
# and paste them chunk by chunk.
if device.type == "cpu":
if device.type == "cpu" or torch.jit.is_scripting():
# CPU is most efficient when they are pasted one by one with skip_empty=True
# so that it performs minimal number of operations.
num_chunks = N
Expand All @@ -133,7 +140,10 @@ def paste_masks_in_image(masks, boxes, image_shape, threshold=0.5):
# for visualization and debugging
masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8)

img_masks[(inds,) + spatial_inds] = masks_chunk
if torch.jit.is_scripting(): # Scripting does not use the optimized codepath
img_masks[inds] = masks_chunk
else:
img_masks[(inds,) + spatial_inds] = masks_chunk
return img_masks


Expand Down
2 changes: 1 addition & 1 deletion detectron2/layers/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def forward(self, x):

ConvTranspose2d = torch.nn.ConvTranspose2d
BatchNorm2d = torch.nn.BatchNorm2d
interpolate = torch.nn.functional.interpolate
interpolate = F.interpolate


if TORCH_VERSION > (1, 5):
Expand Down
7 changes: 4 additions & 3 deletions detectron2/structures/keypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import numpy as np
from typing import Any, List, Tuple, Union
import torch

from detectron2.layers import interpolate
from torch.nn import functional as F


class Keypoints:
Expand Down Expand Up @@ -181,7 +180,9 @@ def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tenso

for i in range(num_rois):
outsize = (int(heights_ceil[i]), int(widths_ceil[i]))
roi_map = interpolate(maps[[i]], size=outsize, mode="bicubic", align_corners=False).squeeze(
roi_map = F.interpolate(
maps[[i]], size=outsize, mode="bicubic", align_corners=False
).squeeze(
0
) # #keypoints x H x W

Expand Down
12 changes: 12 additions & 0 deletions tests/layers/test_mask_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from detectron2.structures import BitMasks, Boxes, BoxMode, PolygonMasks
from detectron2.structures.masks import polygons_to_bitmask
from detectron2.utils.file_io import PathManager
from detectron2.utils.testing import random_boxes


def iou_between_full_image_bit_masks(a, b):
Expand Down Expand Up @@ -152,6 +153,17 @@ def test_polygon_area(self):
target = d ** 2 / 2
self.assertEqual(area, target)

def test_paste_mask_scriptable(self):
scripted_f = torch.jit.script(paste_masks_in_image)
N = 10
masks = torch.rand(N, 28, 28)
boxes = Boxes(random_boxes(N, 100))
image_shape = (150, 150)

out = paste_masks_in_image(masks, boxes, image_shape)
scripted_out = scripted_f(masks, boxes, image_shape)
self.assertTrue(torch.equal(out, scripted_out))


def benchmark_paste():
S = 800
Expand Down

0 comments on commit a1acc7f

Please sign in to comment.