diff --git a/detectron2/modeling/meta_arch/retinanet.py b/detectron2/modeling/meta_arch/retinanet.py index 0bdc2ee009..103f1ab950 100644 --- a/detectron2/modeling/meta_arch/retinanet.py +++ b/detectron2/modeling/meta_arch/retinanet.py @@ -2,15 +2,15 @@ import logging import math import numpy as np -from typing import List +from typing import Dict, List, Tuple import torch from fvcore.nn import giou_loss, sigmoid_focal_loss_jit, smooth_l1_loss -from torch import nn +from torch import Tensor, nn from torch.nn import functional as F from detectron2.config import configurable from detectron2.data.detection_utils import convert_image_to_rgb -from detectron2.layers import ShapeSpec, batched_nms, cat, get_norm +from detectron2.layers import ShapeSpec, batched_nms, cat, get_norm, nonzero_tuple from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou from detectron2.utils.events import get_event_storage @@ -24,7 +24,7 @@ __all__ = ["RetinaNet"] -def permute_to_N_HWA_K(tensor, K): +def permute_to_N_HWA_K(tensor, K: int): """ Transpose/reshape a tensor from (N, (Ai x K), H, W) to (N, (HxWxAi), K) """ @@ -224,7 +224,7 @@ def visualize_training(self, batched_inputs, results): vis_name = f"Top: GT bounding boxes; Bottom: {max_boxes} Highest Scoring Results" storage.put_image(vis_name, vis_img) - def forward(self, batched_inputs): + def forward(self, batched_inputs: Tuple[Dict[str, Tensor]]): """ Args: batched_inputs: a list, batched outputs of :class:`DatasetMapper` . @@ -253,6 +253,7 @@ def forward(self, batched_inputs): pred_anchor_deltas = [permute_to_N_HWA_K(x, 4) for x in pred_anchor_deltas] if self.training: + assert not torch.jit.is_scripting(), "Not supported" assert "instances" in batched_inputs[0], "Instance annotations are missing in training!" gt_instances = [x["instances"].to(self.device) for x in batched_inputs] @@ -270,6 +271,8 @@ def forward(self, batched_inputs): return losses else: results = self.inference(anchors, pred_logits, pred_anchor_deltas, images.image_sizes) + if torch.jit.is_scripting(): + return results processed_results = [] for results_per_image, input_per_image, image_size in zip( results, batched_inputs, images.image_sizes @@ -392,19 +395,25 @@ def label_anchors(self, anchors, gt_instances): return gt_labels, matched_gt_boxes - def inference(self, anchors, pred_logits, pred_anchor_deltas, image_sizes): + def inference( + self, + anchors: List[Boxes], + pred_logits: List[Tensor], + pred_anchor_deltas: List[Tensor], + image_sizes: List[Tuple[int, int]], + ): """ Arguments: anchors (list[Boxes]): A list of #feature level Boxes. The Boxes contain anchors of this image on the specific feature level. pred_logits, pred_anchor_deltas: list[Tensor], one per level. Each has shape (N, Hi * Wi * Ai, K or 4) - image_sizes (List[torch.Size]): the input image sizes + image_sizes (List[(h, w)]): the input image sizes Returns: results (List[Instances]): a list of #images elements. """ - results = [] + results: List[Instances] = [] for img_idx, image_size in enumerate(image_sizes): pred_logits_per_image = [x[img_idx] for x in pred_logits] deltas_per_image = [x[img_idx] for x in pred_anchor_deltas] @@ -414,7 +423,13 @@ def inference(self, anchors, pred_logits, pred_anchor_deltas, image_sizes): results.append(results_per_image) return results - def inference_single_image(self, anchors, box_cls, box_delta, image_size): + def inference_single_image( + self, + anchors: List[Boxes], + box_cls: List[Tensor], + box_delta: List[Tensor], + image_size: Tuple[int, int], + ): """ Single-image inference. Return bounding-box detection results by thresholding on scores and applying non-maximum suppression (NMS). @@ -443,7 +458,7 @@ def inference_single_image(self, anchors, box_cls, box_delta, image_size): # 1. Keep boxes with confidence score higher than threshold keep_idxs = predicted_prob > self.test_score_thresh predicted_prob = predicted_prob[keep_idxs] - topk_idxs = torch.nonzero(keep_idxs, as_tuple=True)[0] + topk_idxs = nonzero_tuple(keep_idxs)[0] # 2. Keep top k top scoring boxes only num_topk = min(self.test_topk_candidates, topk_idxs.size(0)) @@ -476,7 +491,7 @@ def inference_single_image(self, anchors, box_cls, box_delta, image_size): result.pred_classes = class_idxs_all[keep] return result - def preprocess_image(self, batched_inputs): + def preprocess_image(self, batched_inputs: Tuple[Dict[str, Tensor]]): """ Normalize, pad and batch the input images. """ @@ -575,7 +590,7 @@ def from_config(cls, cfg, input_shape: List[ShapeSpec]): "num_anchors": num_anchors, } - def forward(self, features): + def forward(self, features: List[Tensor]): """ Arguments: features (list[Tensor]): FPN feature map tensors in high to low resolution. diff --git a/detectron2/modeling/postprocessing.py b/detectron2/modeling/postprocessing.py index 9088e9ff31..f42e77c52f 100644 --- a/detectron2/modeling/postprocessing.py +++ b/detectron2/modeling/postprocessing.py @@ -8,7 +8,9 @@ # perhaps should rename to "resize_instance" -def detector_postprocess(results, output_height, output_width, mask_threshold=0.5): +def detector_postprocess( + results: Instances, output_height: int, output_width: int, mask_threshold: float = 0.5 +): """ Resize the output instances. The input images are often resized when entering an object detector. @@ -49,6 +51,9 @@ def detector_postprocess(results, output_height, output_width, mask_threshold=0. output_boxes = results.pred_boxes elif results.has("proposal_boxes"): output_boxes = results.proposal_boxes + else: + output_boxes = None + assert output_boxes is not None, "Predictions must contain boxes!" output_boxes.scale(scale_x, scale_y) output_boxes.clip(results.image_size) diff --git a/detectron2/structures/boxes.py b/detectron2/structures/boxes.py index ed47eb0161..7cbe1be9bf 100644 --- a/detectron2/structures/boxes.py +++ b/detectron2/structures/boxes.py @@ -213,7 +213,7 @@ def nonempty(self, threshold: float = 0.0) -> torch.Tensor: keep = (widths > threshold) & (heights > threshold) return keep - def __getitem__(self, item): + def __getitem__(self, item) -> "Boxes": """ Args: item: int, slice, or a BoolTensor diff --git a/tests/test_export_torchscript.py b/tests/test_export_torchscript.py index 4fca17ffca..f65da2351d 100644 --- a/tests/test_export_torchscript.py +++ b/tests/test_export_torchscript.py @@ -11,6 +11,7 @@ from detectron2.export.torchscript import dump_torchscript_IR, export_torchscript_with_instances from detectron2.export.torchscript_patch import patch_builtin_len from detectron2.modeling import build_backbone +from detectron2.modeling.postprocessing import detector_postprocess from detectron2.structures import Boxes, Instances from detectron2.utils.env import TORCH_VERSION from detectron2.utils.testing import assert_instances_allclose, get_sample_coco_image @@ -22,9 +23,13 @@ class TestScripting(unittest.TestCase): @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def testMaskRCNN(self): - self._test_model("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") + self._test_rcnn_model("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") - def _test_model(self, config_path): + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def testRetinaNet(self): + self._test_retinanet_model("COCO-Detection/retinanet_R_50_FPN_3x.yaml") + + def _test_rcnn_model(self, config_path): model = model_zoo.get(config_path, trained=True) model.eval() @@ -46,6 +51,25 @@ def _test_model(self, config_path): ].to_instances() assert_instances_allclose(instance, scripted_instance) + def _test_retinanet_model(self, config_path): + model = model_zoo.get(config_path, trained=True) + model.eval() + + fields = { + "pred_boxes": Boxes, + "scores": Tensor, + "pred_classes": Tensor, + } + script_model = export_torchscript_with_instances(model, fields) + + img = get_sample_coco_image() + inputs = [{"image": img}] + with torch.no_grad(): + instance = model(inputs)[0]["instances"] + scripted_instance = script_model(inputs)[0].to_instances() + scripted_instance = detector_postprocess(scripted_instance, img.shape[1], img.shape[2]) + assert_instances_allclose(instance, scripted_instance) + @unittest.skipIf( os.environ.get("CIRCLECI") or TORCH_VERSION < (1, 8), "Insufficient Pytorch version"