Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 1386 granular control over export in ptq and qat #1879

Merged
Merged
3,622 changes: 444 additions & 3,178 deletions notebooks/yolo_nas_custom_dataset_fine_tuning_with_qat.ipynb

Large diffs are not rendered by default.

17 changes: 13 additions & 4 deletions src/super_gradients/common/deprecate.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,12 @@ def deprecated_parameter(parameter_name: str, deprecated_since: str, removed_fro

def decorator(func: callable) -> callable:
argspec = inspect.getfullargspec(func)
argument_index = argspec.args.index(parameter_name)
# This check is necessary for methods with star-signature foo(*, a,b,c)
# For such methods argspec.args is actually empty and argspec.kwonlyargs contains the parameter names
if "parameter_name" in argspec.args:
argument_index = argspec.args.index(parameter_name)
else:
argument_index = None

default_value = None
sig = inspect.signature(func)
Expand All @@ -126,9 +131,13 @@ def wrapper(*args, **kwargs):

# Try to get the actual value from the arguments
# Have to check both positional and keyword arguments
try:
value = args[argument_index]
except IndexError:
if argument_index is not None:
try:
value = args[argument_index]
except IndexError:
if parameter_name in kwargs:
value = kwargs[parameter_name]
else:
if parameter_name in kwargs:
value = kwargs[parameter_name]

Expand Down
3 changes: 2 additions & 1 deletion src/super_gradients/conversion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .conversion_enums import ExportTargetBackend, ExportQuantizationMode, DetectionOutputFormatMode
from .export_params import ExportParams

__all__ = ["ExportQuantizationMode", "DetectionOutputFormatMode", "ExportTargetBackend"]
__all__ = ["ExportQuantizationMode", "DetectionOutputFormatMode", "ExportTargetBackend", "ExportParams"]
74 changes: 74 additions & 0 deletions src/super_gradients/conversion/export_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import dataclasses
from typing import Optional, Tuple

from super_gradients.conversion.conversion_enums import ExportTargetBackend, DetectionOutputFormatMode


@dataclasses.dataclass
class ExportParams:
"""
Parameters for exporting a model to ONNX format in PTQ/QAT methods of Trainer.
Most of the parameters are related ot ExportableObjectDetectionModel.export method.

:param output_onnx_path: The path to save the ONNX model.
If None, the ONNX filename will use current experiment dir folder
and the output filename will reflect model input shape & whether it's a PTQ or QAT model.

:param batch_size: The batch size for the ONNX model. Default is 1.

:param input_image_shape: The input image shape (rows, cols) for the ONNX model.
If None, the input shape will be inferred from the model or validation dataloader.

:param preprocessing: If True, the preprocessing will be included in the ONNX model.
This option is only available for models that support model.export() syntax.

:param postprocessing: If True, the postprocessing will be included in the ONNX model.
This option is only available for models that support model.export() syntax.

:param confidence_threshold: The confidence threshold for object detection models.
This option is only available for models that support model.export() syntax.
If None, the default confidence threshold for a given model will be used.
:param onnx_export_kwargs: (dict) Optional keyword arguments for torch.onnx.export() function.
:param onnx_simplify: (bool) If True, apply onnx-simplifier to the exported model.

:param detection_nms_iou_threshold: (float) A IoU threshold for the NMS step.
Relevant only for object detection models and only if postprocessing is True.
Default to None, which means the default value for a given model will be used.

:param detection_max_predictions_per_image: (int) Maximum number of predictions per image.
Relevant only for object detection models and only if postprocessing is True.

:param detection_num_pre_nms_predictions: (int) Number of predictions to keep before NMS.
Relevant only for object detection models and only if postprocessing is True.

:param detection_predictions_format: (DetectionOutputFormatMode) Format of the output predictions of detection models.
Possible values:
DetectionOutputFormatMode.BATCH_FORMAT - A tuple of 4 tensors will be returned
(num_detections, detection_boxes, detection_scores, detection_classes)
- A tensor of [batch_size, 1] containing the image indices for each detection.
- A tensor of [batch_size, max_output_boxes, 4] containing the bounding box coordinates for each detection in [x1, y1, x2, y2] format.
- A tensor of [batch_size, max_output_boxes] containing the confidence scores for each detection.
- A tensor of [batch_size, max_output_boxes] containing the class indices for each detection.

DetectionOutputFormatMode.FLAT_FORMAT - Tensor of shape [N, 7], where N is the total number of
predictions in the entire batch.
Each row will contain [image_index, x1, y1, x2, y2, class confidence, class index] values.
Relevant only for object detection models and only if postprocessing is True.
"""

output_onnx_path: Optional[str] = None
engine: Optional[ExportTargetBackend] = None
batch_size: int = 1
input_image_shape: Optional[Tuple[int, int]] = None
preprocessing: bool = True
postprocessing: bool = True
confidence_threshold: Optional[float] = None

onnx_export_kwargs: Optional[dict] = None
onnx_simplify: bool = True

# These are only relevant for object detection and pose estimation models
detection_nms_iou_threshold: Optional[float] = None
detection_max_predictions_per_image: Optional[int] = None
detection_predictions_format: DetectionOutputFormatMode = DetectionOutputFormatMode.BATCH_FORMAT
detection_num_pre_nms_predictions: int = 1000
6 changes: 5 additions & 1 deletion src/super_gradients/module_interfaces/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .module_interfaces import HasPredict, HasPreprocessingParams, SupportsReplaceNumClasses, SupportsReplaceInputChannels, SupportsFineTune
from .exceptions import ModelHasNoPreprocessingParamsException
from .exportable_detector import ExportableObjectDetectionModel, AbstractObjectDetectionDecodingModule
from .exportable_detector import ExportableObjectDetectionModel, AbstractObjectDetectionDecodingModule, ObjectDetectionModelExportResult
from .exportable_pose_estimation import ExportablePoseEstimationModel, PoseEstimationModelExportResult, AbstractPoseEstimationDecodingModule
from .pose_estimation_post_prediction_callback import AbstractPoseEstimationPostPredictionCallback, PoseEstimationPredictions
from .supports_input_shape_check import SupportsInputShapeCheck
from .quantization_result import QuantizationResult


__all__ = [
"HasPredict",
Expand All @@ -20,4 +22,6 @@
"AbstractPoseEstimationDecodingModule",
"SupportsFineTune",
"SupportsInputShapeCheck",
"ObjectDetectionModelExportResult",
"QuantizationResult",
]
13 changes: 11 additions & 2 deletions src/super_gradients/module_interfaces/exportable_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@

logger = get_logger(__name__)

__all__ = ["ExportableObjectDetectionModel", "AbstractObjectDetectionDecodingModule", "ModelExportResult", "ModelHasNoPreprocessingParamsException"]
__all__ = [
"ExportableObjectDetectionModel",
"AbstractObjectDetectionDecodingModule",
"ModelExportResult",
"ModelHasNoPreprocessingParamsException",
"ObjectDetectionModelExportResult",
]


class ModelHasNoPreprocessingParamsException(Exception):
Expand Down Expand Up @@ -92,7 +98,7 @@ def get_num_pre_nms_predictions(self) -> int:


@dataclasses.dataclass
class ModelExportResult:
class ObjectDetectionModelExportResult:
"""
A dataclass that holds the result of model export.
"""
Expand All @@ -113,6 +119,9 @@ def __repr__(self):
return self.usage_instructions


ModelExportResult = ObjectDetectionModelExportResult # Alias for backward compatibility


class ExportableObjectDetectionModel:
"""
A mixin class that adds export functionality to the object detection models.
Expand Down
16 changes: 16 additions & 0 deletions src/super_gradients/module_interfaces/quantization_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import dataclasses
from typing import Union, Dict

from torch import nn

from super_gradients.module_interfaces import PoseEstimationModelExportResult, ObjectDetectionModelExportResult

__all__ = ["QuantizationResult"]


@dataclasses.dataclass
class QuantizationResult:
quantized_model: nn.Module
output_onnx_path: str
valid_metrics_dict: Dict[str, float]
export_result: Union[None, ObjectDetectionModelExportResult, PoseEstimationModelExportResult] = None
Loading
Loading