Skip to content

Commit

Permalink
Feature/sg 1106 Minor fixes in export api (#1412)
Browse files Browse the repository at this point in the history
* Fix wrong type annotation for quantization_mode

* Fix initialization of example_input_image

* Added provides to stay compatible with newer onnxruntime that require provides to be present

* Fixed edge case of exporting detection models of small size and num_pre_nms_predictions that exceeds this value

* Change initilization of grid in yolox/yolonas/ppyoloe to ensure that constants in ONNX file is fp32 and not fp64.
#1392
  • Loading branch information
BloodAxe committed Aug 24, 2023
1 parent fc6ce60 commit 737c5a5
Show file tree
Hide file tree
Showing 11 changed files with 133 additions and 43 deletions.
20 changes: 10 additions & 10 deletions documentation/source/models_export.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ export_result

import onnxruntime
import numpy as np
session = onnxruntime.InferenceSession("yolo_nas_s.onnx")
session = onnxruntime.InferenceSession("yolo_nas_s.onnx", providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
inputs = [o.name for o in session.get_inputs()]
outputs = [o.name for o in session.get_outputs()]
example_input_image = np.zeros(1, 3, 640, 640).astype(np.uint8)
example_input_image = np.zeros((1, 3, 640, 640)).astype(np.uint8)
predictions = session.run(outputs, {inputs[0]: example_input_image})

Exported model has predictions in batch format:
Expand Down Expand Up @@ -117,7 +117,7 @@ image = load_image("https://deci-pretrained-models.s3.amazonaws.com/sample_image
image = cv2.resize(image, (export_result.input_image_shape[1], export_result.input_image_shape[0]))
image_bchw = np.transpose(np.expand_dims(image, 0), (0, 3, 1, 2))

session = onnxruntime.InferenceSession(export_result.output)
session = onnxruntime.InferenceSession(export_result.output, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
inputs = [o.name for o in session.get_inputs()]
outputs = [o.name for o in session.get_outputs()]
result = session.run(outputs, {inputs[0]: image_bchw})
Expand Down Expand Up @@ -337,10 +337,10 @@ export_result

import onnxruntime
import numpy as np
session = onnxruntime.InferenceSession("yolo_nas_s.onnx")
session = onnxruntime.InferenceSession("yolo_nas_s.onnx", providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
inputs = [o.name for o in session.get_inputs()]
outputs = [o.name for o in session.get_outputs()]
example_input_image = np.zeros(1, 3, 640, 640).astype(np.uint8)
example_input_image = np.zeros((1, 3, 640, 640)).astype(np.uint8)
predictions = session.run(outputs, {inputs[0]: example_input_image})

Exported model has predictions in flat format:
Expand All @@ -359,7 +359,7 @@ Now we exported a model that produces predictions in `flat` format. Let's run th


```python
session = onnxruntime.InferenceSession(export_result.output)
session = onnxruntime.InferenceSession(export_result.output, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
inputs = [o.name for o in session.get_inputs()]
outputs = [o.name for o in session.get_outputs()]
result = session.run(outputs, {inputs[0]: image_bchw})
Expand Down Expand Up @@ -437,7 +437,7 @@ export_result = model.export(
output_predictions_format = DetectionOutputFormatMode.FLAT_FORMAT
)

session = onnxruntime.InferenceSession(export_result.output)
session = onnxruntime.InferenceSession(export_result.output, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
inputs = [o.name for o in session.get_inputs()]
outputs = [o.name for o in session.get_outputs()]
result = session.run(outputs, {inputs[0]: image_bchw})
Expand Down Expand Up @@ -471,7 +471,7 @@ export_result = model.export(
quantization_mode=ExportQuantizationMode.INT8 # or ExportQuantizationMode.FP16
)

session = onnxruntime.InferenceSession(export_result.output)
session = onnxruntime.InferenceSession(export_result.output, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
inputs = [o.name for o in session.get_inputs()]
outputs = [o.name for o in session.get_outputs()]
result = session.run(outputs, {inputs[0]: image_bchw})
Expand Down Expand Up @@ -514,15 +514,15 @@ export_result = model.export(
calibration_loader=dummy_calibration_loader
)

session = onnxruntime.InferenceSession(export_result.output)
session = onnxruntime.InferenceSession(export_result.output, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
inputs = [o.name for o in session.get_inputs()]
outputs = [o.name for o in session.get_outputs()]
result = session.run(outputs, {inputs[0]: image_bchw})

show_predictions_from_flat_format(image, result)
```

25%|█████████████████████████████████████████████████ | 4/16 [00:11<00:34, 2.87s/it]
25%|█████████████████████████████████████████████████ | 4/16 [00:11<00:34, 2.90s/it]



Expand Down
Binary file modified documentation/source/models_export_files/models_export_28_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified documentation/source/models_export_files/models_export_30_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 5 additions & 5 deletions src/super_gradients/examples/model_export/models_export.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@
"image = cv2.resize(image, (export_result.input_image_shape[1], export_result.input_image_shape[0]))\n",
"image_bchw = np.transpose(np.expand_dims(image, 0), (0, 3, 1, 2))\n",
"\n",
"session = onnxruntime.InferenceSession(export_result.output)\n",
"session = onnxruntime.InferenceSession(export_result.output, providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"])\n",
"inputs = [o.name for o in session.get_inputs()]\n",
"outputs = [o.name for o in session.get_outputs()]\n",
"result = session.run(outputs, {inputs[0]: image_bchw})\n",
Expand Down Expand Up @@ -486,7 +486,7 @@
}
],
"source": [
"session = onnxruntime.InferenceSession(export_result.output)\n",
"session = onnxruntime.InferenceSession(export_result.output, providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"])\n",
"inputs = [o.name for o in session.get_inputs()]\n",
"outputs = [o.name for o in session.get_outputs()]\n",
"result = session.run(outputs, {inputs[0]: image_bchw})\n",
Expand Down Expand Up @@ -605,7 +605,7 @@
" output_predictions_format = DetectionOutputFormatMode.FLAT_FORMAT\n",
")\n",
"\n",
"session = onnxruntime.InferenceSession(export_result.output)\n",
"session = onnxruntime.InferenceSession(export_result.output, providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"])\n",
"inputs = [o.name for o in session.get_inputs()]\n",
"outputs = [o.name for o in session.get_outputs()]\n",
"result = session.run(outputs, {inputs[0]: image_bchw})\n",
Expand Down Expand Up @@ -659,7 +659,7 @@
" quantization_mode=ExportQuantizationMode.INT8 # or ExportQuantizationMode.FP16\n",
")\n",
"\n",
"session = onnxruntime.InferenceSession(export_result.output)\n",
"session = onnxruntime.InferenceSession(export_result.output, providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"])\n",
"inputs = [o.name for o in session.get_inputs()]\n",
"outputs = [o.name for o in session.get_outputs()]\n",
"result = session.run(outputs, {inputs[0]: image_bchw})\n",
Expand Down Expand Up @@ -729,7 +729,7 @@
" calibration_loader=dummy_calibration_loader\n",
")\n",
"\n",
"session = onnxruntime.InferenceSession(export_result.output)\n",
"session = onnxruntime.InferenceSession(export_result.output, providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"])\n",
"inputs = [o.name for o in session.get_inputs()]\n",
"outputs = [o.name for o in session.get_outputs()]\n",
"result = session.run(outputs, {inputs[0]: image_bchw})\n",
Expand Down
59 changes: 41 additions & 18 deletions src/super_gradients/module_interfaces/exportable_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
from super_gradients.conversion.gs_utils import import_onnx_graphsurgeon_or_install
from super_gradients.training.utils.export_utils import infer_format_from_file_name, infer_image_shape_from_model, infer_image_input_channels
from super_gradients.training.utils.quantization.fix_pytorch_quantization_modules import patch_pytorch_quantization_modules_if_needed
from super_gradients.training.utils.utils import infer_model_device, check_model_contains_quantized_modules

