Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Export API #1318

Merged
merged 57 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
0414394
Designing export API
BloodAxe Jul 14, 2023
054152d
Export WIP
BloodAxe Jul 17, 2023
1872443
ONNX NMS
BloodAxe Jul 17, 2023
4a76131
Export WIP
BloodAxe Jul 18, 2023
fffcdc5
Refactor test and move benchmark API to functino
BloodAxe Jul 19, 2023
55afd70
Export WIP
BloodAxe Jul 19, 2023
ed8730b
Make the top_k a constant and not variable since TRT export does not …
BloodAxe Jul 19, 2023
cdc41d7
Refactor test and move benchmark API to functino
BloodAxe Jul 19, 2023
72e7d66
Added option to change the output format
BloodAxe Jul 20, 2023
015357d
Refactor test and move benchmark API to functino
BloodAxe Jul 20, 2023
794fc73
Added option to change the output format
BloodAxe Jul 21, 2023
da99bfa
Refactor test and move benchmark API to functino
BloodAxe Jul 21, 2023
a18a19d
Merge branch 'master' into feature/SG-1001
BloodAxe Jul 21, 2023
ca6ab9b
Fixing export to make it TRT friendly
BloodAxe Jul 21, 2023
c7f123b
Fixing export to make it TRT friendly
BloodAxe Jul 21, 2023
f08ca47
Fixing export to make it TRT friendly
BloodAxe Jul 24, 2023
8c188ca
Fixing export to make it TRT friendly
BloodAxe Jul 24, 2023
f7e6e93
Remove unused classes
BloodAxe Jul 24, 2023
56daabd
Remove unused classes
BloodAxe Jul 24, 2023
3ad8636
Remove unused classes
BloodAxe Jul 24, 2023
3717316
Remove unused classes
BloodAxe Jul 24, 2023
4ac7912
Fixing export to FP16
BloodAxe Jul 28, 2023
fa6167d
Fixing export to FP16
BloodAxe Jul 28, 2023
fdd6582
Improve output of the benchmark result
BloodAxe Jul 28, 2023
ff1fcff
Improve device handling when exporting NMS
BloodAxe Jul 28, 2023
16521a9
Improve device handling when exporting NMS
BloodAxe Jul 28, 2023
46e5a6e
Fix nms format conversion modules export
BloodAxe Jul 28, 2023
6691b03
Revert unit test
BloodAxe Jul 28, 2023
75bcd46
Improve model device handling
BloodAxe Jul 28, 2023
66ceaa1
Adding docs
BloodAxe Jul 31, 2023
5b57459
Adding docs
BloodAxe Jul 31, 2023
acf2684
Adding docs
BloodAxe Jul 31, 2023
7e1c417
Merge branch 'master' into feature/SG-1001
BloodAxe Jul 31, 2023
e2683e3
Adding docs
BloodAxe Aug 1, 2023
435c927
Merge remote-tracking branch 'origin/feature/SG-1001' into feature/SG…
BloodAxe Aug 1, 2023
e2d4fec
Address TODO's after code review
BloodAxe Aug 1, 2023
57890f3
Added check whether model is already quantized
BloodAxe Aug 2, 2023
efcf93f
Merge branch 'master' into feature/SG-1001
BloodAxe Aug 3, 2023
0746cf8
Install pytorch quantization package
BloodAxe Aug 3, 2023
912b0d8
Merge remote-tracking branch 'origin/feature/SG-1001' into feature/SG…
BloodAxe Aug 3, 2023
2f0ef9b
Added printin of user-friendly description on how to use the exported…
BloodAxe Aug 3, 2023
418ab1b
Update docs
BloodAxe Aug 4, 2023
948f9f6
Update docs
BloodAxe Aug 7, 2023
10b4fd3
Uninstall SG
BloodAxe Aug 7, 2023
9ed8511
Added onnx_graphsurgeon
BloodAxe Aug 7, 2023
2c4434d
Added onnx_graphsurgeon
BloodAxe Aug 7, 2023
c7cba74
Put extra index url at the top
BloodAxe Aug 7, 2023
9dddff5
Put extra index url before the package that requires it
BloodAxe Aug 7, 2023
c23042e
Fix --index-url to --extra-index-url
BloodAxe Aug 7, 2023
da95d55
get_requirements to handle --extra-index-url correctly
BloodAxe Aug 7, 2023
5acb9a9
Made method draw_box_title public
BloodAxe Aug 7, 2023
10add2a
Merge branch 'master' into feature/SG-1001
BloodAxe Aug 7, 2023
761653e
Merge branch 'master' into feature/SG-1001
BloodAxe Aug 8, 2023
c9a1d79
Fix tests
BloodAxe Aug 8, 2023
79f2e75
Fix missing HasPredict for BaseClassifier model
BloodAxe Aug 8, 2023
88ce9a7
Merge branch 'master' into feature/SG-1001
BloodAxe Aug 8, 2023
70cb07e
Make quantization parameters overridable
BloodAxe Aug 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@ recipe_accuracy_tests:
python src/super_gradients/train_from_recipe.py --config-name=cityscapes_regseg48 experiment_name=shortened_cityscapes_regseg48_iou_test epochs=10 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
python src/super_gradients/examples/convert_recipe_example/convert_recipe_example.py --config-name=cifar10_conversion_params experiment_name=shortened_cifar10_resnet_accuracy_test
coverage run --source=super_gradients -m unittest tests/deci_core_recipe_test_suite_runner.py


