Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 1386 granular control over export in ptq and qat #1879

Merged
181 changes: 141 additions & 40 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import dataclasses
import inspect
import os
import typing
Expand Down Expand Up @@ -104,6 +105,8 @@
from super_gradients.training.params import TrainingParams
from super_gradients.module_interfaces import ExportableObjectDetectionModel, SupportsInputShapeCheck
from super_gradients.conversion import ExportQuantizationMode
from super_gradients.common.deprecate import deprecated_parameter
from super_gradients.training.utils.export_utils import infer_image_shape_from_model, infer_image_input_channels

logger = get_logger(__name__)

Expand All @@ -121,6 +124,47 @@
_imported_pytorch_quantization_failure = import_err


@dataclasses.dataclass
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
class PTQResult:
quantized_model: nn.Module
output_onnx_path: str
valid_metrics_dict: Dict[str, float]


@dataclasses.dataclass
class QATResult:
quantized_model: nn.Module
output_onnx_path: str
valid_metrics_dict: Dict[str, float]


@dataclasses.dataclass
class ExportParams:
"""
Parameters for the export function.

:param output_onnx_path: The path to save the ONNX model.
If None, the ONNX filename will use current experiment dir folder
and the output filename will reflect model input shape & whether it's a PTQ or QAT model.

:param preprocessing: If True, the preprocessing will be included in the ONNX model.
This option is only available for models that support model.export() syntax.

:param postprocessing: If True, the postprocessing will be included in the ONNX model.
This option is only available for models that support model.export() syntax.

:param confidence_threshold: The confidence threshold for object detection models.
This option is only available for models that support model.export() syntax.
If None, the default confidence threshold for a given model will be used.
"""

output_onnx_path: Optional[str] = None
batch_size: int = 1
preprocessing: bool = True
postprocessing: bool = True
confidence_threshold: Optional[float] = None


class Trainer:
"""
SuperGradient Model - Base Class for Sg Models
Expand Down Expand Up @@ -2363,7 +2407,7 @@ def _init_loss_logging_names(self, loss_logging_items):
self.loss_logging_items_names = [criterion_name]

@classmethod
def quantize_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tuple]:
def quantize_from_config(cls, cfg: Union[DictConfig, dict]) -> Union[PTQResult, QATResult]:
"""
Perform quantization aware training (QAT) according to a recipe configuration.

Expand All @@ -2378,7 +2422,8 @@ def quantize_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module,
a train data laoder with the validation transforms is used for calibration.

:param cfg: The parsed DictConfig object from yaml recipe files or a dictionary.
:return: A tuple containing the quantized model and the output of trainer.train() method.
:return: Returns an instaned of PTQResult or QATResult that contains quantized model instance, ONNX path
and other relevant information.

:raises ValueError: If the recipe does not have the required key `quantization_params` or
`checkpoint_params.checkpoint_path` in it.
Expand All @@ -2396,12 +2441,14 @@ def quantize_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module,
cfg = cls._trigger_cfg_modifying_callbacks(cfg)

quantization_params = get_param(cfg, "quantization_params")

if quantization_params is None:
logger.warning("Your recipe does not include quantization_params. Using default quantization params.")
quantization_params = load_recipe("quantization_params/default_quantization_params").quantization_params
cfg.quantization_params = quantization_params

export_params = get_param(cfg, "export_params", {})
export_params = ExportParams(**export_params)

if get_param(cfg.checkpoint_params, "checkpoint_path") is None and get_param(cfg.checkpoint_params, "pretrained_weights") is None:
raise ValueError("Starting checkpoint / pretrained weights are a must for QAT finetuning.")

Expand Down Expand Up @@ -2474,6 +2521,7 @@ def quantize_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module,
quantization_params=quantization_params,
valid_loader=val_dataloader,
valid_metrics_list=cfg.training_hyperparams.valid_metrics_list,
export_params=export_params,
)
else:
res = trainer.qat(
Expand All @@ -2484,9 +2532,10 @@ def quantize_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module,
train_loader=train_dataloader,
training_params=cfg.training_hyperparams,
additional_qat_configs_to_log=recipe_logged_cfg,
export_params=export_params,
)

return model, res
return res