from super_gradients.training.utils.utils import infer_model_device, check_model_contains_quantized_modules, infer_model_dtype

logger = get_logger(__name__)

Expand Down Expand Up @@ -58,6 +57,19 @@ def forward(self, predictions: Any) -> Tuple[Tensor, Tensor]:
"""
raise NotImplementedError(f"forward() method is not implemented for class {self.__class__.__name__}. ")

@torch.jit.ignore
def infer_total_number_of_predictions(self, predictions: Any) -> int:
"""
This method is used to infer the total number of predictions for a given input resolution.
The function takes raw predictions from the model and returns the total number of predictions.
It is needed to check whether max_predictions_per_image and num_pre_nms_predictions are not greater than
the total number of predictions for a given resolution.
:param predictions: Predictions from the model itself.
:return: A total number of predictions for a given resolution
"""
raise NotImplementedError(f"forward() method is not implemented for class {self.__class__.__name__}. ")

def get_output_names(self) -> List[str]:
"""
Returns the names of the outputs of the module.
Expand Down Expand Up @@ -130,7 +142,7 @@ def export(
confidence_threshold: Optional[float] = None,
nms_threshold: Optional[float] = None,
engine: Optional[ExportTargetBackend] = None,
quantization_mode: ExportQuantizationMode = Optional[None],
quantization_mode: Optional[ExportQuantizationMode] = None,
selective_quantizer: Optional["SelectiveQuantizer"] = None, # noqa
calibration_loader: Optional[DataLoader] = None,
calibration_method: str = "percentile",
Expand Down Expand Up @@ -344,6 +356,27 @@ def export(
num_pre_nms_predictions = postprocessing_module.num_pre_nms_predictions
max_predictions_per_image = max_predictions_per_image or num_pre_nms_predictions

dummy_input = torch.randn(input_shape).to(device=infer_model_device(model), dtype=infer_model_dtype(model))
with torch.no_grad():
number_of_predictions = postprocessing_module.infer_total_number_of_predictions(model.eval()(dummy_input))

if num_pre_nms_predictions > number_of_predictions:
logger.warning(
f"num_pre_nms_predictions ({num_pre_nms_predictions}) is greater than the total number of predictions ({number_of_predictions}) for input"
f"shape {input_shape}. Setting num_pre_nms_predictions to {number_of_predictions}"
)
num_pre_nms_predictions = number_of_predictions
# We have to re-created the postprocessing_module with the new value of num_pre_nms_predictions
postprocessing_kwargs["num_pre_nms_predictions"] = num_pre_nms_predictions
postprocessing_module: AbstractObjectDetectionDecodingModule = model.get_decoding_module(**postprocessing_kwargs)

if max_predictions_per_image > num_pre_nms_predictions:
logger.warning(
f"max_predictions_per_image ({max_predictions_per_image}) is greater than num_pre_nms_predictions ({num_pre_nms_predictions}). "
f"Setting max_predictions_per_image to {num_pre_nms_predictions}"
)
max_predictions_per_image = num_pre_nms_predictions

nms_threshold = nms_threshold or getattr(model, "_default_nms_iou", None)
if nms_threshold is None:
raise ValueError(
Expand All @@ -358,12 +391,6 @@ def export(
"Please specify the confidence_threshold explicitly: model.export(..., confidence_threshold=0.5)"
)

if max_predictions_per_image > num_pre_nms_predictions:
raise ValueError(
f"max_predictions_per_image={max_predictions_per_image} is greater than "
f"num_pre_nms_predictions={num_pre_nms_predictions}. "
f"Please specify max_predictions_per_image <= {num_pre_nms_predictions}."
)
else:
attach_nms_postprocessing = False
postprocessing_module = None
Expand Down Expand Up @@ -542,19 +569,15 @@ def export(
usage_instructions.append("")
usage_instructions.append(" import onnxruntime")
usage_instructions.append(" import numpy as np")
usage_instructions.append(f' session = onnxruntime.InferenceSession("{output}")')
usage_instructions.append(f' session = onnxruntime.InferenceSession("{output}", providers=["CUDAExecutionProvider", "CPUExecutionProvider"])')
usage_instructions.append(" inputs = [o.name for o in session.get_inputs()]")
usage_instructions.append(" outputs = [o.name for o in session.get_outputs()]")

dtype_name = np.dtype(torch_dtype_to_numpy_dtype(input_image_dtype)).name
if preprocessing:
usage_instructions.append(
f" example_input_image = np.zeros({batch_size}, {input_image_channels}, {input_image_shape[0]}, {input_image_shape[1]}).astype(np.{dtype_name})" # noqa
)
else:
usage_instructions.append(
f" example_input_image = np.zeros({batch_size}, {input_image_channels}, {input_image_shape[0]}, {input_image_shape[1]}).astype(np.{dtype_name})" # noqa
)
usage_instructions.append(
f" example_input_image = np.zeros(({batch_size}, {input_image_channels}, {input_image_shape[0]}, {input_image_shape[1]})).astype(np.{dtype_name})" # noqa
)

usage_instructions.append(" predictions = session.run(outputs, {inputs[0]: example_input_image})")
usage_instructions.append("")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import lru_cache
from typing import Union, Optional, List, Tuple
from typing import Union, Optional, List, Tuple, Any

import torch
from torch import Tensor
Expand Down Expand Up @@ -82,6 +82,20 @@ def forward(self, inputs: Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, ...]]) -> T

return output_pred_bboxes, output_pred_scores

@torch.jit.ignore
def infer_total_number_of_predictions(self, predictions: Any) -> int:
"""
:param inputs:
:return:
"""
if torch.jit.is_tracing():
pred_bboxes, pred_scores = predictions
else:
pred_bboxes, pred_scores = predictions[0]

return pred_bboxes.size(1)


class PPYoloE(SgModule, ExportableObjectDetectionModel, HasPredict):
def __init__(self, arch_params):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ def _generate_anchors(self, feats=None, dtype=None, device=None):
else:
h = int(self.eval_size[0] / stride)
w = int(self.eval_size[1] / stride)
shift_x = torch.arange(end=w) + self.grid_cell_offset
shift_y = torch.arange(end=h) + self.grid_cell_offset
shift_x = torch.arange(end=w, dtype=dtype) + self.grid_cell_offset
shift_y = torch.arange(end=h, dtype=dtype) + self.grid_cell_offset
if torch_version_is_greater_or_equal(1, 10):
shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing="ij")
else:
Expand Down
21 changes: 18 additions & 3 deletions src/super_gradients/training/models/detection_models/yolo_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import collections
import math
import warnings
from typing import Union, Type, List, Tuple, Optional
from typing import Union, Type, List, Tuple, Optional, Any
from functools import lru_cache

import numpy as np
Expand Down Expand Up @@ -282,9 +282,9 @@ def forward(self, inputs):
def _make_grid(nx: int, ny: int, dtype: torch.dtype):
if torch_version_is_greater_or_equal(1, 10):
# https://github.com/pytorch/pytorch/issues/50276
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)], indexing="ij")
yv, xv = torch.meshgrid([torch.arange(ny, dtype=dtype), torch.arange(nx, dtype=dtype)], indexing="ij")
else:
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
yv, xv = torch.meshgrid([torch.arange(ny, dtype=dtype), torch.arange(nx, dtype=dtype)])
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).to(dtype)


