Skip to content

Commit

Permalink
make RetinaNet scriptable
Browse files Browse the repository at this point in the history
Summary: mainly just type annotations

Reviewed By: theschnitz

Differential Revision: D24822684

fbshipit-source-id: 3f6f471feba456aa2499aad9b92d73673270c426
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Nov 11, 2020
1 parent a1acc7f commit f1d0c05
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 16 deletions.
39 changes: 27 additions & 12 deletions detectron2/modeling/meta_arch/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import logging
import math
import numpy as np
from typing import List
from typing import Dict, List, Tuple
import torch
from fvcore.nn import giou_loss, sigmoid_focal_loss_jit, smooth_l1_loss
from torch import nn
from torch import Tensor, nn
from torch.nn import functional as F

from detectron2.config import configurable
from detectron2.data.detection_utils import convert_image_to_rgb
from detectron2.layers import ShapeSpec, batched_nms, cat, get_norm
from detectron2.layers import ShapeSpec, batched_nms, cat, get_norm, nonzero_tuple
from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou
from detectron2.utils.events import get_event_storage

Expand All @@ -24,7 +24,7 @@
__all__ = ["RetinaNet"]


def permute_to_N_HWA_K(tensor, K):
def permute_to_N_HWA_K(tensor, K: int):
"""
Transpose/reshape a tensor from (N, (Ai x K), H, W) to (N, (HxWxAi), K)
"""
Expand Down Expand Up @@ -224,7 +224,7 @@ def visualize_training(self, batched_inputs, results):
vis_name = f"Top: GT bounding boxes; Bottom: {max_boxes} Highest Scoring Results"
storage.put_image(vis_name, vis_img)

def forward(self, batched_inputs):
def forward(self, batched_inputs: Tuple[Dict[str, Tensor]]):
"""
Args:
batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
Expand Down Expand Up @@ -253,6 +253,7 @@ def forward(self, batched_inputs):
pred_anchor_deltas = [permute_to_N_HWA_K(x, 4) for x in pred_anchor_deltas]

if self.training:
assert not torch.jit.is_scripting(), "Not supported"
assert "instances" in batched_inputs[0], "Instance annotations are missing in training!"
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]

Expand All @@ -270,6 +271,8 @@ def forward(self, batched_inputs):
return losses
else:
results = self.inference(anchors, pred_logits, pred_anchor_deltas, images.image_sizes)
if torch.jit.is_scripting():
return results
processed_results = []
for results_per_image, input_per_image, image_size in zip(
results, batched_inputs, images.image_sizes
Expand Down Expand Up @@ -392,19 +395,25 @@ def label_anchors(self, anchors, gt_instances):

return gt_labels, matched_gt_boxes

def inference(self, anchors, pred_logits, pred_anchor_deltas, image_sizes):
def inference(
self,
anchors: List[Boxes],
pred_logits: List[Tensor],
pred_anchor_deltas: List[Tensor],
image_sizes: List[Tuple[int, int]],
):
"""
Arguments:
anchors (list[Boxes]): A list of #feature level Boxes.
The Boxes contain anchors of this image on the specific feature level.
pred_logits, pred_anchor_deltas: list[Tensor], one per level. Each
has shape (N, Hi * Wi * Ai, K or 4)
image_sizes (List[torch.Size]): the input image sizes
image_sizes (List[(h, w)]): the input image sizes
Returns:
results (List[Instances]): a list of #images elements.
"""
results = []
results: List[Instances] = []
for img_idx, image_size in enumerate(image_sizes):
pred_logits_per_image = [x[img_idx] for x in pred_logits]
deltas_per_image = [x[img_idx] for x in pred_anchor_deltas]
Expand All @@ -414,7 +423,13 @@ def inference(self, anchors, pred_logits, pred_anchor_deltas, image_sizes):
results.append(results_per_image)
return results

def inference_single_image(self, anchors, box_cls, box_delta, image_size):
def inference_single_image(
self,
anchors: List[Boxes],
box_cls: List[Tensor],
box_delta: List[Tensor],
image_size: Tuple[int, int],
):
"""
Single-image inference. Return bounding-box detection results by thresholding
on scores and applying non-maximum suppression (NMS).
Expand Down Expand Up @@ -443,7 +458,7 @@ def inference_single_image(self, anchors, box_cls, box_delta, image_size):
# 1. Keep boxes with confidence score higher than threshold
keep_idxs = predicted_prob > self.test_score_thresh
predicted_prob = predicted_prob[keep_idxs]
topk_idxs = torch.nonzero(keep_idxs, as_tuple=True)[0]
topk_idxs = nonzero_tuple(keep_idxs)[0]

