Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix of model.export() to work correct with bs>1 #1551

Merged
merged 18 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 152 additions & 55 deletions src/super_gradients/conversion/onnx/nms.py

Large diffs are not rendered by default.

159 changes: 109 additions & 50 deletions src/super_gradients/conversion/onnx/pose_nms.py

Large diffs are not rendered by default.

12 changes: 7 additions & 5 deletions src/super_gradients/conversion/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ def append_graphs(graph1: gs.Graph, graph2: gs.Graph, prefix: str = "graph2_") -
merged_graph.outputs.clear()
merged_graph.outputs = graph2.outputs

merged_graph.toposort()
# iteratively_infer_shapes(merged_graph)
merged_graph = merged_graph.fold_constants().toposort().cleanup()
merged_graph = iteratively_infer_shapes(merged_graph)

return merged_graph


def iteratively_infer_shapes(graph: gs.Graph) -> None:
def iteratively_infer_shapes(graph: gs.Graph) -> gs.Graph:
"""
Sanitize the graph by cleaning any unconnected nodes, do a topological resort,
and fold constant inputs values. When possible, run shape inference on the
Expand All @@ -65,7 +65,7 @@ def iteratively_infer_shapes(graph: gs.Graph) -> None:
for _ in range(3):
count_before = len(graph.nodes)

graph.cleanup().toposort()
graph = graph.cleanup().toposort()
try:
# for node in graph.nodes:
# for o in node.outputs:
Expand All @@ -76,7 +76,7 @@ def iteratively_infer_shapes(graph: gs.Graph) -> None:
except Exception as e:
logger.debug(f"Shape inference could not be performed at this time:\n{e}")
try:
graph.fold_constants(fold_shapes=True)
graph = graph.fold_constants(fold_shapes=True)
except TypeError as e:
logger.error("This version of ONNX GraphSurgeon does not support folding shapes, " f"please upgrade your onnx_graphsurgeon module. Error:\n{e}")
raise
Expand All @@ -86,3 +86,5 @@ def iteratively_infer_shapes(graph: gs.Graph) -> None:
# No new folding occurred in this iteration, so we can stop for now.
break
logger.debug(f"Folded {count_before - count_after} constants.")

return graph
89 changes: 61 additions & 28 deletions src/super_gradients/conversion/tensorrt/nms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import tempfile
from typing import Optional, Mapping

import numpy as np
import onnx
Expand All @@ -10,63 +11,82 @@
from super_gradients.conversion.conversion_enums import DetectionOutputFormatMode
from super_gradients.conversion.conversion_utils import numpy_dtype_to_torch_dtype
from super_gradients.conversion.gs_utils import import_onnx_graphsurgeon_or_fail_with_instructions
from super_gradients.conversion.onnx.utils import append_graphs
from super_gradients.conversion.onnx.utils import append_graphs, iteratively_infer_shapes

logger = get_logger(__name__)

gs = import_onnx_graphsurgeon_or_fail_with_instructions()


class ConvertTRTFormatToFlatTensor(nn.Module):
"""
Convert the predictions from EfficientNMS_TRT node format to flat tensor format.
This node is supported on TensorRT 8.5.3+
"""

__constants__ = ["batch_size", "max_predictions_per_image"]

def __init__(self, batch_size: int, max_predictions_per_image: int):
"""
Convert the predictions from TensorRT format to flat tensor format.
:param batch_size: A fixed batch size for the model
:param max_predictions_per_image: Maximum number of predictions per image
"""
super().__init__()
self.batch_size = batch_size
self.max_predictions_per_image = max_predictions_per_image

def forward(self, num_predictions: Tensor, pred_boxes: Tensor, pred_scores: Tensor, pred_classes: Tensor) -> Tensor:
"""
Convert the predictions from "batch" format to "flat" tensor.
Convert the predictions from "batch" format to "flat" format.
:param num_predictions: [B,1] The number of predictions for each image in the batch.
:param pred_boxes: [B, max_predictions_per_image, 4] The predicted bounding boxes for each image in the batch.
:param pred_scores: [B, max_predictions_per_image] The predicted scores for each image in the batch.
:param pred_classes: [B, max_predictions_per_image] The predicted classes for each image in the batch.
:return: Tensor of shape [N, 7] The predictions in flat tensor format.
N is the total number of predictions in the entire batch.
Each row will contain [image_index, x1, y1, x2, y2, class confidence, class index] values.
:param pred_boxes: [B, max_predictions_per_image, 4] The predicted bounding boxes for each image in the batch.
:param pred_scores: [B, max_predictions_per_image] The predicted scores for each image in the batch.
:param pred_classes: [B, max_predictions_per_image] The predicted classes for each image in the batch.
:return: Tensor of shape [N, 7] The predictions in flat tensor format.
N is the total number of predictions in the entire batch.
Each row will contain [image_index, x1, y1, x2, y2, class confidence, class index] values.
"""
batch_indexes = (
torch.arange(start=0, end=self.batch_size, step=1, device=num_predictions.device).view(-1, 1).repeat(1, pred_scores.shape[1])
) # [B, max_predictions_per_image]

preds_indexes = (
torch.arange(start=0, end=self.max_predictions_per_image, step=1, device=num_predictions.device).view(1, -1, 1).repeat(self.batch_size, 1, 1)
) # [B, max_predictions_per_image, 1]
preds_indexes = torch.arange(start=0, end=self.max_predictions_per_image, step=1, device=num_predictions.device).view(
1, -1
) # [1, max_predictions_per_image]

flat_predictions = torch.cat(
[
preds_indexes.to(dtype=pred_scores.dtype),
batch_indexes.unsqueeze(-1).to(dtype=pred_scores.dtype),
pred_boxes,
pred_scores.unsqueeze(dim=-1),
pred_classes.unsqueeze(dim=-1).to(pred_scores.dtype),
],
dim=-1,
) # [B, max_predictions_per_image, 8]

num_predictions = num_predictions.repeat(1, self.max_predictions_per_image) # [B, max_predictions_per_image]
) # [B, max_predictions_per_image, 7]

mask = (flat_predictions[:, :, 0] < num_predictions) & (flat_predictions[:, :, 1] == batch_indexes) # [B, max_predictions_per_image]
mask: Tensor = preds_indexes < num_predictions.view((self.batch_size, 1)) # [B, max_predictions_per_image]

# Compatible
mask = mask.view(-1)
flat_predictions = flat_predictions.view(self.max_predictions_per_image * self.batch_size, 7)
flat_predictions = flat_predictions[mask] # [N, 7]
return flat_predictions[:, 1:]

return flat_predictions

@classmethod
def as_graph(cls, batch_size: int, max_predictions_per_image: int, dtype: torch.dtype, device: torch.device) -> gs.Graph:
def as_graph(
cls, batch_size: int, max_predictions_per_image: int, dtype: torch.dtype, device: torch.device, onnx_export_kwargs: Optional[Mapping] = None
) -> gs.Graph:
if onnx_export_kwargs is None:
onnx_export_kwargs = {}

with tempfile.TemporaryDirectory() as tmpdirname:
onnx_file = os.path.join(tmpdirname, "ConvertTRTFormatToFlatTensor.onnx")
onnx_file = os.path.join(tmpdirname, "ConvertTRTFormatToFlatTensorTMP.onnx")

num_detections = torch.randint(1, max_predictions_per_image, (batch_size, 1), dtype=torch.int32, device=device)
pred_boxes = torch.zeros((batch_size, max_predictions_per_image, 4), dtype=dtype, device=device)
Expand All @@ -80,9 +100,12 @@ def as_graph(cls, batch_size: int, max_predictions_per_image: int, dtype: torch.
input_names=["num_detections", "pred_boxes", "pred_scores", "pred_classes"],
output_names=["flat_predictions"],
dynamic_axes={"flat_predictions": {0: "num_predictions"}},
**onnx_export_kwargs,
)

convert_format_graph = gs.import_onnx(onnx.load(onnx_file))
convert_format_graph = convert_format_graph.fold_constants().cleanup().toposort()
convert_format_graph = iteratively_infer_shapes(convert_format_graph)
return convert_format_graph


Expand All @@ -96,17 +119,23 @@ def attach_tensorrt_nms(
batch_size: int,
output_predictions_format: DetectionOutputFormatMode,
device: torch.device,
onnx_export_kwargs: Optional[Mapping] = None,
):
"""
Attach TensorRT NMS plugin to the ONNX model
:param onnx_model_path:
:param output_onnx_model_path:
:param max_predictions_per_image: Maximum number of predictions per image
:param precision:
:param batch_size:
:return:
:param onnx_model_path: Path to the original model in ONNX format to attach the NMS plugin to.
:param output_onnx_model_path: Path to save the new ONNX model with the NMS plugin attached.
:param num_pre_nms_predictions: Number of predictions that goes into NMS.
:param max_predictions_per_image: Maximum number of predictions per image (after NMS).
:param batch_size: Batch size of the model.
:param confidence_threshold: Confidence threshold for NMS step.
:param nms_threshold: NMS IoU threshold.
:param output_predictions_format: Output predictions format.
:param device: Device to run the model on.
:param onnx_export_kwargs: Additional kwargs to pass to torch.onnx.export
"""

graph = gs.import_onnx(onnx.load(onnx_model_path))
# graph.fold_constants()

Expand Down Expand Up @@ -159,7 +188,11 @@ def attach_tensorrt_nms(

if output_predictions_format == DetectionOutputFormatMode.FLAT_FORMAT:
convert_format_graph = ConvertTRTFormatToFlatTensor.as_graph(
batch_size=batch_size, max_predictions_per_image=max_predictions_per_image, dtype=numpy_dtype_to_torch_dtype(pred_boxes.dtype), device=device
batch_size=batch_size,
max_predictions_per_image=max_predictions_per_image,
dtype=numpy_dtype_to_torch_dtype(pred_boxes.dtype),
device=device,
onnx_export_kwargs=onnx_export_kwargs,
)
graph = append_graphs(graph, convert_format_graph)
elif output_predictions_format == DetectionOutputFormatMode.BATCH_FORMAT:
Expand All @@ -168,8 +201,8 @@ def attach_tensorrt_nms(
raise NotImplementedError(f"Currently not supports output_predictions_format: {output_predictions_format}")

# Final cleanup & save
graph.cleanup().toposort()
# iteratively_infer_shapes(graph)
graph = graph.cleanup().toposort()
graph = iteratively_infer_shapes(graph)

logger.debug(f"Final graph outputs: {graph.outputs}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def export(
if output_predictions_format == DetectionOutputFormatMode.FLAT_FORMAT:
logger.warning(
"Support of flat predictions format in TensorRT is experimental and may not work on all versions of TensorRT. "
"We recommend using TensorRT 8.4.1 or newer. On older versions this format will not work. "
"We recommend using TensorRT 8.5.3 or newer. On older versions of TensorRT this format will not work. "
"If you encountering issues loading exported model in TensorRT, please try upgrading TensorRT to latest version. "
"Alternatively, you can export the model to output predictions in batch format by "
"specifying output_predictions_format=DetectionOutputFormatMode.BATCH_FORMAT. "
Expand All @@ -529,6 +529,7 @@ def export(
batch_size=batch_size,
output_predictions_format=output_predictions_format,
device=device,
onnx_export_kwargs=onnx_export_kwargs,
)

if onnx_simplify:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,14 @@ def forward(self, inputs: Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, ...]]) -> T
pred_bboxes, pred_scores = inputs[0]

nms_top_k = self.num_pre_nms_predictions
batch_size, num_anchors, _ = pred_scores.size()

pred_cls_conf, _ = torch.max(pred_scores, dim=2)
topk_candidates = torch.topk(pred_cls_conf, dim=1, k=nms_top_k, largest=True, sorted=True)

offsets = nms_top_k * torch.arange(pred_cls_conf.size(0), device=pred_cls_conf.device)
flat_indices = topk_candidates.indices + offsets.reshape(pred_cls_conf.size(0), 1)
flat_indices = torch.flatten(flat_indices)
offsets = num_anchors * torch.arange(batch_size, device=pred_cls_conf.device)
indices_with_offset = topk_candidates.indices + offsets.reshape(batch_size, 1)
flat_indices = torch.flatten(indices_with_offset)

output_pred_bboxes = pred_bboxes.reshape(-1, pred_bboxes.size(2))[flat_indices, :].reshape(pred_bboxes.size(0), nms_top_k, pred_bboxes.size(2))
output_pred_scores = pred_scores.reshape(-1, pred_scores.size(2))[flat_indices, :].reshape(pred_scores.size(0), nms_top_k, pred_scores.size(2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -736,13 +736,15 @@ def forward(self, predictions):
if self.with_confidence:
pred_scores = pred_scores * conf

pred_cls_conf, _ = torch.max(pred_scores, dim=2)
nms_top_k = self.num_pre_nms_predictions
batch_size, num_anchors, _ = pred_scores.size()

pred_cls_conf, _ = torch.max(pred_scores, dim=2)
topk_candidates = torch.topk(pred_cls_conf, dim=1, k=nms_top_k, largest=True, sorted=True)

offsets = nms_top_k * torch.arange(pred_cls_conf.size(0), device=pred_cls_conf.device)
flat_indices = topk_candidates.indices + offsets.reshape(pred_cls_conf.size(0), 1)
flat_indices = torch.flatten(flat_indices)
offsets = num_anchors * torch.arange(batch_size, device=pred_cls_conf.device)
indices_with_offset = topk_candidates.indices + offsets.reshape(batch_size, 1)
flat_indices = torch.flatten(indices_with_offset)

output_pred_bboxes = pred_bboxes.reshape(-1, pred_bboxes.size(2))[flat_indices, :].reshape(pred_bboxes.size(0), nms_top_k, pred_bboxes.size(2))
output_pred_scores = pred_scores.reshape(-1, pred_scores.size(2))[flat_indices, :].reshape(pred_scores.size(0), nms_top_k, pred_scores.size(2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,14 @@ def forward(self, inputs: Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, ...]]):
pred_bboxes, pred_scores = inputs[0]

nms_top_k = self.num_pre_nms_predictions
batch_size, num_anchors, _ = pred_scores.size()

pred_cls_conf, _ = torch.max(pred_scores, dim=2)
pred_cls_conf, _ = torch.max(pred_scores, dim=2) # [B, Anchors]
topk_candidates = torch.topk(pred_cls_conf, dim=1, k=nms_top_k, largest=True, sorted=True)

offsets = nms_top_k * torch.arange(pred_cls_conf.size(0), device=pred_cls_conf.device)
flat_indices = topk_candidates.indices + offsets.reshape(pred_cls_conf.size(0), 1)
flat_indices = torch.flatten(flat_indices)
offsets = num_anchors * torch.arange(batch_size, device=pred_cls_conf.device)
indices_with_offset = topk_candidates.indices + offsets.reshape(batch_size, 1)
flat_indices = torch.flatten(indices_with_offset)

output_pred_bboxes = pred_bboxes.reshape(-1, pred_bboxes.size(2))[flat_indices, :].reshape(pred_bboxes.size(0), nms_top_k, pred_bboxes.size(2))
output_pred_scores = pred_scores.reshape(-1, pred_scores.size(2))[flat_indices, :].reshape(pred_scores.size(0), nms_top_k, pred_scores.size(2))
Expand All @@ -67,6 +68,24 @@ def forward(self, inputs: Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, ...]]):


class YoloNAS(ExportableObjectDetectionModel, CustomizableDetector):
"""