def qat(
self,
Expand All @@ -2498,7 +2547,8 @@ def qat(
quantization_params: Mapping = None,
additional_qat_configs_to_log: Dict = None,
valid_metrics_list: List[Metric] = None,
):
export_params: ExportParams = None,
) -> QATResult:
"""
Performs post-training quantization (PTQ), and then quantization-aware training (QAT).
Exports the ONNX models (ckpt_best.pth of QAT and the calibrated model) to the checkpoints directory.
Expand Down Expand Up @@ -2547,24 +2597,23 @@ def qat(
:param valid_metrics_list: (list(torchmetrics.Metric)) metrics list for evaluation of the calibrated model.
When None, the validation metrics from training_params are used). (default=None).

:return: Validation results of the QAT model in case quantization_params['ptq_only']=False and of the PTQ
model otherwise.
:return: An instance of QATResult containing the quantized model, the ONNX path and other relevant information.
"""

if quantization_params is None:
quantization_params = load_recipe("quantization_params/default_quantization_params").quantization_params
logger.info(f"Using default quantization params: {quantization_params}")
valid_metrics_list = valid_metrics_list or get_param(training_params, "valid_metrics_list")

_ = self.ptq(
ptq_result = self.ptq(
calib_loader=calib_loader,
model=model,
quantization_params=quantization_params,
valid_loader=valid_loader,
valid_metrics_list=valid_metrics_list,
deepcopy_model_for_export=True,
export_params=None, # Do not export PTQ model
)
# TRAIN
model = ptq_result.quantized_model
model.train()
torch.cuda.empty_cache()

Expand All @@ -2577,30 +2626,39 @@ def qat(
)

# EXPORT QUANTIZED MODEL TO ONNX
input_shape = next(iter(valid_loader))[0].shape
os.makedirs(self.checkpoints_dir_path, exist_ok=True)
qdq_onnx_path = os.path.join(self.checkpoints_dir_path, f"{self.experiment_name}_{'x'.join((str(x) for x in input_shape))}_qat.onnx")

# TODO: modify SG's convert_to_onnx for quantized models and use it instead
export_quantized_module_to_onnx(
model=model.cpu(),
onnx_filename=qdq_onnx_path,
input_shape=input_shape,
input_size=input_shape,
train=False,
)
logger.info(f"Exported QAT ONNX to {qdq_onnx_path}")
return res
if export_params is not None:
input_shape_from_loader = tuple(map(int, next(iter(valid_loader))[0].shape))
input_shape_with_batch_size_one = (1,) + input_shape_from_loader[1:]
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved

if export_params.output_onnx_path is None:
export_params.output_onnx_path = os.path.join(
self.checkpoints_dir_path, f"{self.experiment_name}_{'x'.join((str(x) for x in input_shape_with_batch_size_one))}_qat.onnx"
)
self._export_quantized_model(model, export_params, input_shape_from_loader)
output_onnx_path = export_params.output_onnx_path
logger.info(f"Exported QAT ONNX to {output_onnx_path}")
else:
output_onnx_path = None

return QATResult(quantized_model=model, onnx_path=output_onnx_path, valid_metrics_dict=res)

@deprecated_parameter(
"deepcopy_model_for_export",
deprecated_since="3.6.1",
removed_from="3.8.0",
reason="This parameter is no longer used. A ptq() method will always make a deepcopy of the model.",
)
def ptq(
self,
calib_loader: DataLoader,
model: nn.Module,
valid_loader: DataLoader,
valid_metrics_list: List[torchmetrics.Metric],
quantization_params: Dict = None,
deepcopy_model_for_export: bool = False,
):
export_params: ExportParams = None,
deepcopy_model_for_export=None,
) -> PTQResult:
"""
Performs post-training quantization (calibration of the model)..

Expand Down Expand Up @@ -2643,6 +2701,12 @@ def ptq(

:return: Validation results of the calibrated model.
"""
if deepcopy_model_for_export is False:
raise RuntimeError(
"deepcopy_model_for_export=False is not supported. "
"A deepcopy_model_for_export is always considered True and the input model is not modified in-place anymore."
"If you need an acess to the quantized model object use `quantized_model` attribute of the return value of the ptq() call."
)