# 2. Keep top k top scoring boxes only
num_topk = min(self.test_topk_candidates, topk_idxs.size(0))
Expand Down Expand Up @@ -476,7 +491,7 @@ def inference_single_image(self, anchors, box_cls, box_delta, image_size):
result.pred_classes = class_idxs_all[keep]
return result

def preprocess_image(self, batched_inputs):
def preprocess_image(self, batched_inputs: Tuple[Dict[str, Tensor]]):
"""
Normalize, pad and batch the input images.
"""
Expand Down Expand Up @@ -575,7 +590,7 @@ def from_config(cls, cfg, input_shape: List[ShapeSpec]):
"num_anchors": num_anchors,
}

def forward(self, features):
def forward(self, features: List[Tensor]):
"""
Arguments:
features (list[Tensor]): FPN feature map tensors in high to low resolution.
Expand Down
7 changes: 6 additions & 1 deletion detectron2/modeling/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@


# perhaps should rename to "resize_instance"
def detector_postprocess(results, output_height, output_width, mask_threshold=0.5):
def detector_postprocess(
results: Instances, output_height: int, output_width: int, mask_threshold: float = 0.5
):
"""
Resize the output instances.
The input images are often resized when entering an object detector.
Expand Down Expand Up @@ -49,6 +51,9 @@ def detector_postprocess(results, output_height, output_width, mask_threshold=0.
output_boxes = results.pred_boxes
elif results.has("proposal_boxes"):
output_boxes = results.proposal_boxes
else:
output_boxes = None
assert output_boxes is not None, "Predictions must contain boxes!"

output_boxes.scale(scale_x, scale_y)
output_boxes.clip(results.image_size)
Expand Down
2 changes: 1 addition & 1 deletion detectron2/structures/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def nonempty(self, threshold: float = 0.0) -> torch.Tensor:
keep = (widths > threshold) & (heights > threshold)
return keep

def __getitem__(self, item):
def __getitem__(self, item) -> "Boxes":
"""
Args:
item: int, slice, or a BoolTensor
Expand Down
28 changes: 26 additions & 2 deletions tests/test_export_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from detectron2.export.torchscript import dump_torchscript_IR, export_torchscript_with_instances
from detectron2.export.torchscript_patch import patch_builtin_len
from detectron2.modeling import build_backbone
from detectron2.modeling.postprocessing import detector_postprocess
from detectron2.structures import Boxes, Instances
from detectron2.utils.env import TORCH_VERSION
from detectron2.utils.testing import assert_instances_allclose, get_sample_coco_image
Expand All @@ -22,9 +23,13 @@
class TestScripting(unittest.TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def testMaskRCNN(self):
self._test_model("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
self._test_rcnn_model("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")

def _test_model(self, config_path):
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def testRetinaNet(self):
self._test_retinanet_model("COCO-Detection/retinanet_R_50_FPN_3x.yaml")

def _test_rcnn_model(self, config_path):
model = model_zoo.get(config_path, trained=True)
model.eval()

Expand All @@ -46,6 +51,25 @@ def _test_model(self, config_path):
].to_instances()
assert_instances_allclose(instance, scripted_instance)

def _test_retinanet_model(self, config_path):
model = model_zoo.get(config_path, trained=True)
model.eval()

fields = {
"pred_boxes": Boxes,
"scores": Tensor,
"pred_classes": Tensor,
}
script_model = export_torchscript_with_instances(model, fields)

img = get_sample_coco_image()
inputs = [{"image": img}]
with torch.no_grad():
instance = model(inputs)[0]["instances"]
scripted_instance = script_model(inputs)[0].to_instances()
scripted_instance = detector_postprocess(scripted_instance, img.shape[1], img.shape[2])
assert_instances_allclose(instance, scripted_instance)


@unittest.skipIf(
os.environ.get("CIRCLECI") or TORCH_VERSION < (1, 8), "Insufficient Pytorch version"
Expand Down

0 comments on commit f1d0c05

Please sign in to comment.