Export to ONNX/TRT Support matrix
ONNX files generated with PyTorch 2.0.1 for ONNX opset_version=14

| Batch Size | Export Engine | Format | OnnxRuntime 1.13.1 | TensorRT 8.4.2 | TensorRT 8.5.3 | TensorRT 8.6.1 |
|------------|---------------|--------|--------------------|----------------|----------------|----------------|
| 1 | ONNX | Flat | Yes | Yes | Yes | Yes |
| >1 | ONNX | Flat | Yes | No | No | No |
| 1 | ONNX | Batch | Yes | No | Yes | Yes |
| >1 | ONNX | Batch | Yes | No | No | Yes |
| 1 | TensorRT | Flat | No | No | Yes | Yes |
| >1 | TensorRT | Flat | No | No | Yes | Yes |
| 1 | TensorRT | Batch | No | Yes | Yes | Yes |
| >1 | TensorRT | Batch | No | Yes | Yes | Yes |

"""

def __init__(
self,
backbone: Union[str, dict, HpmStruct, DictConfig],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,13 @@ def forward(self, inputs: Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, ...]]):
pred_bboxes_xyxy, pred_bboxes_conf, pred_pose_coords, pred_pose_scores = inputs[0]

nms_top_k = self.num_pre_nms_predictions
batch_size, num_anchors, _ = pred_bboxes_conf.size()

topk_candidates = torch.topk(pred_bboxes_conf, dim=1, k=nms_top_k, largest=True, sorted=True)

offsets = nms_top_k * torch.arange(pred_bboxes_conf.size(0), device=pred_bboxes_conf.device)
flat_indices = topk_candidates.indices + offsets.reshape(pred_bboxes_conf.size(0), 1, 1)
flat_indices = torch.flatten(flat_indices)
offsets = num_anchors * torch.arange(batch_size, device=pred_bboxes_conf.device)
indices_with_offset = topk_candidates.indices + offsets.reshape(batch_size, 1, 1)
flat_indices = torch.flatten(indices_with_offset)

pred_poses_and_scores = torch.cat([pred_pose_coords, pred_pose_scores.unsqueeze(3)], dim=3)

Expand All @@ -90,6 +91,21 @@ def forward(self, inputs: Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, ...]]):


class YoloNASPose(CustomizableDetector, ExportablePoseEstimationModel):
"""
YoloNASPose model

Exported model support matrix

| Batch Size | Format | OnnxRuntime 1.13.1 | TensorRT 8.4.2 | TensorRT 8.5.3 | TensorRT 8.6.1 |
|------------|--------|--------------------|----------------|----------------|----------------|
| 1 | Flat | Yes | Yes | Yes | Yes |
| >1 | Flat | Yes | Yes | Yes | Yes |
| 1 | Batch | Yes | No | No | Yes |
| >1 | Batch | Yes | No | No | Yes |

ONNX files generated with PyTorch 2.0.1 for ONNX opset_version=14
"""

def __init__(
self,
backbone: Union[str, dict, HpmStruct, DictConfig],
Expand Down
Loading