logger.debug("Performing post-training quantization (PTQ)...")
logger.debug(f"Experiment name {self.experiment_name}")
Expand All @@ -2664,6 +2728,7 @@ def ptq(
logger.info(f"Using default quantization params: {quantization_params}")

model = unwrap_model(model) # Unwrap model in case it is wrapped with DataParallel or DistributedDataParallel
model = copy.deepcopy(model) # Deepcopy model to avoid modifying the original model
model = model.to(device_config.device).eval()

selective_quantizer_params = get_param(quantization_params, "selective_quantizer_params")
Expand Down Expand Up @@ -2700,32 +2765,68 @@ def ptq(
results += [f" - {metric:10}: {value}" for metric, value in valid_metrics_dict.items()]
logger.info("\n".join(results))

input_shape = next(iter(valid_loader))[0].shape
input_shape_with_batch_size_one = tuple([1] + list(input_shape[1:]))
qdq_onnx_path = os.path.join(
self.checkpoints_dir_path, f"{self.experiment_name}_{'x'.join((str(x) for x in input_shape_with_batch_size_one))}_ptq.onnx"
if export_params is not None:
input_shape_from_loader = tuple(map(int, next(iter(valid_loader))[0].shape))
input_shape_with_batch_size_one = (1,) + input_shape_from_loader[1:]
ofrimasad marked this conversation as resolved.
Show resolved Hide resolved

if export_params.output_onnx_path is None:
export_params.output_onnx_path = os.path.join(
self.checkpoints_dir_path, f"{self.experiment_name}_{'x'.join((str(x) for x in input_shape_with_batch_size_one))}_ptq.onnx"
)
logger.debug(f"Output ONNX file path {export_params.output_onnx_path}")
self._export_quantized_model(model, export_params, input_shape_from_loader)
output_onnx_path = export_params.output_onnx_path
else:
output_onnx_path = None

return PTQResult(
quantized_model=model,
onnx_path=output_onnx_path,
validation_results=valid_metrics_dict,
)
logger.debug(f"Output ONNX file path {qdq_onnx_path}")

@staticmethod
def _export_quantized_model(model: nn.Module, export_params: ExportParams, input_shape_from_dataloader: Tuple[int, int, int, int]):
"""
Internal method to export a quantized model to ONNX. This method used internally by PTQ & QAT steps.

:param model: Quantized model
:param export_params: Parameters controlling the export process.
:param input_shape_from_dataloader: Example shape of the batch from validation DataLoader.
It may be used as an example of the input shape during ONNX export.
:return: None
"""
input_shape = export_params.input_shape
if input_shape is None:
input_shape = infer_image_shape_from_model(model)
if input_shape is None:
input_shape = input_shape_from_dataloader[2:]

input_channels = infer_image_input_channels(model)
if input_channels is not None and input_channels != input_shape_from_dataloader[1]:
logger.warning("Infered input channels does not match with the number of channels from the dataloader")

input_shape_with_explicit_batch = tuple([export_params.batch_size] + list(input_shape[1:]))

if isinstance(model, ExportableObjectDetectionModel):
model: ExportableObjectDetectionModel = typing.cast(ExportableObjectDetectionModel, model)
export_result = model.export(
output=qdq_onnx_path,
output=export_params.output_onnx_path,
quantization_mode=ExportQuantizationMode.INT8,
input_image_shape=(input_shape_with_batch_size_one[2], input_shape_with_batch_size_one[3]),
preprocessing=False,
postprocessing=True,
input_image_shape=input_shape,
preprocessing=export_params.preprocessing,
postprocessing=export_params.postprocessing,
confidence_threshold=export_params.confidence_threshold,
# TODO Add more parameters to the export_params
)
logger.info(repr(export_result))
else:
# TODO: modify SG's convert_to_onnx for quantized models and use it instead
export_quantized_module_to_onnx(
model=model.cpu(),
onnx_filename=qdq_onnx_path,
input_shape=input_shape_with_batch_size_one,
input_size=input_shape_with_batch_size_one,
onnx_filename=export_params.output_onnx_path,
input_shape=input_shape_with_explicit_batch,
input_size=input_shape_with_explicit_batch,
train=False,
deepcopy_model=deepcopy_model_for_export,
deepcopy_model=False,
)

return valid_metrics_dict
Loading