From 0515496b76db7e53cba771e6d867122c25dfb283 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Mon, 23 Oct 2023 16:35:56 +0300 Subject: [PATCH] Bugfix of model.export() to work correct with bs>1 (#1551) * Bugfixes: 1) Picking wrong indexes with exporting models with bs>1 in decoding modules (YoloNAS, YoloNAS-Pose, YoloX, PPYolo-E) 2) When using ONNX NonMaxSupression a max_predictions_per_image parameter was used incorrectly (Flat/Batch format predictions could contain more than a requested number of detections) * Fixing NMS postprocessing to make it TRT compatible * Update pose estimation support matrix for exported models * Update compatibility matrix for TRT NMS * Improve ConvertTRTFormatToFlatTensor implementation * Improve ConvertTRTFormatToFlatTensor implementation * Improve ConvertTRTFormatToFlatTensor implementation * Drop use_boolean_gather * Improve test * Update compatibility matrix * Improve test * Improve test * Update notebook export --- documentation/source/models_export.md | 35 +-- src/super_gradients/conversion/onnx/nms.py | 207 ++++++++++++----- .../conversion/onnx/pose_nms.py | 159 ++++++++----- src/super_gradients/conversion/onnx/utils.py | 12 +- .../conversion/tensorrt/nms.py | 89 +++++--- .../examples/model_export/models_export.ipynb | 2 +- .../module_interfaces/exportable_detector.py | 3 +- .../detection_models/pp_yolo_e/pp_yolo_e.py | 7 +- .../models/detection_models/yolo_base.py | 10 +- .../yolo_nas/yolo_nas_variants.py | 27 ++- .../yolo_nas_pose/yolo_nas_pose_variants.py | 22 +- .../unit_tests/export_detection_model_test.py | 208 ++++++++++-------- 12 files changed, 525 insertions(+), 256 deletions(-) diff --git a/documentation/source/models_export.md b/documentation/source/models_export.md index 33de990a3b..719e1ffbdc 100644 --- a/documentation/source/models_export.md +++ b/documentation/source/models_export.md @@ -28,6 +28,15 @@ A new export API is introduced in SG 3.2.0. It is aimed to simplify the export p - Customising NMS parameters and number of detections per image - Customising output format (flat or batched) + +```python +!pip install super_gradients==3.3.1 +``` + + ERROR: Could not find a version that satisfies the requirement super_gradients==3.3.1 (from versions: 1.3.0, 1.3.1, 1.4.0, 1.5.0, 1.5.1, 1.5.2, 1.6.0, 1.7.1, 1.7.2, 1.7.3, 1.7.4, 1.7.5, 2.0.0, 2.0.1, 2.1.0, 2.2.0, 2.5.0, 2.6.0, 3.0.0, 3.0.1, 3.0.2, 3.0.3, 3.0.4, 3.0.5, 3.0.6, 3.0.7, 3.0.8, 3.0.9, 3.1.0, 3.1.1, 3.1.2, 3.1.3, 3.2.0, 3.2.1, 3.3.0) + ERROR: No matching distribution found for super_gradients==3.3.1 + + ### Minimalistic export example Let start with the most simple example of exporting a model to ONNX format. @@ -203,9 +212,9 @@ pred_boxes, pred_boxes.shape [ 35.71795, 249.40926, 176.62216, 544.69794], [182.39618, 249.49301, 301.44122, 529.3324 ], ..., - [ 0. , 0. , 0. , 0. ], - [ 0. , 0. , 0. , 0. ], - [ 0. , 0. , 0. , 0. ]]], dtype=float32), + [ -1. , -1. , -1. , -1. ], + [ -1. , -1. , -1. , -1. ], + [ -1. , -1. , -1. , -1. ]]], dtype=float32), (1, 1000, 4)) @@ -219,8 +228,8 @@ pred_scores, pred_scores.shape - (array([[0.9694027, 0.9693378, 0.9665707, 0.9619047, 0.7538769, ..., - 0. , 0. , 0. , 0. , 0. ]], + (array([[ 0.9694027, 0.9693378, 0.9665707, 0.9619047, 0.7538769, ..., + -1. , -1. , -1. , -1. , -1. ]], dtype=float32), (1, 1000)) @@ -235,8 +244,8 @@ pred_classes, pred_classes.shape - (array([[0, 0, 0, 0, 0, 0, 0, 0, 2, 2, ..., 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], - dtype=int64), + (array([[ 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, ..., -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1]], dtype=int64), (1, 1000)) @@ -295,7 +304,7 @@ show_predictions_from_batch_format(image, result) -![png](models_export_files/models_export_18_0.png) +![png](models_export_files/models_export_19_0.png) @@ -411,7 +420,7 @@ show_predictions_from_flat_format(image, result) -![png](models_export_files/models_export_24_0.png) +![png](models_export_files/models_export_25_0.png) @@ -447,7 +456,7 @@ show_predictions_from_flat_format(image, result) -![png](models_export_files/models_export_26_0.png) +![png](models_export_files/models_export_27_0.png) @@ -481,7 +490,7 @@ show_predictions_from_flat_format(image, result) -![png](models_export_files/models_export_28_0.png) +![png](models_export_files/models_export_29_0.png) @@ -522,12 +531,12 @@ result = session.run(outputs, {inputs[0]: image_bchw}) show_predictions_from_flat_format(image, result) ``` - 25%|█████████████████████████████████████████████████ | 4/16 [00:11<00:34, 2.90s/it] + 25%|██████████████████████████████▊ | 4/16 [00:11<00:33, 2.79s/it] -![png](models_export_files/models_export_30_1.png) +![png](models_export_files/models_export_31_1.png) diff --git a/src/super_gradients/conversion/onnx/nms.py b/src/super_gradients/conversion/onnx/nms.py index 219d158afb..c3f120acdd 100644 --- a/src/super_gradients/conversion/onnx/nms.py +++ b/src/super_gradients/conversion/onnx/nms.py @@ -1,10 +1,11 @@ import os import tempfile -from typing import Tuple +from typing import Tuple, Optional, Mapping import numpy as np import onnx import onnx.shape_inference +import onnxsim import torch from onnx import TensorProto from torch import nn, Tensor @@ -13,7 +14,7 @@ from super_gradients.conversion.conversion_enums import DetectionOutputFormatMode from super_gradients.conversion.conversion_utils import numpy_dtype_to_torch_dtype from super_gradients.conversion.gs_utils import import_onnx_graphsurgeon_or_fail_with_instructions -from super_gradients.conversion.onnx.utils import append_graphs +from super_gradients.conversion.onnx.utils import append_graphs, iteratively_infer_shapes logger = get_logger(__name__) @@ -24,6 +25,17 @@ class PickNMSPredictionsAndReturnAsBatchedResult(nn.Module): __constants__ = ("batch_size", "max_predictions_per_image") def __init__(self, batch_size: int, num_pre_nms_predictions: int, max_predictions_per_image: int): + """ + Select the predictions from ONNX NMS node and return them in batch format. + + :param batch_size: A fixed batch size for the model + :param num_pre_nms_predictions: The number of predictions before NMS step + :param max_predictions_per_image: Maximum number of predictions per image + """ + if max_predictions_per_image > num_pre_nms_predictions: + raise ValueError( + f"max_predictions_per_image ({max_predictions_per_image}) cannot be greater than num_pre_nms_predictions ({num_pre_nms_predictions})" + ) super().__init__() self.batch_size = batch_size self.num_pre_nms_predictions = num_pre_nms_predictions @@ -32,49 +44,78 @@ def __init__(self, batch_size: int, num_pre_nms_predictions: int, max_prediction def forward(self, pred_boxes: Tensor, pred_scores: Tensor, selected_indexes: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """ Select the predictions that are output by the NMS plugin. - :param pred_boxes: [B, N, 4] tensor, float32 - :param pred_scores: [B, N, C] tensor, float32 + :param pred_boxes: [B, N, 4] tensor, float32 in XYXY format + :param pred_scores: [B, N, C] tensor, float32 :param selected_indexes: [num_selected_indices, 3], int64 - each row is [batch_indexes, label_indexes, boxes_indexes] - :return: A tuple of 4 tensors (num_detections, detection_boxes, detection_scores, detection_classes) will be returned: - - A tensor of [batch_size, 1] containing the image indices for each detection. - - A tensor of [batch_size, max_output_boxes, 4] containing the bounding box coordinates for each detection in [x1, y1, x2, y2] format. - - A tensor of [batch_size, max_output_boxes] containing the confidence scores for each detection. - - A tensor of [batch_size, max_output_boxes] containing the class indices for each detection. + :return: A tuple of 4 tensors (num_detections, detection_boxes, detection_scores, detection_classes) will be returned: + - A tensor of [batch_size, 1] containing the image indices for each detection. + - A tensor of [batch_size, max_predictions_per_image, 4] containing the bounding box coordinates + for each detection in [x1, y1, x2, y2] format. + - A tensor of [batch_size, max_predictions_per_image] containing the confidence scores for each detection. + - A tensor of [batch_size, max_predictions_per_image] containing the class indices for each detection. """ - batch_indexes, label_indexes, boxes_indexes = selected_indexes[:, 0], selected_indexes[:, 1], selected_indexes[:, 2] + batch_indexes = selected_indexes[:, 0] + label_indexes = selected_indexes[:, 1] + boxes_indexes = selected_indexes[:, 2] selected_boxes = pred_boxes[batch_indexes, boxes_indexes] selected_scores = pred_scores[batch_indexes, boxes_indexes, label_indexes] - predictions = torch.cat([batch_indexes.unsqueeze(1), selected_boxes, selected_scores.unsqueeze(1), label_indexes.unsqueeze(1)], dim=1) + if self.batch_size == 1: + pred_boxes = selected_boxes[: self.max_predictions_per_image] + pred_scores = selected_scores[: self.max_predictions_per_image] + pred_classes = label_indexes[: self.max_predictions_per_image].long() + num_predictions = pred_boxes.size(0).reshape(1, 1) - predictions = torch.nn.functional.pad( - predictions, (0, 0, 0, self.max_predictions_per_image * self.batch_size - predictions.size(0)), value=-1, mode="constant" - ) + pad_size = self.max_predictions_per_image - pred_boxes.size(0) + pred_boxes = torch.nn.functional.pad(pred_boxes, [0, 0, 0, pad_size], value=-1, mode="constant") + pred_scores = torch.nn.functional.pad(pred_scores, [0, pad_size], value=-1, mode="constant") + pred_classes = torch.nn.functional.pad(pred_classes, [0, pad_size], value=-1, mode="constant") - batch_predictions = torch.zeros((self.batch_size, self.max_predictions_per_image, 6), dtype=predictions.dtype, device=predictions.device) + return num_predictions, pred_boxes.unsqueeze(0), pred_scores.unsqueeze(0), pred_classes.unsqueeze(0) + else: + predictions = torch.cat([selected_boxes, selected_scores.unsqueeze(1), label_indexes.unsqueeze(1)], dim=1) - batch_indexes = torch.arange(start=0, end=self.batch_size, step=1, device=predictions.device).to(dtype=predictions.dtype) - masks = batch_indexes.view(-1, 1).eq(predictions[:, 0].view(1, -1)) # [B, N] + batch_predictions = torch.zeros((self.batch_size, self.max_predictions_per_image, 6), dtype=predictions.dtype, device=predictions.device) - num_predictions = torch.sum(masks, dim=1).long() + image_indexes = torch.arange(start=0, end=self.batch_size, step=1, device=predictions.device) + masks = image_indexes.view(self.batch_size, 1) == batch_indexes.view(1, selected_indexes.size(0)) # [B, L] - for i in range(self.batch_size): - selected_predictions = predictions[masks[i]] - selected_predictions = selected_predictions[:, 1:] - batch_predictions[i] = torch.nn.functional.pad( - selected_predictions, (0, 0, 0, self.max_predictions_per_image - selected_predictions.size(0)), value=0, mode="constant" - ) + # Add dummy row to mask and predictions to ensure that we always have at least one prediction per image + # ONNX/TRT deals poorly with tensors that has zero dims, and we need to ensure that we always have at least one prediction per image + masks = torch.cat([masks, torch.zeros((self.batch_size, 1), dtype=masks.dtype, device=predictions.device)], dim=1) # [B, L+1] + predictions = torch.cat([predictions, torch.zeros((1, 6), dtype=predictions.dtype, device=predictions.device)], dim=0) # [L+1, 6] + + num_predictions = torch.sum(masks, dim=1, keepdim=True).long() + num_predictions_capped = torch.clamp_max(num_predictions, self.max_predictions_per_image) - pred_boxes = batch_predictions[:, :, 0:4] - pred_scores = batch_predictions[:, :, 4] - pred_classes = batch_predictions[:, :, 5].long() + for i in range(self.batch_size): + selected_predictions = predictions[masks[i]] + pad_size = self.num_pre_nms_predictions - selected_predictions.size(0) + selected_predictions = torch.nn.functional.pad(selected_predictions, [0, 0, 0, pad_size], value=-1, mode="constant") + selected_predictions = selected_predictions[0 : self.max_predictions_per_image] - return num_predictions.unsqueeze(1), pred_boxes, pred_scores, pred_classes + batch_predictions[i] = selected_predictions + + pred_boxes = batch_predictions[:, :, 0:4] + pred_scores = batch_predictions[:, :, 4] + pred_classes = batch_predictions[:, :, 5].long() + + return num_predictions_capped, pred_boxes, pred_scores, pred_classes @classmethod - def as_graph(cls, batch_size: int, num_pre_nms_predictions: int, max_predictions_per_image, dtype: torch.dtype, device: torch.device) -> gs.Graph: + def as_graph( + cls, + batch_size: int, + num_pre_nms_predictions: int, + max_predictions_per_image: int, + dtype: torch.dtype, + device: torch.device, + onnx_export_kwargs: Optional[Mapping] = None, + ) -> gs.Graph: + if onnx_export_kwargs is None: + onnx_export_kwargs = {} with tempfile.TemporaryDirectory() as tmpdirname: onnx_file = os.path.join(tmpdirname, "PickNMSPredictionsAndReturnAsBatchedResult.onnx") pred_boxes = torch.zeros((batch_size, num_pre_nms_predictions, 4), dtype=dtype, device=device) @@ -101,13 +142,27 @@ def as_graph(cls, batch_size: int, num_pre_nms_predictions: int, max_predictions }, "selected_indexes": {0: "num_predictions"}, }, + **onnx_export_kwargs, ) + model_opt, check_ok = onnxsim.simplify(onnx_file) + if not check_ok: + raise RuntimeError(f"Failed to simplify ONNX model {onnx_file}") + onnx.save(model_opt, onnx_file) + convert_format_graph = gs.import_onnx(onnx.load(onnx_file)) + convert_format_graph = convert_format_graph.fold_constants().cleanup().toposort() + convert_format_graph = iteratively_infer_shapes(convert_format_graph) return convert_format_graph class PickNMSPredictionsAndReturnAsFlatResult(nn.Module): + """ + Select the output from ONNX NMS node and return them in flat format. + + This module is NOT compatible with TensorRT engine (Tested on TensorRT 8.4.2, 8.5.3 and 8.6.1) when using batch size > 1. + """ + __constants__ = ("batch_size", "num_pre_nms_predictions", "max_predictions_per_image") def __init__(self, batch_size: int, num_pre_nms_predictions: int, max_predictions_per_image: int): @@ -116,33 +171,68 @@ def __init__(self, batch_size: int, num_pre_nms_predictions: int, max_prediction self.num_pre_nms_predictions = num_pre_nms_predictions self.max_predictions_per_image = max_predictions_per_image - def forward(self, pred_boxes: Tensor, pred_scores: Tensor, selected_indexes: Tensor): + def forward(self, pred_boxes: Tensor, pred_scores: Tensor, selected_indexes: Tensor) -> Tensor: """ Select the predictions that are output by the NMS plugin. - :param pred_boxes: [B, N, 4] tensor - :param pred_scores: [B, N, C] tensor + :param pred_boxes: [B, N, 4] tensor + :param pred_scores: [B, N, C] tensor :param selected_indexes: [num_selected_indices, 3] - each row is [batch_indexes, label_indexes, boxes_indexes] - :return: A single tensor of [Nout, 7] shape, where Nout is the total number of detections across all images in the batch. - Each row will contain [image_index, x1, y1, x2, y2, class confidence, class index] values. + Indexes of predictions from same image (same batch_index) corresponds to sorted predictions (Confident first). + :return: A single tensor of [Nout, 7] shape, where Nout is the total number of detections across all images in the batch. + Each row will contain [image_index, x1, y1, x2, y2, class confidence, class index] values. + Each image will have at most max_predictions_per_image detections. """ - batch_indexes, label_indexes, boxes_indexes = selected_indexes[:, 0], selected_indexes[:, 1], selected_indexes[:, 2] + batch_indexes = selected_indexes[:, 0] + label_indexes = selected_indexes[:, 1] + boxes_indexes = selected_indexes[:, 2] selected_boxes = pred_boxes[batch_indexes, boxes_indexes] selected_scores = pred_scores[batch_indexes, boxes_indexes, label_indexes] + dtype = selected_scores.dtype - return torch.cat( - [ - batch_indexes.unsqueeze(1).to(selected_boxes.dtype), - selected_boxes, - selected_scores.unsqueeze(1), - label_indexes.unsqueeze(1).to(selected_boxes.dtype), - ], - dim=1, - ) + flat_results = torch.cat( + [batch_indexes.unsqueeze(-1).to(dtype), selected_boxes, selected_scores.unsqueeze(-1), label_indexes.unsqueeze(-1).to(dtype)], dim=1 + ) # [N, 7] + + if self.batch_size > 1: + # Compute a mask of shape [N,B] where each row contains True if the corresponding prediction belongs to the corresponding batch index + image_index = torch.arange(self.batch_size, dtype=batch_indexes.dtype, device=batch_indexes.device) + + detections_in_images_mask = image_index.view(1, self.batch_size) == batch_indexes.view(-1, 1) # [num_selected_indices, B] + + # Compute total number of detections per image + num_detections_per_image = torch.sum(detections_in_images_mask, dim=0, keepdim=True) # [1, B] + + # Cap the number of detections per image to max_predictions_per_image + num_detections_per_image = torch.clamp_max(num_detections_per_image, self.max_predictions_per_image) # [1, B] + + # Calculate cumulative count of selected predictions for each batch index + # This will give us a tensor of shape [num_selected_indices, B] where the value at each position + # represents the number of predictions for the corresponding batch index up to that position. + cumulative_counts = detections_in_images_mask.float().cumsum(dim=0) # [num_selected_indices, B] + + # Create a mask to ensure we only keep max_predictions_per_image detections per image + mask = ((cumulative_counts <= num_detections_per_image) & detections_in_images_mask).any(dim=1, keepdim=False) # [N] + + final_results = flat_results[mask > 0] + else: + final_results = flat_results[: self.max_predictions_per_image] + + return final_results @classmethod - def as_graph(cls, batch_size: int, num_pre_nms_predictions: int, max_predictions_per_image: int, dtype: torch.dtype, device: torch.device) -> gs.Graph: + def as_graph( + cls, + batch_size: int, + num_pre_nms_predictions: int, + max_predictions_per_image: int, + dtype: torch.dtype, + device: torch.device, + onnx_export_kwargs: Optional[Mapping] = None, + ) -> gs.Graph: + if onnx_export_kwargs is None: + onnx_export_kwargs = {} with tempfile.TemporaryDirectory() as tmpdirname: onnx_file = os.path.join(tmpdirname, "PickNMSPredictionsAndReturnAsFlatTensor.onnx") pred_boxes = torch.zeros((batch_size, num_pre_nms_predictions, 4), dtype=dtype, device=device) @@ -161,11 +251,20 @@ def as_graph(cls, batch_size: int, num_pre_nms_predictions: int, max_predictions "pred_boxes": {}, "pred_scores": {2: "num_classes"}, "selected_indexes": {0: "num_predictions"}, - "flat_predictions": {0: "num_predictions"}, + "flat_predictions": {0: "num_output_predictions"}, }, + **onnx_export_kwargs, ) + model_opt, check_ok = onnxsim.simplify(onnx_file) + if not check_ok: + raise RuntimeError(f"Failed to simplify ONNX model {onnx_file}") + onnx.save(model_opt, onnx_file) + convert_format_graph = gs.import_onnx(onnx.load(onnx_file)) + convert_format_graph = convert_format_graph.fold_constants().cleanup().toposort() + convert_format_graph = iteratively_infer_shapes(convert_format_graph) + return convert_format_graph @@ -179,6 +278,7 @@ def attach_onnx_nms( batch_size: int, output_predictions_format: DetectionOutputFormatMode, device: torch.device, + onnx_export_kwargs: Optional[Mapping] = None, ): """ Attach ONNX NMS plugin to the detection model. @@ -210,7 +310,6 @@ def attach_onnx_nms( :return: None """ graph = gs.import_onnx(onnx.load(onnx_model_path)) - graph.fold_constants() pred_boxes, pred_scores = graph.outputs @@ -244,7 +343,7 @@ def attach_onnx_nms( graph.layer(op="Transpose", name="permute_scores", inputs=[pred_scores], outputs=[permute_scores], attrs={"perm": [0, 2, 1]}) op_inputs = [pred_boxes, permute_scores] + [ - gs.Constant(name="max_output_boxes_per_class", values=np.array([max_predictions_per_image], dtype=np.int64)), + gs.Constant(name="max_output_boxes_per_class", values=np.array([num_pre_nms_predictions], dtype=np.int64)), gs.Constant(name="iou_threshold", values=np.array([nms_threshold], dtype=np.float32)), gs.Constant(name="score_threshold", values=np.array([confidence_threshold], dtype=np.float32)), ] @@ -279,6 +378,7 @@ def attach_onnx_nms( max_predictions_per_image=max_predictions_per_image, dtype=numpy_dtype_to_torch_dtype(np.float32), device=device, + onnx_export_kwargs=onnx_export_kwargs, ) graph = append_graphs(graph, convert_format_graph) elif output_predictions_format == DetectionOutputFormatMode.FLAT_FORMAT: @@ -288,20 +388,17 @@ def attach_onnx_nms( max_predictions_per_image=max_predictions_per_image, dtype=numpy_dtype_to_torch_dtype(np.float32), device=device, + onnx_export_kwargs=onnx_export_kwargs, ) graph = append_graphs(graph, convert_format_graph) else: raise ValueError(f"Invalid output_predictions_format: {output_predictions_format}") # Final cleanup & save - graph.cleanup().toposort() - - # iteratively_infer_shapes(graph) + graph = graph.toposort().fold_constants().cleanup() + graph = iteratively_infer_shapes(graph) model = gs.export_onnx(graph) - onnx.shape_inference.infer_shapes(model) + model = onnx.shape_inference.infer_shapes(model) onnx.save(model, output_onnx_model_path) logger.debug(f"Saved ONNX model to {output_onnx_model_path}") - - # onnxsim.simplify(output_onnx_model_path) - # logger.debug(f"Ran onnxsim.simplify on {output_onnx_model_path}") diff --git a/src/super_gradients/conversion/onnx/pose_nms.py b/src/super_gradients/conversion/onnx/pose_nms.py index 986646a9f3..20a90325ad 100644 --- a/src/super_gradients/conversion/onnx/pose_nms.py +++ b/src/super_gradients/conversion/onnx/pose_nms.py @@ -1,10 +1,11 @@ import os import tempfile -from typing import Tuple +from typing import Tuple, Mapping, Optional import numpy as np import onnx import onnx.shape_inference +import onnxsim import torch from onnx import TensorProto from torch import nn, Tensor @@ -13,7 +14,7 @@ from super_gradients.conversion.conversion_enums import DetectionOutputFormatMode from super_gradients.conversion.conversion_utils import numpy_dtype_to_torch_dtype from super_gradients.conversion.gs_utils import import_onnx_graphsurgeon_or_fail_with_instructions -from super_gradients.conversion.onnx.utils import append_graphs +from super_gradients.conversion.onnx.utils import append_graphs, iteratively_infer_shapes logger = get_logger(__name__) @@ -23,7 +24,7 @@ class PoseNMSAndReturnAsBatchedResult(nn.Module): - __constants__ = ("batch_size", "max_predictions_per_image") + __constants__ = ("batch_size", "num_pre_nms_predictions", "max_predictions_per_image") def __init__(self, batch_size: int, num_pre_nms_predictions: int, max_predictions_per_image: int): """ @@ -37,6 +38,10 @@ def __init__(self, batch_size: int, num_pre_nms_predictions: int, max_prediction max_predictions_per_image predictions are left, the rest of the predictions will be padded with 0. """ + if max_predictions_per_image > num_pre_nms_predictions: + raise ValueError( + f"max_predictions_per_image ({max_predictions_per_image}) must be less than or equal to num_pre_nms_predictions ({num_pre_nms_predictions})" + ) super().__init__() self.batch_size = batch_size self.num_pre_nms_predictions = num_pre_nms_predictions @@ -46,55 +51,67 @@ def forward(self, pred_boxes: Tensor, pred_scores: Tensor, pred_joints: Tensor, """ Select the predictions that are output by the NMS plugin. - :param pred_boxes: [B, N, 4] tensor, float32 in XYXY format - :param pred_scores: [B, N, 1] tensor, float32 - :param pred_joints: [B, N, Num Joints, 3] tensor, float32 + Since pose estimation it is a single-class detection task, we do not need to select the label indexes. + We also already get at most max_predictions_per_image in selected_indexes per image, so there is no need to + do any additional filtering. + + :param pred_boxes: [B, N, 4] tensor, float32 in XYXY format + :param pred_scores: [B, N, 1] tensor, float32 + :param pred_joints: [B, N, Num Joints, 3] tensor, float32 :param selected_indexes: [num_selected_indices, 3], int64 - each row is [batch_indexes, label_indexes, boxes_indexes] - :return: A tuple of 4 tensors (num_detections, boxes, scores, joints) will be returned: - - A tensor of [batch_size, 1] containing the image indices for each detection. - - A tensor of [batch_size, max_output_boxes, 4] containing the bounding box coordinates for each detection in [x1, y1, x2, y2] format. - - A tensor of [batch_size, max_output_boxes, Num Joints, 3] + :return: A tuple of 4 tensors (num_detections, boxes, scores, joints) will be returned: + - A tensor of [batch_size, 1] containing the image indices for each detection. + - A tensor of [batch_size, max_output_boxes, 4] containing the bounding box coordinates + for each detection in [x1, y1, x2, y2] format. + - A tensor of [batch_size, max_output_boxes, Num Joints, 3] """ - batch_indexes, label_indexes, boxes_indexes = selected_indexes[:, 0], selected_indexes[:, 1], selected_indexes[:, 2] + # Adding a dummy row to the beginning of the tensor to make sure that we have at least one row + # Pytorch & ONNX have a hard time dealing with zero-sized tensors (Can't do torch.nn.functional.pad & squeeze or reshape to get rid of zero-sized axis) + # A dummy row does not affect the result size it is not matched to any [B,N] index + selected_indexes = torch.cat([torch.tensor([[-1, -1, -1]], device=selected_indexes.device, dtype=selected_indexes.dtype), selected_indexes], dim=0) - selected_boxes = pred_boxes[batch_indexes, boxes_indexes] # [num_detections, 4] - selected_scores = pred_scores[batch_indexes, boxes_indexes, label_indexes] # [num_detections] - selected_poses = pred_joints[batch_indexes, boxes_indexes] # [num_detections, Num Joints, 3] + batch_indexes = selected_indexes[:, 0] # [L], N >= L + # label_indexes = selected_indexes[:, 1] Not used because we always have 1 label + boxes_indexes = selected_indexes[:, 2] # [L], N >= L - predictions = torch.cat([batch_indexes.unsqueeze(1), selected_boxes, selected_scores.unsqueeze(1), selected_poses.flatten(1)], dim=1) + pre_nms_indexes = torch.arange(start=0, end=self.num_pre_nms_predictions, step=1, device=pred_boxes.device).to(dtype=pred_boxes.dtype) # [N] - predictions = torch.nn.functional.pad( - predictions, (0, 0, 0, self.max_predictions_per_image * self.batch_size - predictions.size(0)), value=-1, mode="constant" - ) + # pre_nms_vs_predictions_mask contains True if the corresponding detection index is equal to the corresponding pre_nms index + pre_nms_vs_predictions_mask = pre_nms_indexes.view(-1, 1) == boxes_indexes.view(1, -1) # [N, L] - batch_predictions = torch.zeros( - (self.batch_size, self.max_predictions_per_image, 4 + 1 + selected_poses.size(1) * selected_poses.size(2)), - dtype=predictions.dtype, - device=predictions.device, - ) + image_indexes = torch.arange(start=0, end=self.batch_size, step=1, device=pred_boxes.device) + batch_indexes_mask = image_indexes.view(-1, 1).eq(batch_indexes.view(1, -1)) # [B, L] - batch_indexes = torch.arange(start=0, end=self.batch_size, step=1, device=predictions.device).to(dtype=predictions.dtype) - masks = batch_indexes.view(-1, 1).eq(predictions[:, 0].view(1, -1)) # [B, N] + final_mask = batch_indexes_mask.unsqueeze(1) & pre_nms_vs_predictions_mask.unsqueeze(0) # [B, N, L] + final_mask = final_mask.any(dim=2, keepdims=False) # [B, N] - num_predictions = torch.sum(masks, dim=1).long() + pred_scores = pred_scores[:, :, 0] # # [B, N] + scores_for_topk = pred_scores * final_mask # [B, N] - for i in range(self.batch_size): - selected_predictions = predictions[masks[i]] - selected_predictions = selected_predictions[:, 1:] - batch_predictions[i] = torch.nn.functional.pad( - selected_predictions, (0, 0, 0, self.max_predictions_per_image - selected_predictions.size(0)), value=0, mode="constant" - ) + order = torch.topk(scores_for_topk, dim=1, k=self.max_predictions_per_image, largest=True, sorted=True) - pred_boxes = batch_predictions[:, :, 0:4] - pred_scores = batch_predictions[:, :, 4] - pred_joints = batch_predictions[:, :, 5:].reshape(self.batch_size, self.max_predictions_per_image, -1, 3) + final_boxes = pred_boxes[image_indexes[:, None], order.indices] # [B, N, 4] + final_scores = pred_scores[image_indexes[:, None], order.indices] # [B, N] + final_poses = pred_joints[image_indexes[:, None], order.indices] # [B, N, Num Joints, 3] - return num_predictions.unsqueeze(1), pred_boxes, pred_scores, pred_joints + # Count number of predictions for each image in batch + num_predictions = torch.sum(batch_indexes_mask.float(), dim=1, keepdim=True).long() # [B, 1] + num_predictions = torch.clamp_max(num_predictions, self.max_predictions_per_image) + + return num_predictions, final_boxes, final_scores, final_poses @classmethod - def as_graph(cls, batch_size: int, num_pre_nms_predictions: int, max_predictions_per_image: int, dtype: torch.dtype, device: torch.device) -> gs.Graph: + def as_graph( + cls, + batch_size: int, + num_pre_nms_predictions: int, + max_predictions_per_image: int, + dtype: torch.dtype, + device: torch.device, + onnx_export_kwargs: Optional[Mapping] = None, + ) -> gs.Graph: """ Convert this module to a separate ONNX graph in order to attach it to the main model. @@ -107,8 +124,11 @@ def as_graph(cls, batch_size: int, num_pre_nms_predictions: int, max_predictions max_predictions_per_image predictions are left, the rest of the predictions will be padded with 0. :param dtype: The target dtype for the graph. If user asked for FP16 model we should create underlying graph with FP16 tensors. :param device: The target device for exporting graph. + :param onnx_export_kwargs: (dict) Optional keyword arguments for torch.onnx.export() function. :return: An instance of GraphSurgeon graph that can be attached to the main model. """ + if onnx_export_kwargs is None: + onnx_export_kwargs = {} with tempfile.TemporaryDirectory() as tmpdirname: onnx_file = os.path.join(tmpdirname, "PoseNMSAndReturnAsBatchedResult.onnx") pre_nms_boxes = torch.zeros((batch_size, num_pre_nms_predictions, 4), dtype=dtype, device=device) @@ -141,9 +161,17 @@ def as_graph(cls, batch_size: int, num_pre_nms_predictions: int, max_predictions }, "selected_indexes": {0: "num_predictions"}, }, + **onnx_export_kwargs, ) + model_opt, check_ok = onnxsim.simplify(onnx_file) + if not check_ok: + raise RuntimeError(f"Failed to simplify ONNX model {onnx_file}") + onnx.save(model_opt, onnx_file) + convert_format_graph = gs.import_onnx(onnx.load(onnx_file)) + convert_format_graph = convert_format_graph.fold_constants().cleanup().toposort() + convert_format_graph = iteratively_infer_shapes(convert_format_graph) return convert_format_graph @@ -165,37 +193,52 @@ def __init__(self, batch_size: int, num_pre_nms_predictions: int, max_prediction self.num_pre_nms_predictions = num_pre_nms_predictions self.max_predictions_per_image = max_predictions_per_image - def forward(self, pred_boxes: Tensor, pred_scores: Tensor, pred_joints: Tensor, selected_indexes: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + def forward(self, pred_boxes: Tensor, pred_scores: Tensor, pred_joints: Tensor, selected_indexes: Tensor) -> Tensor: """ Select the predictions that are output by the NMS plugin. - :param pred_boxes: [B, N, 4] tensor, float32 - :param pred_scores: [B, N, 1] tensor, float32 - :param pred_joints: [B, N, Num Joints, 3] tensor, float32 + Since pose estimation it is a single-class detection task, we do not need to select the label indexes. + We also already get at most max_predictions_per_image in selected_indexes per image, so there is no need to + do any additional filtering. + + :param pred_boxes: [B, N, 4] tensor, float32 + :param pred_scores: [B, N, 1] tensor, float32 + :param pred_joints: [B, N, Num Joints, 3] tensor, float32 :param selected_indexes: [num_selected_indices, 3], int64 - each row is [batch_indexes, label_indexes, boxes_indexes] - :return: A single tensor of [Nout, 7] shape, where Nout is the total number of detections across all images in the batch. - Each row will contain [image_index, x1, y1, x2, y2, class confidence, class index] values. + :return: A single tensor of [Nout, 7] shape, where Nout is the total number of detections across all images in the batch. + Each row will contain [image_index, x1, y1, x2, y2, class confidence, class index] values. """ - batch_indexes, label_indexes, boxes_indexes = selected_indexes[:, 0], selected_indexes[:, 1], selected_indexes[:, 2] + batch_indexes = selected_indexes[:, 0] + label_indexes = selected_indexes[:, 1] + boxes_indexes = selected_indexes[:, 2] selected_boxes = pred_boxes[batch_indexes, boxes_indexes] # [num_detections, 4] - selected_scores = pred_scores[batch_indexes, boxes_indexes, label_indexes].unsqueeze(1) # [num_detections, 1] + selected_scores = pred_scores[batch_indexes, boxes_indexes, label_indexes] # [num_detections, 1] selected_poses = pred_joints[batch_indexes, boxes_indexes].flatten(start_dim=1) # [num_detections, (Num Joints * 3)] + dtype = selected_scores.dtype return torch.cat( [ - batch_indexes.unsqueeze(1).to(selected_boxes.dtype), + batch_indexes.to(dtype).unsqueeze(-1), selected_boxes, - selected_scores, + selected_scores.unsqueeze(1), selected_poses, ], dim=1, ) @classmethod - def as_graph(cls, batch_size: int, num_pre_nms_predictions: int, max_predictions_per_image: int, dtype: torch.dtype, device: torch.device) -> gs.Graph: + def as_graph( + cls, + batch_size: int, + num_pre_nms_predictions: int, + max_predictions_per_image: int, + dtype: torch.dtype, + device: torch.device, + onnx_export_kwargs: Optional[Mapping] = None, + ) -> gs.Graph: """ Convert this module to a separate ONNX graph in order to attach it to the main model. @@ -207,8 +250,11 @@ def as_graph(cls, batch_size: int, num_pre_nms_predictions: int, max_predictions :param max_predictions_per_image: Not used, exists for compatibility with PoseNMSAndReturnAsBatchedResult :param dtype: The target dtype for the graph. If user asked for FP16 model we should create underlying graph with FP16 tensors. :param device: The target device for exporting graph. + :param onnx_export_kwargs: (dict) Optional keyword arguments for torch.onnx.export() function. :return: An instance of GraphSurgeon graph that can be attached to the main model. """ + if onnx_export_kwargs is None: + onnx_export_kwargs = {} with tempfile.TemporaryDirectory() as tmpdirname: onnx_file = os.path.join(tmpdirname, "PoseNMSAndReturnAsFlatResult.onnx") pre_nms_boxes = torch.zeros((batch_size, num_pre_nms_predictions, 4), dtype=dtype, device=device) @@ -231,9 +277,18 @@ def as_graph(cls, batch_size: int, num_pre_nms_predictions: int, max_predictions "selected_indexes": {0: "num_predictions"}, "flat_predictions": {0: "num_predictions"}, }, + **onnx_export_kwargs, ) + model_opt, check_ok = onnxsim.simplify(onnx_file) + if not check_ok: + raise RuntimeError(f"Failed to simplify ONNX model {onnx_file}") + onnx.save(model_opt, onnx_file) + convert_format_graph = gs.import_onnx(onnx.load(onnx_file)) + convert_format_graph = convert_format_graph.fold_constants().cleanup().toposort() + convert_format_graph = iteratively_infer_shapes(convert_format_graph) + return convert_format_graph @@ -247,6 +302,7 @@ def attach_onnx_pose_nms( batch_size: int, output_predictions_format: DetectionOutputFormatMode, device: torch.device, + onnx_export_kwargs: Optional[Mapping] = None, ): """ Attach ONNX NMS stage to the pose estimation predictions. @@ -357,6 +413,7 @@ def attach_onnx_pose_nms( max_predictions_per_image=max_predictions_per_image, dtype=numpy_dtype_to_torch_dtype(np.float32), device=device, + onnx_export_kwargs=onnx_export_kwargs, ) graph = append_graphs(graph, convert_format_graph) elif output_predictions_format == DetectionOutputFormatMode.FLAT_FORMAT: @@ -366,15 +423,17 @@ def attach_onnx_pose_nms( max_predictions_per_image=max_predictions_per_image, dtype=numpy_dtype_to_torch_dtype(np.float32), device=device, + onnx_export_kwargs=onnx_export_kwargs, ) graph = append_graphs(graph, convert_format_graph) else: raise ValueError(f"Invalid output_predictions_format: {output_predictions_format}") # Final cleanup & save - graph.cleanup().toposort() + graph = graph.toposort().fold_constants().cleanup() + graph = iteratively_infer_shapes(graph) model = gs.export_onnx(graph) - onnx.shape_inference.infer_shapes(model) + model = onnx.shape_inference.infer_shapes(model) onnx.save(model, output_onnx_model_path) logger.debug(f"Saved ONNX model to {output_onnx_model_path}") diff --git a/src/super_gradients/conversion/onnx/utils.py b/src/super_gradients/conversion/onnx/utils.py index ff67d59d4f..01f549b84e 100644 --- a/src/super_gradients/conversion/onnx/utils.py +++ b/src/super_gradients/conversion/onnx/utils.py @@ -49,13 +49,13 @@ def append_graphs(graph1: gs.Graph, graph2: gs.Graph, prefix: str = "graph2_") - merged_graph.outputs.clear() merged_graph.outputs = graph2.outputs - merged_graph.toposort() - # iteratively_infer_shapes(merged_graph) + merged_graph = merged_graph.fold_constants().toposort().cleanup() + merged_graph = iteratively_infer_shapes(merged_graph) return merged_graph -def iteratively_infer_shapes(graph: gs.Graph) -> None: +def iteratively_infer_shapes(graph: gs.Graph) -> gs.Graph: """ Sanitize the graph by cleaning any unconnected nodes, do a topological resort, and fold constant inputs values. When possible, run shape inference on the @@ -65,7 +65,7 @@ def iteratively_infer_shapes(graph: gs.Graph) -> None: for _ in range(3): count_before = len(graph.nodes) - graph.cleanup().toposort() + graph = graph.cleanup().toposort() try: # for node in graph.nodes: # for o in node.outputs: @@ -76,7 +76,7 @@ def iteratively_infer_shapes(graph: gs.Graph) -> None: except Exception as e: logger.debug(f"Shape inference could not be performed at this time:\n{e}") try: - graph.fold_constants(fold_shapes=True) + graph = graph.fold_constants(fold_shapes=True) except TypeError as e: logger.error("This version of ONNX GraphSurgeon does not support folding shapes, " f"please upgrade your onnx_graphsurgeon module. Error:\n{e}") raise @@ -86,3 +86,5 @@ def iteratively_infer_shapes(graph: gs.Graph) -> None: # No new folding occurred in this iteration, so we can stop for now. break logger.debug(f"Folded {count_before - count_after} constants.") + + return graph diff --git a/src/super_gradients/conversion/tensorrt/nms.py b/src/super_gradients/conversion/tensorrt/nms.py index 9ad8a1a47d..f1ab3699da 100644 --- a/src/super_gradients/conversion/tensorrt/nms.py +++ b/src/super_gradients/conversion/tensorrt/nms.py @@ -1,5 +1,6 @@ import os import tempfile +from typing import Optional, Mapping import numpy as np import onnx @@ -10,7 +11,7 @@ from super_gradients.conversion.conversion_enums import DetectionOutputFormatMode from super_gradients.conversion.conversion_utils import numpy_dtype_to_torch_dtype from super_gradients.conversion.gs_utils import import_onnx_graphsurgeon_or_fail_with_instructions -from super_gradients.conversion.onnx.utils import append_graphs +from super_gradients.conversion.onnx.utils import append_graphs, iteratively_infer_shapes logger = get_logger(__name__) @@ -18,55 +19,74 @@ class ConvertTRTFormatToFlatTensor(nn.Module): + """ + Convert the predictions from EfficientNMS_TRT node format to flat tensor format. + + This node is supported on TensorRT 8.5.3+ + """ + __constants__ = ["batch_size", "max_predictions_per_image"] def __init__(self, batch_size: int, max_predictions_per_image: int): + """ + Convert the predictions from TensorRT format to flat tensor format. + + :param batch_size: A fixed batch size for the model + :param max_predictions_per_image: Maximum number of predictions per image + """ super().__init__() self.batch_size = batch_size self.max_predictions_per_image = max_predictions_per_image def forward(self, num_predictions: Tensor, pred_boxes: Tensor, pred_scores: Tensor, pred_classes: Tensor) -> Tensor: """ - Convert the predictions from "batch" format to "flat" tensor. + Convert the predictions from "batch" format to "flat" format. + :param num_predictions: [B,1] The number of predictions for each image in the batch. - :param pred_boxes: [B, max_predictions_per_image, 4] The predicted bounding boxes for each image in the batch. - :param pred_scores: [B, max_predictions_per_image] The predicted scores for each image in the batch. - :param pred_classes: [B, max_predictions_per_image] The predicted classes for each image in the batch. - :return: Tensor of shape [N, 7] The predictions in flat tensor format. - N is the total number of predictions in the entire batch. - Each row will contain [image_index, x1, y1, x2, y2, class confidence, class index] values. + :param pred_boxes: [B, max_predictions_per_image, 4] The predicted bounding boxes for each image in the batch. + :param pred_scores: [B, max_predictions_per_image] The predicted scores for each image in the batch. + :param pred_classes: [B, max_predictions_per_image] The predicted classes for each image in the batch. + :return: Tensor of shape [N, 7] The predictions in flat tensor format. + N is the total number of predictions in the entire batch. + Each row will contain [image_index, x1, y1, x2, y2, class confidence, class index] values. """ batch_indexes = ( torch.arange(start=0, end=self.batch_size, step=1, device=num_predictions.device).view(-1, 1).repeat(1, pred_scores.shape[1]) ) # [B, max_predictions_per_image] - preds_indexes = ( - torch.arange(start=0, end=self.max_predictions_per_image, step=1, device=num_predictions.device).view(1, -1, 1).repeat(self.batch_size, 1, 1) - ) # [B, max_predictions_per_image, 1] + preds_indexes = torch.arange(start=0, end=self.max_predictions_per_image, step=1, device=num_predictions.device).view( + 1, -1 + ) # [1, max_predictions_per_image] flat_predictions = torch.cat( [ - preds_indexes.to(dtype=pred_scores.dtype), batch_indexes.unsqueeze(-1).to(dtype=pred_scores.dtype), pred_boxes, pred_scores.unsqueeze(dim=-1), pred_classes.unsqueeze(dim=-1).to(pred_scores.dtype), ], dim=-1, - ) # [B, max_predictions_per_image, 8] - - num_predictions = num_predictions.repeat(1, self.max_predictions_per_image) # [B, max_predictions_per_image] + ) # [B, max_predictions_per_image, 7] - mask = (flat_predictions[:, :, 0] < num_predictions) & (flat_predictions[:, :, 1] == batch_indexes) # [B, max_predictions_per_image] + mask: Tensor = preds_indexes < num_predictions.view((self.batch_size, 1)) # [B, max_predictions_per_image] + # Compatible + mask = mask.view(-1) + flat_predictions = flat_predictions.view(self.max_predictions_per_image * self.batch_size, 7) flat_predictions = flat_predictions[mask] # [N, 7] - return flat_predictions[:, 1:] + + return flat_predictions @classmethod - def as_graph(cls, batch_size: int, max_predictions_per_image: int, dtype: torch.dtype, device: torch.device) -> gs.Graph: + def as_graph( + cls, batch_size: int, max_predictions_per_image: int, dtype: torch.dtype, device: torch.device, onnx_export_kwargs: Optional[Mapping] = None + ) -> gs.Graph: + if onnx_export_kwargs is None: + onnx_export_kwargs = {} + with tempfile.TemporaryDirectory() as tmpdirname: - onnx_file = os.path.join(tmpdirname, "ConvertTRTFormatToFlatTensor.onnx") + onnx_file = os.path.join(tmpdirname, "ConvertTRTFormatToFlatTensorTMP.onnx") num_detections = torch.randint(1, max_predictions_per_image, (batch_size, 1), dtype=torch.int32, device=device) pred_boxes = torch.zeros((batch_size, max_predictions_per_image, 4), dtype=dtype, device=device) @@ -80,9 +100,12 @@ def as_graph(cls, batch_size: int, max_predictions_per_image: int, dtype: torch. input_names=["num_detections", "pred_boxes", "pred_scores", "pred_classes"], output_names=["flat_predictions"], dynamic_axes={"flat_predictions": {0: "num_predictions"}}, + **onnx_export_kwargs, ) convert_format_graph = gs.import_onnx(onnx.load(onnx_file)) + convert_format_graph = convert_format_graph.fold_constants().cleanup().toposort() + convert_format_graph = iteratively_infer_shapes(convert_format_graph) return convert_format_graph @@ -96,17 +119,23 @@ def attach_tensorrt_nms( batch_size: int, output_predictions_format: DetectionOutputFormatMode, device: torch.device, + onnx_export_kwargs: Optional[Mapping] = None, ): """ Attach TensorRT NMS plugin to the ONNX model - :param onnx_model_path: - :param output_onnx_model_path: - :param max_predictions_per_image: Maximum number of predictions per image - :param precision: - :param batch_size: - :return: + :param onnx_model_path: Path to the original model in ONNX format to attach the NMS plugin to. + :param output_onnx_model_path: Path to save the new ONNX model with the NMS plugin attached. + :param num_pre_nms_predictions: Number of predictions that goes into NMS. + :param max_predictions_per_image: Maximum number of predictions per image (after NMS). + :param batch_size: Batch size of the model. + :param confidence_threshold: Confidence threshold for NMS step. + :param nms_threshold: NMS IoU threshold. + :param output_predictions_format: Output predictions format. + :param device: Device to run the model on. + :param onnx_export_kwargs: Additional kwargs to pass to torch.onnx.export """ + graph = gs.import_onnx(onnx.load(onnx_model_path)) # graph.fold_constants() @@ -159,7 +188,11 @@ def attach_tensorrt_nms( if output_predictions_format == DetectionOutputFormatMode.FLAT_FORMAT: convert_format_graph = ConvertTRTFormatToFlatTensor.as_graph( - batch_size=batch_size, max_predictions_per_image=max_predictions_per_image, dtype=numpy_dtype_to_torch_dtype(pred_boxes.dtype), device=device + batch_size=batch_size, + max_predictions_per_image=max_predictions_per_image, + dtype=numpy_dtype_to_torch_dtype(pred_boxes.dtype), + device=device, + onnx_export_kwargs=onnx_export_kwargs, ) graph = append_graphs(graph, convert_format_graph) elif output_predictions_format == DetectionOutputFormatMode.BATCH_FORMAT: @@ -168,8 +201,8 @@ def attach_tensorrt_nms( raise NotImplementedError(f"Currently not supports output_predictions_format: {output_predictions_format}") # Final cleanup & save - graph.cleanup().toposort() - # iteratively_infer_shapes(graph) + graph = graph.cleanup().toposort() + graph = iteratively_infer_shapes(graph) logger.debug(f"Final graph outputs: {graph.outputs}") diff --git a/src/super_gradients/examples/model_export/models_export.ipynb b/src/super_gradients/examples/model_export/models_export.ipynb index 6d49e5310e..0e2d3cf87f 100644 --- a/src/super_gradients/examples/model_export/models_export.ipynb +++ b/src/super_gradients/examples/model_export/models_export.ipynb @@ -49,7 +49,7 @@ "execution_count": null, "outputs": [], "source": [ - "!pip install super_gradients==3.2.1" + "!pip install super_gradients==3.3.1" ], "metadata": { "collapsed": false diff --git a/src/super_gradients/module_interfaces/exportable_detector.py b/src/super_gradients/module_interfaces/exportable_detector.py index 2ca8990138..b81ce80c5d 100644 --- a/src/super_gradients/module_interfaces/exportable_detector.py +++ b/src/super_gradients/module_interfaces/exportable_detector.py @@ -509,7 +509,7 @@ def export( if output_predictions_format == DetectionOutputFormatMode.FLAT_FORMAT: logger.warning( "Support of flat predictions format in TensorRT is experimental and may not work on all versions of TensorRT. " - "We recommend using TensorRT 8.4.1 or newer. On older versions this format will not work. " + "We recommend using TensorRT 8.5.3 or newer. On older versions of TensorRT this format will not work. " "If you encountering issues loading exported model in TensorRT, please try upgrading TensorRT to latest version. " "Alternatively, you can export the model to output predictions in batch format by " "specifying output_predictions_format=DetectionOutputFormatMode.BATCH_FORMAT. " @@ -529,6 +529,7 @@ def export( batch_size=batch_size, output_predictions_format=output_predictions_format, device=device, + onnx_export_kwargs=onnx_export_kwargs, ) if onnx_simplify: diff --git a/src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_e.py b/src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_e.py index 5932dfa7c9..9814f24537 100644 --- a/src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_e.py +++ b/src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_e.py @@ -69,13 +69,14 @@ def forward(self, inputs: Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, ...]]) -> T pred_bboxes, pred_scores = inputs[0] nms_top_k = self.num_pre_nms_predictions + batch_size, num_anchors, _ = pred_scores.size() pred_cls_conf, _ = torch.max(pred_scores, dim=2) topk_candidates = torch.topk(pred_cls_conf, dim=1, k=nms_top_k, largest=True, sorted=True) - offsets = nms_top_k * torch.arange(pred_cls_conf.size(0), device=pred_cls_conf.device) - flat_indices = topk_candidates.indices + offsets.reshape(pred_cls_conf.size(0), 1) - flat_indices = torch.flatten(flat_indices) + offsets = num_anchors * torch.arange(batch_size, device=pred_cls_conf.device) + indices_with_offset = topk_candidates.indices + offsets.reshape(batch_size, 1) + flat_indices = torch.flatten(indices_with_offset) output_pred_bboxes = pred_bboxes.reshape(-1, pred_bboxes.size(2))[flat_indices, :].reshape(pred_bboxes.size(0), nms_top_k, pred_bboxes.size(2)) output_pred_scores = pred_scores.reshape(-1, pred_scores.size(2))[flat_indices, :].reshape(pred_scores.size(0), nms_top_k, pred_scores.size(2)) diff --git a/src/super_gradients/training/models/detection_models/yolo_base.py b/src/super_gradients/training/models/detection_models/yolo_base.py index a81854a953..1f058f43f8 100755 --- a/src/super_gradients/training/models/detection_models/yolo_base.py +++ b/src/super_gradients/training/models/detection_models/yolo_base.py @@ -736,13 +736,15 @@ def forward(self, predictions): if self.with_confidence: pred_scores = pred_scores * conf - pred_cls_conf, _ = torch.max(pred_scores, dim=2) nms_top_k = self.num_pre_nms_predictions + batch_size, num_anchors, _ = pred_scores.size() + + pred_cls_conf, _ = torch.max(pred_scores, dim=2) topk_candidates = torch.topk(pred_cls_conf, dim=1, k=nms_top_k, largest=True, sorted=True) - offsets = nms_top_k * torch.arange(pred_cls_conf.size(0), device=pred_cls_conf.device) - flat_indices = topk_candidates.indices + offsets.reshape(pred_cls_conf.size(0), 1) - flat_indices = torch.flatten(flat_indices) + offsets = num_anchors * torch.arange(batch_size, device=pred_cls_conf.device) + indices_with_offset = topk_candidates.indices + offsets.reshape(batch_size, 1) + flat_indices = torch.flatten(indices_with_offset) output_pred_bboxes = pred_bboxes.reshape(-1, pred_bboxes.size(2))[flat_indices, :].reshape(pred_bboxes.size(0), nms_top_k, pred_bboxes.size(2)) output_pred_scores = pred_scores.reshape(-1, pred_scores.size(2))[flat_indices, :].reshape(pred_scores.size(0), nms_top_k, pred_scores.size(2)) diff --git a/src/super_gradients/training/models/detection_models/yolo_nas/yolo_nas_variants.py b/src/super_gradients/training/models/detection_models/yolo_nas/yolo_nas_variants.py index daa0d34b63..4577d58465 100644 --- a/src/super_gradients/training/models/detection_models/yolo_nas/yolo_nas_variants.py +++ b/src/super_gradients/training/models/detection_models/yolo_nas/yolo_nas_variants.py @@ -52,13 +52,14 @@ def forward(self, inputs: Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, ...]]): pred_bboxes, pred_scores = inputs[0] nms_top_k = self.num_pre_nms_predictions + batch_size, num_anchors, _ = pred_scores.size() - pred_cls_conf, _ = torch.max(pred_scores, dim=2) + pred_cls_conf, _ = torch.max(pred_scores, dim=2) # [B, Anchors] topk_candidates = torch.topk(pred_cls_conf, dim=1, k=nms_top_k, largest=True, sorted=True) - offsets = nms_top_k * torch.arange(pred_cls_conf.size(0), device=pred_cls_conf.device) - flat_indices = topk_candidates.indices + offsets.reshape(pred_cls_conf.size(0), 1) - flat_indices = torch.flatten(flat_indices) + offsets = num_anchors * torch.arange(batch_size, device=pred_cls_conf.device) + indices_with_offset = topk_candidates.indices + offsets.reshape(batch_size, 1) + flat_indices = torch.flatten(indices_with_offset) output_pred_bboxes = pred_bboxes.reshape(-1, pred_bboxes.size(2))[flat_indices, :].reshape(pred_bboxes.size(0), nms_top_k, pred_bboxes.size(2)) output_pred_scores = pred_scores.reshape(-1, pred_scores.size(2))[flat_indices, :].reshape(pred_scores.size(0), nms_top_k, pred_scores.size(2)) @@ -67,6 +68,24 @@ def forward(self, inputs: Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, ...]]): class YoloNAS(ExportableObjectDetectionModel, CustomizableDetector): + """ + + Export to ONNX/TRT Support matrix + ONNX files generated with PyTorch 2.0.1 for ONNX opset_version=14 + + | Batch Size | Export Engine | Format | OnnxRuntime 1.13.1 | TensorRT 8.4.2 | TensorRT 8.5.3 | TensorRT 8.6.1 | + |------------|---------------|--------|--------------------|----------------|----------------|----------------| + | 1 | ONNX | Flat | Yes | Yes | Yes | Yes | + | >1 | ONNX | Flat | Yes | No | No | No | + | 1 | ONNX | Batch | Yes | No | Yes | Yes | + | >1 | ONNX | Batch | Yes | No | No | Yes | + | 1 | TensorRT | Flat | No | No | Yes | Yes | + | >1 | TensorRT | Flat | No | No | Yes | Yes | + | 1 | TensorRT | Batch | No | Yes | Yes | Yes | + | >1 | TensorRT | Batch | No | Yes | Yes | Yes | + + """ + def __init__( self, backbone: Union[str, dict, HpmStruct, DictConfig], diff --git a/src/super_gradients/training/models/pose_estimation_models/yolo_nas_pose/yolo_nas_pose_variants.py b/src/super_gradients/training/models/pose_estimation_models/yolo_nas_pose/yolo_nas_pose_variants.py index 14f3880423..95e5ee7fac 100644 --- a/src/super_gradients/training/models/pose_estimation_models/yolo_nas_pose/yolo_nas_pose_variants.py +++ b/src/super_gradients/training/models/pose_estimation_models/yolo_nas_pose/yolo_nas_pose_variants.py @@ -67,12 +67,13 @@ def forward(self, inputs: Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, ...]]): pred_bboxes_xyxy, pred_bboxes_conf, pred_pose_coords, pred_pose_scores = inputs[0] nms_top_k = self.num_pre_nms_predictions + batch_size, num_anchors, _ = pred_bboxes_conf.size() topk_candidates = torch.topk(pred_bboxes_conf, dim=1, k=nms_top_k, largest=True, sorted=True) - offsets = nms_top_k * torch.arange(pred_bboxes_conf.size(0), device=pred_bboxes_conf.device) - flat_indices = topk_candidates.indices + offsets.reshape(pred_bboxes_conf.size(0), 1, 1) - flat_indices = torch.flatten(flat_indices) + offsets = num_anchors * torch.arange(batch_size, device=pred_bboxes_conf.device) + indices_with_offset = topk_candidates.indices + offsets.reshape(batch_size, 1, 1) + flat_indices = torch.flatten(indices_with_offset) pred_poses_and_scores = torch.cat([pred_pose_coords, pred_pose_scores.unsqueeze(3)], dim=3) @@ -90,6 +91,21 @@ def forward(self, inputs: Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, ...]]): class YoloNASPose(CustomizableDetector, ExportablePoseEstimationModel): + """ + YoloNASPose model + + Exported model support matrix + + | Batch Size | Format | OnnxRuntime 1.13.1 | TensorRT 8.4.2 | TensorRT 8.5.3 | TensorRT 8.6.1 | + |------------|--------|--------------------|----------------|----------------|----------------| + | 1 | Flat | Yes | Yes | Yes | Yes | + | >1 | Flat | Yes | Yes | Yes | Yes | + | 1 | Batch | Yes | No | No | Yes | + | >1 | Batch | Yes | No | No | Yes | + + ONNX files generated with PyTorch 2.0.1 for ONNX opset_version=14 + """ + def __init__( self, backbone: Union[str, dict, HpmStruct, DictConfig], diff --git a/tests/unit_tests/export_detection_model_test.py b/tests/unit_tests/export_detection_model_test.py index 924b40c003..808d3536b6 100644 --- a/tests/unit_tests/export_detection_model_test.py +++ b/tests/unit_tests/export_detection_model_test.py @@ -1,5 +1,6 @@ import logging import os +import random import tempfile import unittest @@ -607,46 +608,48 @@ def test_trt_nms_convert_to_flat_result(self): available_devices = ["cpu"] available_dtypes = [torch.float32] - for device in available_devices: - for dtype in available_dtypes: + for num_predictions_max in [0, max_predictions_per_image // 2, max_predictions_per_image]: + for device in available_devices: + for dtype in available_dtypes: - num_detections = torch.randint(1, max_predictions_per_image, (batch_size, 1), dtype=torch.int32) - detection_boxes = torch.randn((batch_size, max_predictions_per_image, 4), dtype=dtype) - detection_scores = torch.randn((batch_size, max_predictions_per_image), dtype=dtype) - detection_classes = torch.randint(0, 80, (batch_size, max_predictions_per_image), dtype=torch.int32) + num_detections = torch.randint(0, num_predictions_max + 1, (batch_size, 1), dtype=torch.int32) + detection_boxes = torch.randn((batch_size, max_predictions_per_image, 4), dtype=dtype) + detection_scores = torch.randn((batch_size, max_predictions_per_image)).sigmoid().to(dtype) + detection_classes = torch.randint(0, 80, (batch_size, max_predictions_per_image), dtype=torch.int32) - torch_module = ConvertTRTFormatToFlatTensor(batch_size, max_predictions_per_image) - flat_predictions_torch = torch_module(num_detections, detection_boxes, detection_scores, detection_classes) - print(flat_predictions_torch.shape, flat_predictions_torch.dtype, flat_predictions_torch) + torch_module = ConvertTRTFormatToFlatTensor(batch_size, max_predictions_per_image) + flat_predictions_torch = torch_module(num_detections, detection_boxes, detection_scores, detection_classes) + print(flat_predictions_torch.shape, flat_predictions_torch.dtype, flat_predictions_torch) - onnx_file = "ConvertTRTFormatToFlatTensor.onnx" + onnx_file = "ConvertTRTFormatToFlatTensor.onnx" - graph = ConvertTRTFormatToFlatTensor.as_graph( - batch_size=batch_size, max_predictions_per_image=max_predictions_per_image, dtype=dtype, device=device - ) - model = gs.export_onnx(graph) - onnx.checker.check_model(model) - onnx.save(model, onnx_file) - - session = onnxruntime.InferenceSession(onnx_file) - - inputs = [o.name for o in session.get_inputs()] - outputs = [o.name for o in session.get_outputs()] - - [flat_predictions_onnx] = session.run( - output_names=outputs, - input_feed={ - inputs[0]: num_detections.numpy(), - inputs[1]: detection_boxes.numpy(), - inputs[2]: detection_scores.numpy(), - inputs[3]: detection_classes.numpy(), - }, - ) + graph = ConvertTRTFormatToFlatTensor.as_graph( + batch_size=batch_size, max_predictions_per_image=max_predictions_per_image, dtype=dtype, device=device + ) + model = gs.export_onnx(graph) + onnx.checker.check_model(model) + onnx.save(model, onnx_file) + + session = onnxruntime.InferenceSession(onnx_file) + + inputs = [o.name for o in session.get_inputs()] + outputs = [o.name for o in session.get_outputs()] - np.testing.assert_allclose(flat_predictions_torch.numpy(), flat_predictions_onnx, rtol=1e-3, atol=1e-3) + [flat_predictions_onnx] = session.run( + output_names=outputs, + input_feed={ + inputs[0]: num_detections.numpy(), + inputs[1]: detection_boxes.numpy(), + inputs[2]: detection_scores.numpy(), + inputs[3]: detection_classes.numpy(), + }, + ) + + np.testing.assert_allclose(flat_predictions_torch.numpy(), flat_predictions_onnx, rtol=1e-3, atol=1e-3) def test_onnx_nms_flat_result(self): - max_predictions = 100 + num_pre_nms_predictions = 1024 + max_predictions_per_image = 128 batch_size = 7 if torch.cuda.is_available(): @@ -656,41 +659,56 @@ def test_onnx_nms_flat_result(self): available_devices = ["cpu"] available_dtypes = [torch.float32] - for device in available_devices: - for dtype in available_dtypes: - - # Run a few tests to ensure ONNX model produces the same results as the PyTorch model - # And also can handle dynamic shapes input - pred_boxes = torch.randn((batch_size, max_predictions, 4), dtype=dtype) - pred_scores = torch.randn((batch_size, max_predictions, 40), dtype=dtype) - selected_indexes = torch.tensor([[6, 10, 4], [1, 13, 4], [2, 17, 2], [2, 18, 2]], dtype=torch.int64) - - torch_module = PickNMSPredictionsAndReturnAsFlatResult( - batch_size=batch_size, num_pre_nms_predictions=max_predictions, max_predictions_per_image=max_predictions - ) - torch_result = torch_module(pred_boxes, pred_scores, selected_indexes) - - with tempfile.TemporaryDirectory() as temp_dir: - onnx_file = os.path.join(temp_dir, "PickNMSPredictionsAndReturnAsFlatResult.onnx") - graph = PickNMSPredictionsAndReturnAsFlatResult.as_graph( - batch_size=batch_size, num_pre_nms_predictions=max_predictions, max_predictions_per_image=max_predictions, device=device, dtype=dtype + for max_detections in [0, num_pre_nms_predictions // 2, num_pre_nms_predictions, num_pre_nms_predictions * 2]: + for device in available_devices: + for dtype in available_dtypes: + + # Run a few tests to ensure ONNX model produces the same results as the PyTorch model + # And also can handle dynamic shapes input + pred_boxes = torch.randn((batch_size, num_pre_nms_predictions, 4), dtype=dtype) + pred_scores = torch.randn((batch_size, num_pre_nms_predictions, 40), dtype=dtype) + + selected_indexes = [] + for batch_index in range(batch_size): + # num_detections = random.randrange(0, max_detections) if max_detections > 0 else 0 + num_detections = max_detections + for _ in range(num_detections): + selected_indexes.append([batch_index, random.randrange(0, 40), random.randrange(0, num_pre_nms_predictions)]) + selected_indexes = torch.tensor(selected_indexes, dtype=torch.int64).view(-1, 3) + + torch_module = PickNMSPredictionsAndReturnAsFlatResult( + batch_size=batch_size, num_pre_nms_predictions=num_pre_nms_predictions, max_predictions_per_image=max_predictions_per_image ) + torch_result = torch_module(pred_boxes, pred_scores, selected_indexes) - model = gs.export_onnx(graph) - onnx.checker.check_model(model) - onnx.save(model, onnx_file) + with tempfile.TemporaryDirectory() as temp_dir: + onnx_file = os.path.join(temp_dir, "PickNMSPredictionsAndReturnAsFlatResult.onnx") + graph = PickNMSPredictionsAndReturnAsFlatResult.as_graph( + batch_size=batch_size, + num_pre_nms_predictions=num_pre_nms_predictions, + max_predictions_per_image=max_predictions_per_image, + device=device, + dtype=dtype, + ) - session = onnxruntime.InferenceSession(onnx_file) + model = gs.export_onnx(graph) + onnx.checker.check_model(model) + onnx.save(model, onnx_file) - inputs = [o.name for o in session.get_inputs()] - outputs = [o.name for o in session.get_outputs()] + session = onnxruntime.InferenceSession(onnx_file) - [onnx_result] = session.run(outputs, {inputs[0]: pred_boxes.numpy(), inputs[1]: pred_scores.numpy(), inputs[2]: selected_indexes.numpy()}) + inputs = [o.name for o in session.get_inputs()] + outputs = [o.name for o in session.get_outputs()] - np.testing.assert_allclose(torch_result.numpy(), onnx_result, rtol=1e-3, atol=1e-3) + [onnx_result] = session.run( + outputs, {inputs[0]: pred_boxes.numpy(), inputs[1]: pred_scores.numpy(), inputs[2]: selected_indexes.numpy()} + ) + + np.testing.assert_allclose(torch_result.numpy(), onnx_result, rtol=1e-3, atol=1e-3) def test_onnx_nms_batch_result(self): - max_predictions = 100 + num_pre_nms_predictions = 1024 + max_predictions_per_image = 128 batch_size = 7 if torch.cuda.is_available(): @@ -700,41 +718,53 @@ def test_onnx_nms_batch_result(self): available_devices = ["cpu"] available_dtypes = [torch.float32] - for device in available_devices: - for dtype in available_dtypes: - - # Run a few tests to ensure ONNX model produces the same results as the PyTorch model - # And also can handle dynamic shapes input - pred_boxes = torch.randn((batch_size, max_predictions, 4), dtype=dtype) - pred_scores = torch.randn((batch_size, max_predictions, 40), dtype=dtype) - selected_indexes = torch.tensor([[6, 10, 4], [1, 13, 4], [2, 17, 2], [2, 18, 2]], dtype=torch.int64) - - torch_module = PickNMSPredictionsAndReturnAsBatchedResult( - batch_size=batch_size, num_pre_nms_predictions=max_predictions, max_predictions_per_image=max_predictions - ) - torch_result = torch_module(pred_boxes, pred_scores, selected_indexes) - - with tempfile.TemporaryDirectory() as temp_dir: - onnx_file = os.path.join(temp_dir, "PickNMSPredictionsAndReturnAsBatchedResult.onnx") - graph = PickNMSPredictionsAndReturnAsBatchedResult.as_graph( - batch_size=batch_size, num_pre_nms_predictions=max_predictions, max_predictions_per_image=max_predictions, device=device, dtype=dtype + for max_detections in [0, num_pre_nms_predictions // 2, num_pre_nms_predictions, num_pre_nms_predictions * 2]: + for device in available_devices: + for dtype in available_dtypes: + + # Run a few tests to ensure ONNX model produces the same results as the PyTorch model + # And also can handle dynamic shapes input + pred_boxes = torch.randn((batch_size, num_pre_nms_predictions, 4), dtype=dtype) + pred_scores = torch.randn((batch_size, num_pre_nms_predictions, 40), dtype=dtype) + + selected_indexes = [] + for batch_index in range(batch_size): + # num_detections = random.randrange(0, max_detections) if max_detections > 0 else 0 + num_detections = max_detections + for _ in range(num_detections): + selected_indexes.append([batch_index, random.randrange(0, 40), random.randrange(0, num_pre_nms_predictions)]) + selected_indexes = torch.tensor(selected_indexes, dtype=torch.int64).view(-1, 3) + + torch_module = PickNMSPredictionsAndReturnAsBatchedResult( + batch_size=batch_size, num_pre_nms_predictions=num_pre_nms_predictions, max_predictions_per_image=max_predictions_per_image ) + torch_result = torch_module(pred_boxes, pred_scores, selected_indexes) - model = gs.export_onnx(graph) - onnx.checker.check_model(model) - onnx.save(model, onnx_file) + with tempfile.TemporaryDirectory() as temp_dir: + onnx_file = os.path.join(temp_dir, "PickNMSPredictionsAndReturnAsBatchedResult.onnx") + graph = PickNMSPredictionsAndReturnAsBatchedResult.as_graph( + batch_size=batch_size, + num_pre_nms_predictions=num_pre_nms_predictions, + max_predictions_per_image=max_predictions_per_image, + device=device, + dtype=dtype, + ) - session = onnxruntime.InferenceSession(onnx_file) + model = gs.export_onnx(graph) + onnx.checker.check_model(model) + onnx.save(model, onnx_file) - inputs = [o.name for o in session.get_inputs()] - outputs = [o.name for o in session.get_outputs()] + session = onnxruntime.InferenceSession(onnx_file) + + inputs = [o.name for o in session.get_inputs()] + outputs = [o.name for o in session.get_outputs()] - onnx_result = session.run(outputs, {inputs[0]: pred_boxes.numpy(), inputs[1]: pred_scores.numpy(), inputs[2]: selected_indexes.numpy()}) + onnx_result = session.run(outputs, {inputs[0]: pred_boxes.numpy(), inputs[1]: pred_scores.numpy(), inputs[2]: selected_indexes.numpy()}) - np.testing.assert_allclose(torch_result[0].numpy(), onnx_result[0], rtol=1e-3, atol=1e-3) - np.testing.assert_allclose(torch_result[1].numpy(), onnx_result[1], rtol=1e-3, atol=1e-3) - np.testing.assert_allclose(torch_result[2].numpy(), onnx_result[2], rtol=1e-3, atol=1e-3) - np.testing.assert_allclose(torch_result[3].numpy(), onnx_result[3], rtol=1e-3, atol=1e-3) + np.testing.assert_allclose(torch_result[0].numpy(), onnx_result[0], rtol=1e-3, atol=1e-3) + np.testing.assert_allclose(torch_result[1].numpy(), onnx_result[1], rtol=1e-3, atol=1e-3) + np.testing.assert_allclose(torch_result[2].numpy(), onnx_result[2], rtol=1e-3, atol=1e-3) + np.testing.assert_allclose(torch_result[3].numpy(), onnx_result[3], rtol=1e-3, atol=1e-3) def _get_image_as_bchw(self, image_shape=(640, 640)): """