Skip to content

Commit

Permalink
Feature/sg 1386 granular control over export in ptq and qat (#1879)
Browse files Browse the repository at this point in the history
* (Work in progress) Adding a granular yet uniform control on exporting a model after PTQ/QAT

* Adding granular control on model.export during PTQ & QAT

* Intermediate fixes for ptq()

* Added additional properties to export params

* Update trainer

* Improve signature of ptq() and qat() methods

* Moving dataclasses to separate files

* Update notebook

* Move export params to super_gradients.conversion
  • Loading branch information
BloodAxe committed Mar 6, 2024
1 parent e7caf6c commit eaa0c21
Show file tree
Hide file tree
Showing 8 changed files with 729 additions and 3,261 deletions.
3,622 changes: 444 additions & 3,178 deletions notebooks/yolo_nas_custom_dataset_fine_tuning_with_qat.ipynb

Large diffs are not rendered by default.

17 changes: 13 additions & 4 deletions src/super_gradients/common/deprecate.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,12 @@ def deprecated_parameter(parameter_name: str, deprecated_since: str, removed_fro

def decorator(func: callable) -> callable:
argspec = inspect.getfullargspec(func)
argument_index = argspec.args.index(parameter_name)
# This check is necessary for methods with star-signature foo(*, a,b,c)
# For such methods argspec.args is actually empty and argspec.kwonlyargs contains the parameter names
if "parameter_name" in argspec.args:
argument_index = argspec.args.index(parameter_name)
else:
argument_index = None

default_value = None
sig = inspect.signature(func)
Expand All @@ -126,9 +131,13 @@ def wrapper(*args, **kwargs):

# Try to get the actual value from the arguments
# Have to check both positional and keyword arguments
try:
value = args[argument_index]
except IndexError:
if argument_index is not None:
try:
value = args[argument_index]
except IndexError:
if parameter_name in kwargs:
value = kwargs[parameter_name]
else:
if parameter_name in kwargs:
value = kwargs[parameter_name]

Expand Down
3 changes: 2 additions & 1 deletion src/super_gradients/conversion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .conversion_enums import ExportTargetBackend, ExportQuantizationMode, DetectionOutputFormatMode
from .export_params import ExportParams

__all__ = ["ExportQuantizationMode", "DetectionOutputFormatMode", "ExportTargetBackend"]
__all__ = ["ExportQuantizationMode", "DetectionOutputFormatMode", "ExportTargetBackend", "ExportParams"]
74 changes: 74 additions & 0 deletions src/super_gradients/conversion/export_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import dataclasses
from typing import Optional, Tuple

from super_gradients.conversion.conversion_enums import ExportTargetBackend, DetectionOutputFormatMode


@dataclasses.dataclass
class ExportParams:
"""
Parameters for exporting a model to ONNX format in PTQ/QAT methods of Trainer.
Most of the parameters are related ot ExportableObjectDetectionModel.export method.
:param output_onnx_path: The path to save the ONNX model.
If None, the ONNX filename will use current experiment dir folder
and the output filename will reflect model input shape & whether it's a PTQ or QAT model.
:param batch_size: The batch size for the ONNX model. Default is 1.
:param input_image_shape: The input image shape (rows, cols) for the ONNX model.
If None, the input shape will be inferred from the model or validation dataloader.
:param preprocessing: If True, the preprocessing will be included in the ONNX model.
This option is only available for models that support model.export() syntax.
:param postprocessing: If True, the postprocessing will be included in the ONNX model.
This option is only available for models that support model.export() syntax.
:param confidence_threshold: The confidence threshold for object detection models.
This option is only available for models that support model.export() syntax.
If None, the default confidence threshold for a given model will be used.
:param onnx_export_kwargs: (dict) Optional keyword arguments for torch.onnx.export() function.
:param onnx_simplify: (bool) If True, apply onnx-simplifier to the exported model.
:param detection_nms_iou_threshold: (float) A IoU threshold for the NMS step.
Relevant only for object detection models and only if postprocessing is True.
Default to None, which means the default value for a given model will be used.
:param detection_max_predictions_per_image: (int) Maximum number of predictions per image.
Relevant only for object detection models and only if postprocessing is True.
:param detection_num_pre_nms_predictions: (int) Number of predictions to keep before NMS.
Relevant only for object detection models and only if postprocessing is True.
:param detection_predictions_format: (DetectionOutputFormatMode) Format of the output predictions of detection models.
Possible values:
DetectionOutputFormatMode.BATCH_FORMAT - A tuple of 4 tensors will be returned
(num_detections, detection_boxes, detection_scores, detection_classes)
- A tensor of [batch_size, 1] containing the image indices for each detection.
- A tensor of [batch_size, max_output_boxes, 4] containing the bounding box coordinates for each detection in [x1, y1, x2, y2] format.
- A tensor of [batch_size, max_output_boxes] containing the confidence scores for each detection.
- A tensor of [batch_size, max_output_boxes] containing the class indices for each detection.
DetectionOutputFormatMode.FLAT_FORMAT - Tensor of shape [N, 7], where N is the total number of
predictions in the entire batch.
Each row will contain [image_index, x1, y1, x2, y2, class confidence, class index] values.
Relevant only for object detection models and only if postprocessing is True.
"""

output_onnx_path: Optional[str] = None
engine: Optional[ExportTargetBackend] = None
batch_size: int = 1
input_image_shape: Optional[Tuple[int, int]] = None
preprocessing: bool = True
postprocessing: bool = True
confidence_threshold: Optional[float] = None

onnx_export_kwargs: Optional[dict] = None
onnx_simplify: bool = True

# These are only relevant for object detection and pose estimation models
detection_nms_iou_threshold: Optional[float] = None
detection_max_predictions_per_image: Optional[int] = None
detection_predictions_format: DetectionOutputFormatMode = DetectionOutputFormatMode.BATCH_FORMAT
detection_num_pre_nms_predictions: int = 1000
6 changes: 5 additions & 1 deletion src/super_gradients/module_interfaces/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .module_interfaces import HasPredict, HasPreprocessingParams, SupportsReplaceNumClasses, SupportsReplaceInputChannels, SupportsFineTune
from .exceptions import ModelHasNoPreprocessingParamsException
from .exportable_detector import ExportableObjectDetectionModel, AbstractObjectDetectionDecodingModule
from .exportable_detector import ExportableObjectDetectionModel, AbstractObjectDetectionDecodingModule, ObjectDetectionModelExportResult
from .exportable_pose_estimation import ExportablePoseEstimationModel, PoseEstimationModelExportResult, AbstractPoseEstimationDecodingModule
from .pose_estimation_post_prediction_callback import AbstractPoseEstimationPostPredictionCallback, PoseEstimationPredictions
from .supports_input_shape_check import SupportsInputShapeCheck
from .quantization_result import QuantizationResult


__all__ = [
"HasPredict",
Expand All @@ -20,4 +22,6 @@
"AbstractPoseEstimationDecodingModule",
"SupportsFineTune",
"SupportsInputShapeCheck",
"ObjectDetectionModelExportResult",
"QuantizationResult",
]
13 changes: 11 additions & 2 deletions src/super_gradients/module_interfaces/exportable_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@

logger = get_logger(__name__)

__all__ = ["ExportableObjectDetectionModel", "AbstractObjectDetectionDecodingModule", "ModelExportResult", "ModelHasNoPreprocessingParamsException"]
__all__ = [
"ExportableObjectDetectionModel",
"AbstractObjectDetectionDecodingModule",
"ModelExportResult",
"ModelHasNoPreprocessingParamsException",
"ObjectDetectionModelExportResult",
]


class ModelHasNoPreprocessingParamsException(Exception):
Expand Down Expand Up @@ -92,7 +98,7 @@ def get_num_pre_nms_predictions(self) -> int:


@dataclasses.dataclass
class ModelExportResult:
class ObjectDetectionModelExportResult:
"""
A dataclass that holds the result of model export.
"""
Expand All @@ -113,6 +119,9 @@ def __repr__(self):
return self.usage_instructions


ModelExportResult = ObjectDetectionModelExportResult # Alias for backward compatibility


class ExportableObjectDetectionModel:
"""
A mixin class that adds export functionality to the object detection models.
Expand Down
16 changes: 16 additions & 0 deletions src/super_gradients/module_interfaces/quantization_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import dataclasses
from typing import Union, Dict

from torch import nn

from super_gradients.module_interfaces import PoseEstimationModelExportResult, ObjectDetectionModelExportResult

__all__ = ["QuantizationResult"]


@dataclasses.dataclass
class QuantizationResult:
quantized_model: nn.Module
output_onnx_path: str
valid_metrics_dict: Dict[str, float]
export_result: Union[None, ObjectDetectionModelExportResult, PoseEstimationModelExportResult] = None
Loading

0 comments on commit eaa0c21

Please sign in to comment.