Skip to content

Commit

Permalink
Bugfix of model.export() to work correct with bs>1 (#1551)
Browse files Browse the repository at this point in the history
* Bugfixes:

 1) Picking wrong indexes with exporting models with bs>1 in decoding modules (YoloNAS, YoloNAS-Pose, YoloX, PPYolo-E)

 2) When using ONNX NonMaxSupression a max_predictions_per_image parameter was used incorrectly (Flat/Batch format predictions could contain more than a requested number of detections)

* Fixing NMS postprocessing to make it TRT compatible

* Update pose estimation support matrix for exported models

* Update compatibility matrix for TRT NMS

* Improve ConvertTRTFormatToFlatTensor implementation

* Improve ConvertTRTFormatToFlatTensor implementation

* Improve ConvertTRTFormatToFlatTensor implementation

* Drop use_boolean_gather

* Improve test

* Update compatibility matrix

* Improve test

* Improve test

* Update notebook export
  • Loading branch information
BloodAxe committed Oct 23, 2023
1 parent 24798b0 commit 0515496
Show file tree
Hide file tree
Showing 12 changed files with 525 additions and 256 deletions.
35 changes: 22 additions & 13 deletions documentation/source/models_export.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ A new export API is introduced in SG 3.2.0. It is aimed to simplify the export p
- Customising NMS parameters and number of detections per image
- Customising output format (flat or batched)


```python
!pip install super_gradients==3.3.1
```

ERROR: Could not find a version that satisfies the requirement super_gradients==3.3.1 (from versions: 1.3.0, 1.3.1, 1.4.0, 1.5.0, 1.5.1, 1.5.2, 1.6.0, 1.7.1, 1.7.2, 1.7.3, 1.7.4, 1.7.5, 2.0.0, 2.0.1, 2.1.0, 2.2.0, 2.5.0, 2.6.0, 3.0.0, 3.0.1, 3.0.2, 3.0.3, 3.0.4, 3.0.5, 3.0.6, 3.0.7, 3.0.8, 3.0.9, 3.1.0, 3.1.1, 3.1.2, 3.1.3, 3.2.0, 3.2.1, 3.3.0)
ERROR: No matching distribution found for super_gradients==3.3.1


### Minimalistic export example

Let start with the most simple example of exporting a model to ONNX format.
Expand Down Expand Up @@ -203,9 +212,9 @@ pred_boxes, pred_boxes.shape
[ 35.71795, 249.40926, 176.62216, 544.69794],
[182.39618, 249.49301, 301.44122, 529.3324 ],
...,
[ 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. ]]], dtype=float32),
[ -1. , -1. , -1. , -1. ],
[ -1. , -1. , -1. , -1. ],
[ -1. , -1. , -1. , -1. ]]], dtype=float32),
(1, 1000, 4))


Expand All @@ -219,8 +228,8 @@ pred_scores, pred_scores.shape