Expand Down Expand Up @@ -748,3 +748,18 @@ def forward(self, predictions):
output_pred_scores = pred_scores.reshape(-1, pred_scores.size(2))[flat_indices, :].reshape(pred_scores.size(0), nms_top_k, pred_scores.size(2))

return output_pred_bboxes, output_pred_scores

def get_num_pre_nms_predictions(self) -> int:
return self.num_pre_nms_predictions

@torch.jit.ignore
def infer_total_number_of_predictions(self, predictions: Any) -> int:
"""
:param inputs:
:return:
"""
if isinstance(predictions, (tuple, list)):
predictions = predictions[0]

return predictions.size(1)
Original file line number Diff line number Diff line change
Expand Up @@ -281,14 +281,16 @@ def _generate_anchors(self, feats=None, dtype=None, device=None):
else:
h = int(self.eval_size[0] / stride)
w = int(self.eval_size[1] / stride)
shift_x = torch.arange(end=w) + self.grid_cell_offset
shift_y = torch.arange(end=h) + self.grid_cell_offset

shift_x = torch.arange(end=w, dtype=dtype) + self.grid_cell_offset
shift_y = torch.arange(end=h, dtype=dtype) + self.grid_cell_offset

if torch_version_is_greater_or_equal(1, 10):
shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing="ij")
else:
shift_y, shift_x = torch.meshgrid(shift_y, shift_x)