examples_to_docs:
jupyter nbconvert --to markdown --output-dir="documentation/source/" --execute src/super_gradients/examples/model_export/models_export.ipynb
Louis-Dupont marked this conversation as resolved.
Show resolved Hide resolved
900 changes: 900 additions & 0 deletions documentation/source/models_export.md

Large diffs are not rendered by default.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24 changes: 24 additions & 0 deletions src/super_gradients/conversion/conversion_enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from enum import Enum

__all__ = ["DetectionOutputFormatMode", "ExportQuantizationMode", "ExportTargetBackend"]


class ExportTargetBackend(str, Enum):
"""Enum for specifying target backend for exporting a model."""

ONNXRUNTIME = "onnxruntime"
TENSORRT = "tensorrt"


class DetectionOutputFormatMode(str, Enum):
"""Enum for specifying output format for the detection model when postprocessing & NMS is enabled."""

FLAT_FORMAT = "flat"
BATCH_FORMAT = "batch"


class ExportQuantizationMode(str, Enum):
"""Enum for specifying quantization mode."""

FP16 = "fp16"
INT8 = "int8"
30 changes: 30 additions & 0 deletions src/super_gradients/conversion/conversion_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
import torch

__all__ = ["torch_dtype_to_numpy_dtype", "numpy_dtype_to_torch_dtype"]

_DTYPE_CORRESPONDENCE = [
(torch.float32, np.float32),
(torch.float64, np.float64),
(torch.float16, np.float16),
(torch.int32, np.int32),
(torch.int64, np.int64),
(torch.int16, np.int16),
(torch.int8, np.int8),
(torch.uint8, np.uint8),
(torch.bool, np.bool),
]


def torch_dtype_to_numpy_dtype(dtype: torch.dtype):
for torch_dtype, numpy_dtype in _DTYPE_CORRESPONDENCE:
if dtype == torch_dtype:
return numpy_dtype
raise NotImplementedError(f"Unsupported dtype: {dtype}")


def numpy_dtype_to_torch_dtype(dtype: np.dtype):
for torch_dtype, numpy_dtype in _DTYPE_CORRESPONDENCE:
if dtype == numpy_dtype:
return torch_dtype
raise NotImplementedError(f"Unsupported dtype: {dtype}")
301 changes: 301 additions & 0 deletions src/super_gradients/conversion/onnx/nms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
import os
import tempfile
from typing import Tuple

import numpy as np
import onnx
import onnx.shape_inference
import onnx_graphsurgeon as gs
import torch
from onnx import TensorProto
from torch import nn, Tensor

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.conversion.conversion_enums import DetectionOutputFormatMode
from super_gradients.conversion.conversion_utils import numpy_dtype_to_torch_dtype
from super_gradients.conversion.onnx.utils import append_graphs

logger = get_logger(__name__)


class PickNMSPredictionsAndReturnAsBatchedResult(nn.Module):
__constants__ = ("batch_size", "max_predictions_per_image")