(array([[0.9694027, 0.9693378, 0.9665707, 0.9619047, 0.7538769, ...,
0. , 0. , 0. , 0. , 0. ]],
(array([[ 0.9694027, 0.9693378, 0.9665707, 0.9619047, 0.7538769, ...,
-1. , -1. , -1. , -1. , -1. ]],
dtype=float32),
(1, 1000))

Expand All @@ -235,8 +244,8 @@ pred_classes, pred_classes.shape



(array([[0, 0, 0, 0, 0, 0, 0, 0, 2, 2, ..., 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
dtype=int64),
(array([[ 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, ..., -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1]], dtype=int64),
(1, 1000))


Expand Down Expand Up @@ -295,7 +304,7 @@ show_predictions_from_batch_format(image, result)



![png](models_export_files/models_export_18_0.png)
![png](models_export_files/models_export_19_0.png)



Expand Down Expand Up @@ -411,7 +420,7 @@ show_predictions_from_flat_format(image, result)



![png](models_export_files/models_export_24_0.png)
![png](models_export_files/models_export_25_0.png)



Expand Down Expand Up @@ -447,7 +456,7 @@ show_predictions_from_flat_format(image, result)



![png](models_export_files/models_export_26_0.png)
![png](models_export_files/models_export_27_0.png)



Expand Down Expand Up @@ -481,7 +490,7 @@ show_predictions_from_flat_format(image, result)



![png](models_export_files/models_export_28_0.png)
![png](models_export_files/models_export_29_0.png)



Expand Down Expand Up @@ -522,12 +531,12 @@ result = session.run(outputs, {inputs[0]: image_bchw})
show_predictions_from_flat_format(image, result)
```

25%|█████████████████████████████████████████████████ | 4/16 [00:11<00:34, 2.90s/it]
25%|██████████████████████████████ | 4/16 [00:11<00:33, 2.79s/it]




![png](models_export_files/models_export_30_1.png)
![png](models_export_files/models_export_31_1.png)



Expand Down
207 changes: 152 additions & 55 deletions src/super_gradients/conversion/onnx/nms.py

Large diffs are not rendered by default.

159 changes: 109 additions & 50 deletions src/super_gradients/conversion/onnx/pose_nms.py

Large diffs are not rendered by default.

12 changes: 7 additions & 5 deletions src/super_gradients/conversion/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ def append_graphs(graph1: gs.Graph, graph2: gs.Graph, prefix: str = "graph2_") -
merged_graph.outputs.clear()
merged_graph.outputs = graph2.outputs

merged_graph.toposort()
# iteratively_infer_shapes(merged_graph)
merged_graph = merged_graph.fold_constants().toposort().cleanup()
merged_graph = iteratively_infer_shapes(merged_graph)

return merged_graph


def iteratively_infer_shapes(graph: gs.Graph) -> None:
def iteratively_infer_shapes(graph: gs.Graph) -> gs.Graph:
"""
Sanitize the graph by cleaning any unconnected nodes, do a topological resort,
and fold constant inputs values. When possible, run shape inference on the
Expand All @@ -65,7 +65,7 @@ def iteratively_infer_shapes(graph: gs.Graph) -> None:
for _ in range(3):
count_before = len(graph.nodes)

graph.cleanup().toposort()
graph = graph.cleanup().toposort()
try:
# for node in graph.nodes:
# for o in node.outputs:
Expand All @@ -76,7 +76,7 @@ def iteratively_infer_shapes(graph: gs.Graph) -> None:
except Exception as e:
logger.debug(f"Shape inference could not be performed at this time:\n{e}")
try:
graph.fold_constants(fold_shapes=True)
graph = graph.fold_constants(fold_shapes=True)
except TypeError as e:
logger.error("This version of ONNX GraphSurgeon does not support folding shapes, " f"please upgrade your onnx_graphsurgeon module. Error:\n{e}")
raise
Expand All @@ -86,3 +86,5 @@ def iteratively_infer_shapes(graph: gs.Graph) -> None:
# No new folding occurred in this iteration, so we can stop for now.
break
logger.debug(f"Folded {count_before - count_after} constants.")

return graph
89 changes: 61 additions & 28 deletions src/super_gradients/conversion/tensorrt/nms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import tempfile
from typing import Optional, Mapping

import numpy as np
import onnx
Expand All @@ -10,63 +11,82 @@
from super_gradients.conversion.conversion_enums import DetectionOutputFormatMode
from super_gradients.conversion.conversion_utils import numpy_dtype_to_torch_dtype
from super_gradients.conversion.gs_utils import import_onnx_graphsurgeon_or_fail_with_instructions
from super_gradients.conversion.onnx.utils import append_graphs
from super_gradients.conversion.onnx.utils import append_graphs, iteratively_infer_shapes

logger = get_logger(__name__)

gs = import_onnx_graphsurgeon_or_fail_with_instructions()


class ConvertTRTFormatToFlatTensor(nn.Module):
"""
Convert the predictions from EfficientNMS_TRT node format to flat tensor format.
This node is supported on TensorRT 8.5.3+
"""

__constants__ = ["batch_size", "max_predictions_per_image"]

def __init__(self, batch_size: int, max_predictions_per_image: int):
"""
Convert the predictions from TensorRT format to flat tensor format.
:param batch_size: A fixed batch size for the model
:param max_predictions_per_image: Maximum number of predictions per image
"""
super().__init__()
self.batch_size = batch_size
self.max_predictions_per_image = max_predictions_per_image

def forward(self, num_predictions: Tensor, pred_boxes: Tensor, pred_scores: Tensor, pred_classes: Tensor) -> Tensor:
"""
Convert the predictions from "batch" format to "flat" tensor.
Convert the predictions from "batch" format to "flat" format.
:param num_predictions: [B,1] The number of predictions for each image in the batch.
:param pred_boxes: [B, max_predictions_per_image, 4] The predicted bounding boxes for each image in the batch.
:param pred_scores: [B, max_predictions_per_image] The predicted scores for each image in the batch.
:param pred_classes: [B, max_predictions_per_image] The predicted classes for each image in the batch.
:return: Tensor of shape [N, 7] The predictions in flat tensor format.
N is the total number of predictions in the entire batch.
Each row will contain [image_index, x1, y1, x2, y2, class confidence, class index] values.
:param pred_boxes: [B, max_predictions_per_image, 4] The predicted bounding boxes for each image in the batch.
:param pred_scores: [B, max_predictions_per_image] The predicted scores for each image in the batch.
:param pred_classes: [B, max_predictions_per_image] The predicted classes for each image in the batch.
:return: Tensor of shape [N, 7] The predictions in flat tensor format.
N is the total number of predictions in the entire batch.
Each row will contain [image_index, x1, y1, x2, y2, class confidence, class index] values.
"""
batch_indexes = (
torch.arange(start=0, end=self.batch_size, step=1, device=num_predictions.device).view(-1, 1).repeat(1, pred_scores.shape[1])
) # [B, max_predictions_per_image]

preds_indexes = (
torch.arange(start=0, end=self.max_predictions_per_image, step=1, device=num_predictions.device).view(1, -1, 1).repeat(self.batch_size, 1, 1)
) # [B, max_predictions_per_image, 1]
preds_indexes = torch.arange(start=0, end=self.max_predictions_per_image, step=1, device=num_predictions.device).view(
1, -1
) # [1, max_predictions_per_image]

flat_predictions = torch.cat(
[
preds_indexes.to(dtype=pred_scores.dtype),
batch_indexes.unsqueeze(-1).to(dtype=pred_scores.dtype),
pred_boxes,
pred_scores.unsqueeze(dim=-1),
pred_classes.unsqueeze(dim=-1).to(pred_scores.dtype),
],
dim=-1,
) # [B, max_predictions_per_image, 8]

num_predictions = num_predictions.repeat(1, self.max_predictions_per_image) # [B, max_predictions_per_image]
) # [B, max_predictions_per_image, 7]

mask = (flat_predictions[:, :, 0] < num_predictions) & (flat_predictions[:, :, 1] == batch_indexes) # [B, max_predictions_per_image]
mask: Tensor = preds_indexes < num_predictions.view((self.batch_size, 1)) # [B, max_predictions_per_image]

# Compatible
mask = mask.view(-1)
flat_predictions = flat_predictions.view(self.max_predictions_per_image * self.batch_size, 7)
flat_predictions = flat_predictions[mask] # [N, 7]
return flat_predictions[:, 1:]

return flat_predictions

@classmethod
def as_graph(cls, batch_size: int, max_predictions_per_image: int, dtype: torch.dtype, device: torch.device) -> gs.Graph:
def as_graph(
cls, batch_size: int, max_predictions_per_image: int, dtype: torch.dtype, device: torch.device, onnx_export_kwargs: Optional[Mapping] = None
) -> gs.Graph:
if onnx_export_kwargs is None:
onnx_export_kwargs = {}

with tempfile.TemporaryDirectory() as tmpdirname:
onnx_file = os.path.join(tmpdirname, "ConvertTRTFormatToFlatTensor.onnx")
onnx_file = os.path.join(tmpdirname, "ConvertTRTFormatToFlatTensorTMP.onnx")

num_detections = torch.randint(1, max_predictions_per_image, (batch_size, 1), dtype=torch.int32, device=device)
pred_boxes = torch.zeros((batch_size, max_predictions_per_image, 4), dtype=dtype, device=device)
Expand All @@ -80,9 +100,12 @@ def as_graph(cls, batch_size: int, max_predictions_per_image: int, dtype: torch.
input_names=["num_detections", "pred_boxes", "pred_scores", "pred_classes"],
output_names=["flat_predictions"],
dynamic_axes={"flat_predictions": {0: "num_predictions"}},
**onnx_export_kwargs,
)

convert_format_graph = gs.import_onnx(onnx.load(onnx_file))
convert_format_graph = convert_format_graph.fold_constants().cleanup().toposort()
convert_format_graph = iteratively_infer_shapes(convert_format_graph)
return convert_format_graph


Expand All @@ -96,17 +119,23 @@ def attach_tensorrt_nms(
batch_size: int,
output_predictions_format: DetectionOutputFormatMode,
device: torch.device,
onnx_export_kwargs: Optional[Mapping] = None,
):
"""
Attach TensorRT NMS plugin to the ONNX model
:param onnx_model_path:
:param output_onnx_model_path:
:param max_predictions_per_image: Maximum number of predictions per image
:param precision:
:param batch_size:
:return:
:param onnx_model_path: Path to the original model in ONNX format to attach the NMS plugin to.
:param output_onnx_model_path: Path to save the new ONNX model with the NMS plugin attached.
:param num_pre_nms_predictions: Number of predictions that goes into NMS.
:param max_predictions_per_image: Maximum number of predictions per image (after NMS).
:param batch_size: Batch size of the model.
:param confidence_threshold: Confidence threshold for NMS step.
:param nms_threshold: NMS IoU threshold.
:param output_predictions_format: Output predictions format.
:param device: Device to run the model on.
:param onnx_export_kwargs: Additional kwargs to pass to torch.onnx.export
"""

graph = gs.import_onnx(onnx.load(onnx_model_path))
# graph.fold_constants()

Expand Down Expand Up @@ -159,7 +188,11 @@ def attach_tensorrt_nms(

if output_predictions_format == DetectionOutputFormatMode.FLAT_FORMAT:
convert_format_graph = ConvertTRTFormatToFlatTensor.as_graph(
batch_size=batch_size, max_predictions_per_image=max_predictions_per_image, dtype=numpy_dtype_to_torch_dtype(pred_boxes.dtype), device=device
batch_size=batch_size,
max_predictions_per_image=max_predictions_per_image,
dtype=numpy_dtype_to_torch_dtype(pred_boxes.dtype),
device=device,
onnx_export_kwargs=onnx_export_kwargs,
)
graph = append_graphs(graph, convert_format_graph)
elif output_predictions_format == DetectionOutputFormatMode.BATCH_FORMAT:
Expand All @@ -168,8 +201,8 @@ def attach_tensorrt_nms(
raise NotImplementedError(f"Currently not supports output_predictions_format: {output_predictions_format}")

# Final cleanup & save
graph.cleanup().toposort()
# iteratively_infer_shapes(graph)
graph = graph.cleanup().toposort()
graph = iteratively_infer_shapes(graph)

logger.debug(f"Final graph outputs: {graph.outputs}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"execution_count": null,
"outputs": [],
"source": [
"!pip install super_gradients==3.2.1"
"!pip install super_gradients==3.3.1"
],
"metadata": {
"collapsed": false
Expand Down
3 changes: 2 additions & 1 deletion src/super_gradients/module_interfaces/exportable_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def export(
if output_predictions_format == DetectionOutputFormatMode.FLAT_FORMAT:
logger.warning(
"Support of flat predictions format in TensorRT is experimental and may not work on all versions of TensorRT. "
"We recommend using TensorRT 8.4.1 or newer. On older versions this format will not work. "
"We recommend using TensorRT 8.5.3 or newer. On older versions of TensorRT this format will not work. "
"If you encountering issues loading exported model in TensorRT, please try upgrading TensorRT to latest version. "
"Alternatively, you can export the model to output predictions in batch format by "
"specifying output_predictions_format=DetectionOutputFormatMode.BATCH_FORMAT. "
Expand All @@ -529,6 +529,7 @@ def export(
batch_size=batch_size,
output_predictions_format=output_predictions_format,
device=device,
onnx_export_kwargs=onnx_export_kwargs,
)

if onnx_simplify:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,14 @@ def forward(self, inputs: Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, ...]]) -> T
pred_bboxes, pred_scores = inputs[0]

nms_top_k = self.num_pre_nms_predictions
batch_size, num_anchors, _ = pred_scores.size()

pred_cls_conf, _ = torch.max(pred_scores, dim=2)
topk_candidates = torch.topk(pred_cls_conf, dim=1, k=nms_top_k, largest=True, sorted=True)

offsets = nms_top_k * torch.arange(pred_cls_conf.size(0), device=pred_cls_conf.device)
flat_indices = topk_candidates.indices + offsets.reshape(pred_cls_conf.size(0), 1)
flat_indices = torch.flatten(flat_indices)
offsets = num_anchors * torch.arange(batch_size, device=pred_cls_conf.device)
indices_with_offset = topk_candidates.indices + offsets.reshape(batch_size, 1)
flat_indices = torch.flatten(indices_with_offset)

output_pred_bboxes = pred_bboxes.reshape(-1, pred_bboxes.size(2))[flat_indices, :].reshape(pred_bboxes.size(0), nms_top_k, pred_bboxes.size(2))
output_pred_scores = pred_scores.reshape(-1, pred_scores.size(2))[flat_indices, :].reshape(pred_scores.size(0), nms_top_k, pred_scores.size(2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -736,13 +736,15 @@ def forward(self, predictions):
if self.with_confidence:
pred_scores = pred_scores * conf

pred_cls_conf, _ = torch.max(pred_scores, dim=2)
nms_top_k = self.num_pre_nms_predictions
batch_size, num_anchors, _ = pred_scores.size()

pred_cls_conf, _ = torch.max(pred_scores, dim=2)
topk_candidates = torch.topk(pred_cls_conf, dim=1, k=nms_top_k, largest=True, sorted=True)

offsets = nms_top_k * torch.arange(pred_cls_conf.size(0), device=pred_cls_conf.device)
flat_indices = topk_candidates.indices + offsets.reshape(pred_cls_conf.size(0), 1)
flat_indices = torch.flatten(flat_indices)
offsets = num_anchors * torch.arange(batch_size, device=pred_cls_conf.device)
indices_with_offset = topk_candidates.indices + offsets.reshape(batch_size, 1)
flat_indices = torch.flatten(indices_with_offset)

output_pred_bboxes = pred_bboxes.reshape(-1, pred_bboxes.size(2))[flat_indices, :].reshape(pred_bboxes.size(0), nms_top_k, pred_bboxes.size(2))
output_pred_scores = pred_scores.reshape(-1, pred_scores.size(2))[flat_indices, :].reshape(pred_scores.size(0), nms_top_k, pred_scores.size(2))
Expand Down
Loading

0 comments on commit 0515496

Please sign in to comment.