From a1acc7fe9ac713503da8bf2f4380a4ef55b72c91 Mon Sep 17 00:00:00 2001 From: Yuxin Wu Date: Tue, 10 Nov 2020 23:43:19 -0800 Subject: [PATCH] make paste_mask scriptable Reviewed By: theschnitz Differential Revision: D24822645 fbshipit-source-id: 2c999d0d819cd9007b126cd88019f48b63376525 --- detectron2/layers/mask_ops.py | 26 ++++++++++++++++++-------- detectron2/layers/wrappers.py | 2 +- detectron2/structures/keypoints.py | 7 ++++--- tests/layers/test_mask_ops.py | 12 ++++++++++++ 4 files changed, 35 insertions(+), 12 deletions(-) diff --git a/detectron2/layers/mask_ops.py b/detectron2/layers/mask_ops.py index cfdc05ceb9..c698a03c4d 100644 --- a/detectron2/layers/mask_ops.py +++ b/detectron2/layers/mask_ops.py @@ -1,9 +1,12 @@ # Copyright (c) Facebook, Inc. and its affiliates. import numpy as np +from typing import Tuple import torch from PIL import Image from torch.nn import functional as F +from detectron2.structures import Boxes + __all__ = ["paste_masks_in_image"] @@ -13,7 +16,7 @@ GPU_MEM_LIMIT = 1024 ** 3 # 1 GB memory limit -def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True): +def _do_paste_mask(masks, boxes, img_h: int, img_w: int, skip_empty: bool = True): """ Args: masks: N, 1, H, W @@ -33,7 +36,8 @@ def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True): # Compared to pasting them one by one, # this has more operations but is faster on COCO-scale dataset. device = masks.device - if skip_empty: + + if skip_empty and not torch.jit.is_scripting(): x0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to( dtype=torch.int32 ) @@ -56,17 +60,20 @@ def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True): gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1)) grid = torch.stack([gx, gy], dim=3) - if not masks.dtype.is_floating_point: - masks = masks.float() + if not torch.jit.is_scripting(): + if not masks.dtype.is_floating_point: + masks = masks.float() img_masks = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False) - if skip_empty: + if skip_empty and not torch.jit.is_scripting(): return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int)) else: return img_masks[:, 0], () -def paste_masks_in_image(masks, boxes, image_shape, threshold=0.5): +def paste_masks_in_image( + masks: torch.Tensor, boxes: Boxes, image_shape: Tuple[int, int], threshold: float = 0.5 +): """ Paste a set of masks that are of a fixed resolution (e.g., 28 x 28) into an image. The location, height, and width for pasting each mask is determined by their @@ -106,7 +113,7 @@ def paste_masks_in_image(masks, boxes, image_shape, threshold=0.5): # The actual implementation split the input into chunks, # and paste them chunk by chunk. - if device.type == "cpu": + if device.type == "cpu" or torch.jit.is_scripting(): # CPU is most efficient when they are pasted one by one with skip_empty=True # so that it performs minimal number of operations. num_chunks = N @@ -133,7 +140,10 @@ def paste_masks_in_image(masks, boxes, image_shape, threshold=0.5): # for visualization and debugging masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8) - img_masks[(inds,) + spatial_inds] = masks_chunk + if torch.jit.is_scripting(): # Scripting does not use the optimized codepath + img_masks[inds] = masks_chunk + else: + img_masks[(inds,) + spatial_inds] = masks_chunk return img_masks diff --git a/detectron2/layers/wrappers.py b/detectron2/layers/wrappers.py index b65c0ae993..ae162fd6fc 100644 --- a/detectron2/layers/wrappers.py +++ b/detectron2/layers/wrappers.py @@ -85,7 +85,7 @@ def forward(self, x): ConvTranspose2d = torch.nn.ConvTranspose2d BatchNorm2d = torch.nn.BatchNorm2d -interpolate = torch.nn.functional.interpolate +interpolate = F.interpolate if TORCH_VERSION > (1, 5): diff --git a/detectron2/structures/keypoints.py b/detectron2/structures/keypoints.py index 8dcfe452d7..c07c24faf8 100644 --- a/detectron2/structures/keypoints.py +++ b/detectron2/structures/keypoints.py @@ -2,8 +2,7 @@ import numpy as np from typing import Any, List, Tuple, Union import torch - -from detectron2.layers import interpolate +from torch.nn import functional as F class Keypoints: @@ -181,7 +180,9 @@ def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tenso for i in range(num_rois): outsize = (int(heights_ceil[i]), int(widths_ceil[i])) - roi_map = interpolate(maps[[i]], size=outsize, mode="bicubic", align_corners=False).squeeze( + roi_map = F.interpolate( + maps[[i]], size=outsize, mode="bicubic", align_corners=False + ).squeeze( 0 ) # #keypoints x H x W diff --git a/tests/layers/test_mask_ops.py b/tests/layers/test_mask_ops.py index 28cc916526..9236ce83e5 100644 --- a/tests/layers/test_mask_ops.py +++ b/tests/layers/test_mask_ops.py @@ -23,6 +23,7 @@ from detectron2.structures import BitMasks, Boxes, BoxMode, PolygonMasks from detectron2.structures.masks import polygons_to_bitmask from detectron2.utils.file_io import PathManager +from detectron2.utils.testing import random_boxes def iou_between_full_image_bit_masks(a, b): @@ -152,6 +153,17 @@ def test_polygon_area(self): target = d ** 2 / 2 self.assertEqual(area, target) + def test_paste_mask_scriptable(self): + scripted_f = torch.jit.script(paste_masks_in_image) + N = 10 + masks = torch.rand(N, 28, 28) + boxes = Boxes(random_boxes(N, 100)) + image_shape = (150, 150) + + out = paste_masks_in_image(masks, boxes, image_shape) + scripted_out = scripted_f(masks, boxes, image_shape) + self.assertTrue(torch.equal(out, scripted_out)) + def benchmark_paste(): S = 800