def __init__(self, batch_size: int, num_pre_nms_predictions: int, max_predictions_per_image: int):
super().__init__()
self.batch_size = batch_size
self.num_pre_nms_predictions = num_pre_nms_predictions
self.max_predictions_per_image = max_predictions_per_image

def forward(self, pred_boxes: Tensor, pred_scores: Tensor, selected_indexes: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""
Select the predictions that are output by the NMS plugin.
:param pred_boxes: [B, N, 4] tensor, float32
:param pred_scores: [B, N, C] tensor, float32
:param selected_indexes: [num_selected_indices, 3], int64 - each row is [batch_indexes, label_indexes, boxes_indexes]
:return:

"""
batch_indexes, label_indexes, boxes_indexes = selected_indexes[:, 0], selected_indexes[:, 1], selected_indexes[:, 2]

selected_boxes = pred_boxes[batch_indexes, boxes_indexes]
selected_scores = pred_scores[batch_indexes, boxes_indexes, label_indexes]

predictions = torch.cat([batch_indexes.unsqueeze(1), selected_boxes, selected_scores.unsqueeze(1), label_indexes.unsqueeze(1)], dim=1)

predictions = torch.nn.functional.pad(
predictions, (0, 0, 0, self.max_predictions_per_image * self.batch_size - predictions.size(0)), value=-1, mode="constant"
)

batch_predictions = torch.zeros((self.batch_size, self.max_predictions_per_image, 6), dtype=predictions.dtype, device=predictions.device)

batch_indexes = torch.arange(start=0, end=self.batch_size, step=1, device=predictions.device, dtype=predictions.dtype)
masks = batch_indexes.view(-1, 1).eq(predictions[:, 0].view(1, -1)) # [B, N]

num_predictions = torch.sum(masks, dim=1).long()

for i in range(self.batch_size):
selected_predictions = predictions[masks[i]]
selected_predictions = selected_predictions[:, 1:]
batch_predictions[i] = torch.nn.functional.pad(
selected_predictions, (0, 0, 0, self.max_predictions_per_image - selected_predictions.size(0)), value=0, mode="constant"
)

pred_boxes = batch_predictions[:, :, 0:4]
pred_scores = batch_predictions[:, :, 4]
pred_classes = batch_predictions[:, :, 5].long()

return num_predictions.unsqueeze(1), pred_boxes, pred_scores, pred_classes

@classmethod
def as_graph(cls, batch_size: int, num_pre_nms_predictions: int, max_predictions_per_image, dtype: torch.dtype, device: torch.device) -> gs.Graph:
with tempfile.TemporaryDirectory() as tmpdirname:
onnx_file = os.path.join(tmpdirname, "PickNMSPredictionsAndReturnAsBatchedResult.onnx")
pred_boxes = torch.zeros((batch_size, num_pre_nms_predictions, 4), dtype=dtype, device=device)
pred_scores = torch.zeros((batch_size, num_pre_nms_predictions, 3), dtype=dtype, device=device)
selected_indexes = torch.zeros((max_predictions_per_image, 3), dtype=torch.int64, device=device)

torch.onnx.export(
PickNMSPredictionsAndReturnAsBatchedResult(
batch_size=batch_size, num_pre_nms_predictions=num_pre_nms_predictions, max_predictions_per_image=max_predictions_per_image
).to(device=device, dtype=dtype),
args=(pred_boxes, pred_scores, selected_indexes),
f=onnx_file,
input_names=["raw_boxes", "raw_scores", "selected_indexes"],
output_names=["num_predictions", "pred_boxes", "pred_scores", "pred_classes"],
dynamic_axes={
"raw_boxes": {
# 0: "batch_size",
# 1: "num_anchors"
},
"raw_scores": {
# 0: "batch_size",
# 1: "num_anchors",
2: "num_classes",
},
"selected_indexes": {0: "num_predictions"},
},
)

convert_format_graph = gs.import_onnx(onnx.load(onnx_file))
return convert_format_graph


class PickNMSPredictionsAndReturnAsFlatResult(nn.Module):
__constants__ = ("batch_size", "num_pre_nms_predictions", "max_predictions_per_image")

def __init__(self, batch_size: int, num_pre_nms_predictions: int, max_predictions_per_image: int):
super().__init__()
self.batch_size = batch_size
self.num_pre_nms_predictions = num_pre_nms_predictions
self.max_predictions_per_image = max_predictions_per_image

def forward(self, pred_boxes: Tensor, pred_scores: Tensor, selected_indexes: Tensor):
"""
Select the predictions that are output by the NMS plugin.
:param pred_boxes: [B, N, 4] tensor
:param pred_scores: [B, N, C] tensor
:param selected_indexes: [num_selected_indices, 3] - each row is [batch_indexes, label_indexes, boxes_indexes]
:return: A single tensor of [Nout, 7] shape, where Nout is the total number of detections across all images in the batch.
Each row will contain [image_index, x1, y1, x2, y2, class confidence, class index] values.

"""
batch_indexes, label_indexes, boxes_indexes = selected_indexes[:, 0], selected_indexes[:, 1], selected_indexes[:, 2]

selected_boxes = pred_boxes[batch_indexes, boxes_indexes]
selected_scores = pred_scores[batch_indexes, boxes_indexes, label_indexes]

return torch.cat(
[
batch_indexes.unsqueeze(1).to(selected_boxes.dtype),
selected_boxes,
selected_scores.unsqueeze(1),
label_indexes.unsqueeze(1).to(selected_boxes.dtype),
],
dim=1,
)

@classmethod
def as_graph(cls, batch_size: int, num_pre_nms_predictions: int, max_predictions_per_image: int, dtype: torch.dtype, device: torch.device) -> gs.Graph:
with tempfile.TemporaryDirectory() as tmpdirname:
onnx_file = os.path.join(tmpdirname, "PickNMSPredictionsAndReturnAsFlatTensor.onnx")
pred_boxes = torch.zeros((batch_size, num_pre_nms_predictions, 4), dtype=dtype, device=device)
pred_scores = torch.zeros((batch_size, num_pre_nms_predictions, 91), dtype=dtype, device=device)
selected_indexes = torch.zeros((max_predictions_per_image // 2, 3), dtype=torch.int64, device=device)

torch.onnx.export(
PickNMSPredictionsAndReturnAsFlatResult(
batch_size=batch_size, num_pre_nms_predictions=num_pre_nms_predictions, max_predictions_per_image=max_predictions_per_image
),
args=(pred_boxes, pred_scores, selected_indexes),
f=onnx_file,
input_names=["pred_boxes", "pred_scores", "selected_indexes"],
output_names=["flat_predictions"],
dynamic_axes={
"pred_boxes": {},
"pred_scores": {2: "num_classes"},
"selected_indexes": {0: "num_predictions"},
"flat_predictions": {0: "num_predictions"},
},
)

convert_format_graph = gs.import_onnx(onnx.load(onnx_file))
return convert_format_graph


def attach_onnx_nms(
onnx_model_path: str,
output_onnx_model_path,
num_pre_nms_predictions: int,
max_predictions_per_image: int,
confidence_threshold: float,
nms_threshold: float,
batch_size: int,
output_predictions_format: DetectionOutputFormatMode,
device: torch.device,
):
"""
Attach ONNX NMS plugin to the detection model.
The model should have exactly two outputs: pred_boxes and pred_scores.
- pred_boxes: [batch_size, num_pre_nms_predictions, 4]
- pred_scores: [batch_size, num_pre_nms_predictions, num_classes]
This function will add the NMS layer to the model and return predictions in the format defined by output_format.

:param onnx_model_path: Input ONNX model path
:param output_onnx_model_path: Output ONNX model path. Can be the same as input model path.
:param num_pre_nms_predictions:
:param batch_size: The batch size used for the inference.
:param max_predictions_per_image: Maximum number of predictions per image
:param confidence_threshold: The confidence threshold to use for detections.
:param nms_threshold: The NMS threshold to use for detections.
:param output_predictions_format: The output format of the predictions. Can be "flat" or "batch".

If output_format equals to "flat":
A single tensor of [N, 7] will be returned, where N is the total number of detections across all images in the batch.
Each row will contain [image_index, x1, y1, x2, y2, class_index, confidence].

If output_format equals to "batch" format:
A tuple of 4 tensors (num_detections, detection_boxes, detection_scores, detection_classes) will be returned:
- A tensor of [batch_size, 1] containing the image indices for each detection.
- A tensor of [batch_size, max_output_boxes, 4] containing the bounding box coordinates for each detection in [x1, y1, x2, y2] format.
- A tensor of [batch_size, max_output_boxes] containing the confidence scores for each detection.
- A tensor of [batch_size, max_output_boxes] containing the class indices for each detection.

:return: None
"""
graph = gs.import_onnx(onnx.load(onnx_model_path))
graph.fold_constants()

pred_boxes, pred_scores = graph.outputs

graph_output_dtype = pred_scores.dtype

if graph_output_dtype == np.float16:
pred_scores_f32 = gs.Variable(
name="pred_scores_f32",
dtype=np.float32,
shape=pred_scores.shape,
)
pred_boxes_f32 = gs.Variable(
name="pred_boxes_f32",
dtype=np.float32,
shape=pred_boxes.shape,
)
graph.layer(op="Cast", name="cast_boxes_to_fp32", inputs=[pred_boxes], outputs=[pred_boxes_f32], attrs={"to": TensorProto.FLOAT})
graph.layer(op="Cast", name="cast_scores_to_fp32", inputs=[pred_scores], outputs=[pred_scores_f32], attrs={"to": TensorProto.FLOAT})

pred_scores = pred_scores_f32
pred_boxes = pred_boxes_f32
elif graph_output_dtype == np.float32:
pass
else:
raise ValueError(f"Invalid dtype: {graph_output_dtype}")

permute_scores = gs.Variable(
name="permuted_scores",
dtype=np.float32,
)
graph.layer(op="Transpose", name="permute_scores", inputs=[pred_scores], outputs=[permute_scores], attrs={"perm": [0, 2, 1]})

op_inputs = [pred_boxes, permute_scores] + [
gs.Constant(name="max_output_boxes_per_class", values=np.array([max_predictions_per_image], dtype=np.int64)),
gs.Constant(name="iou_threshold", values=np.array([nms_threshold], dtype=np.float32)),
gs.Constant(name="score_threshold", values=np.array([confidence_threshold], dtype=np.float32)),
]
logger.debug(f"op_inputs: {op_inputs}")

# NMS Outputs
# selected indices from the boxes tensor. [num_selected_indices, 3], the selected index format is [batch_index, class_index, box_index].
output_selected_indices = gs.Variable(
name="selected_indices",
dtype=np.int64,
shape=["num_selected_indices", 3],
) # A scalar indicating the number of valid detections per batch image.

# Create the NMS Plugin node with the selected inputs. The outputs of the node will also
# become the final outputs of the graph.
graph.layer(
op="NonMaxSuppression",
name="batched_nms",
inputs=op_inputs,
outputs=[output_selected_indices],
attrs={
"center_point_box": 0,
},
)

graph.outputs = [pred_boxes, pred_scores, output_selected_indices]

if output_predictions_format == DetectionOutputFormatMode.BATCH_FORMAT:
convert_format_graph = PickNMSPredictionsAndReturnAsBatchedResult.as_graph(
batch_size=batch_size,
num_pre_nms_predictions=num_pre_nms_predictions,
max_predictions_per_image=max_predictions_per_image,
dtype=numpy_dtype_to_torch_dtype(np.float32),
device=device,
)
graph = append_graphs(graph, convert_format_graph)
elif output_predictions_format == DetectionOutputFormatMode.FLAT_FORMAT:
convert_format_graph = PickNMSPredictionsAndReturnAsFlatResult.as_graph(
batch_size=batch_size,
num_pre_nms_predictions=num_pre_nms_predictions,
max_predictions_per_image=max_predictions_per_image,
dtype=numpy_dtype_to_torch_dtype(np.float32),
device=device,
)
graph = append_graphs(graph, convert_format_graph)
else:
raise ValueError(f"Invalid output_predictions_format: {output_predictions_format}")

# Final cleanup & save
graph.cleanup().toposort()

# iteratively_infer_shapes(graph)

model = gs.export_onnx(graph)
onnx.shape_inference.infer_shapes(model)
onnx.save(model, output_onnx_model_path)
logger.debug(f"Saved ONNX model to {output_onnx_model_path}")

# onnxsim.simplify(output_onnx_model_path)
# logger.debug(f"Ran onnxsim.simplify on {output_onnx_model_path}")
Loading