anchor_point = torch.stack([shift_x, shift_y], dim=-1).to(dtype=dtype)
anchor_point = torch.stack([shift_x, shift_y], dim=-1)
anchor_points.append(anchor_point.reshape([-1, 2]))
stride_tensor.append(torch.full([h * w, 1], stride, dtype=dtype))
anchor_points = torch.cat(anchor_points)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import Union, Tuple, Optional
from typing import Union, Tuple, Optional, Any

import torch
from omegaconf import DictConfig
Expand Down Expand Up @@ -28,6 +28,23 @@ def __init__(
super().__init__()
self.num_pre_nms_predictions = num_pre_nms_predictions

@torch.jit.ignore
def infer_total_number_of_predictions(self, predictions: Any) -> int:
"""
:param inputs:
:return:
"""
if torch.jit.is_tracing():
pred_bboxes, pred_scores = predictions
else:
pred_bboxes, pred_scores = predictions[0]

return pred_bboxes.size(1)

def get_num_pre_nms_predictions(self) -> int:
return self.num_pre_nms_predictions

def forward(self, inputs: Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, ...]]):
if torch.jit.is_tracing():
pred_bboxes, pred_scores = inputs
Expand Down
19 changes: 19 additions & 0 deletions tests/unit_tests/export_detection_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,25 @@ def setUp(self) -> None:
this_dir = os.path.dirname(__file__)
self.test_image_path = os.path.join(this_dir, "../data/tinycoco/images/val2017/000000444010.jpg")

def test_export_model_on_small_size(self):
with tempfile.TemporaryDirectory() as tmpdirname:
for model_type in [
Models.YOLO_NAS_S,
Models.PP_YOLOE_S,
Models.YOLOX_S,
]:
out_path = os.path.join(tmpdirname, model_type + ".onnx")
ppyolo_e: ExportableObjectDetectionModel = models.get(model_type, pretrained_weights="coco")
result = ppyolo_e.export(
out_path,
input_image_shape=(64, 64),
num_pre_nms_predictions=2000,
max_predictions_per_image=1000,
output_predictions_format=DetectionOutputFormatMode.BATCH_FORMAT,
)
assert result.input_image_dtype == torch.uint8
assert result.input_image_shape == (64, 64)

def test_the_most_common_export_use_case(self):
"""
Test the most common export use case - export to ONNX with all default parameters
Expand Down

0 comments on commit 737c5a5

Please sign in to comment.