Skip to content

Commit

Permalink
Feature/sg 1144 new export in quantization (#1511)
Browse files Browse the repository at this point in the history
* PTQ with exportable detector

* PTQ with exportable detector

* Remove tmp recipes

* Revert SSD

* Change logging to DEBUG mode
  • Loading branch information
BloodAxe committed Oct 9, 2023
1 parent 811c0f5 commit 3c342bb
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,12 @@ def modify_params_for_qat(

# Q/DQ Layers take a lot of space for activations in training mode
if get_param(quantization_params, "selective_quantizer_params") and get_param(quantization_params["selective_quantizer_params"], "learn_amax"):
train_dataloader_params["batch_size"] //= batch_size_divisor
val_dataloader_params["batch_size"] //= batch_size_divisor
train_dataloader_params["batch_size"] = max(1, train_dataloader_params["batch_size"] // batch_size_divisor)
val_dataloader_params["batch_size"] = max(1, val_dataloader_params["batch_size"] // batch_size_divisor)

logger.warning(f"New dataset_params.train_dataloader_params.batch_size: {train_dataloader_params['batch_size']}")
logger.warning(f"New dataset_params.val_dataloader_params.batch_size: {val_dataloader_params['batch_size']}")
training_hyperparams["max_epochs"] //= max_epochs_divisor
training_hyperparams["max_epochs"] = max(1, training_hyperparams["max_epochs"] // max_epochs_divisor)
logger.warning(f"New number of epochs: {training_hyperparams['max_epochs']}")
training_hyperparams["initial_lr"] *= lr_decay_factor
if get_param(training_hyperparams, "warmup_initial_lr") is not None:
Expand Down
56 changes: 45 additions & 11 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import inspect
import os
import typing
from copy import deepcopy
from typing import Union, Tuple, Mapping, Dict, Any, List, Optional

Expand Down Expand Up @@ -99,6 +100,8 @@
from super_gradients.common.environment.cfg_utils import load_experiment_cfg, add_params_to_cfg, load_recipe
from super_gradients.common.factories.pre_launch_callbacks_factory import PreLaunchCallbacksFactory
from super_gradients.training.params import TrainingParams
from super_gradients.module_interfaces import ExportableObjectDetectionModel
from super_gradients.conversion import ExportQuantizationMode

logger = get_logger(__name__)

Expand Down Expand Up @@ -2518,10 +2521,27 @@ def ptq(
:return: Validation results of the calibrated model.
"""

logger.debug("Performing post-training quantization (PTQ)...")
logger.debug(f"Experiment name {self.experiment_name}")

run_id = core_utils.get_param(self.training_params, "run_id", None)
logger.debug(f"Experiment run id {run_id}")

self.checkpoints_dir_path = get_checkpoints_dir_path(ckpt_root_dir=self.ckpt_root_dir, experiment_name=self.experiment_name, run_id=run_id)
logger.debug(f"Checkpoints directory {self.checkpoints_dir_path}")

os.makedirs(self.checkpoints_dir_path, exist_ok=True)

from super_gradients.training.utils.quantization.fix_pytorch_quantization_modules import patch_pytorch_quantization_modules_if_needed

patch_pytorch_quantization_modules_if_needed()

if quantization_params is None:
quantization_params = load_recipe("quantization_params/default_quantization_params").quantization_params
logger.info(f"Using default quantization params: {quantization_params}")

model = unwrap_model(model) # Unwrap model in case it is wrapped with DataParallel or DistributedDataParallel

selective_quantizer_params = get_param(quantization_params, "selective_quantizer_params")
calib_params = get_param(quantization_params, "calib_params")
model.to(device_config.device)
Expand Down Expand Up @@ -2559,17 +2579,31 @@ def ptq(
logger.info("\n".join(results))

input_shape = next(iter(valid_loader))[0].shape
os.makedirs(self.checkpoints_dir_path, exist_ok=True)
qdq_onnx_path = os.path.join(self.checkpoints_dir_path, f"{self.experiment_name}_{'x'.join((str(x) for x in input_shape))}_ptq.onnx")

# TODO: modify SG's convert_to_onnx for quantized models and use it instead
export_quantized_module_to_onnx(
model=model.cpu(),
onnx_filename=qdq_onnx_path,
input_shape=input_shape,
input_size=input_shape,
train=False,
deepcopy_model=deepcopy_model_for_export,
input_shape_with_batch_size_one = tuple([1] + list(input_shape[1:]))
qdq_onnx_path = os.path.join(
self.checkpoints_dir_path, f"{self.experiment_name}_{'x'.join((str(x) for x in input_shape_with_batch_size_one))}_ptq.onnx"
)
logger.debug(f"Output ONNX file path {qdq_onnx_path}")

if isinstance(model, ExportableObjectDetectionModel):
model: ExportableObjectDetectionModel = typing.cast(ExportableObjectDetectionModel, model)
export_result = model.export(
output=qdq_onnx_path,
quantization_mode=ExportQuantizationMode.INT8,
input_image_shape=(input_shape_with_batch_size_one[2], input_shape_with_batch_size_one[3]),
preprocessing=False,
postprocessing=True,
)
logger.info(repr(export_result))
else:
# TODO: modify SG's convert_to_onnx for quantized models and use it instead
export_quantized_module_to_onnx(
model=model.cpu(),
onnx_filename=qdq_onnx_path,
input_shape=input_shape_with_batch_size_one,
input_size=input_shape_with_batch_size_one,
train=False,
deepcopy_model=deepcopy_model_for_export,
)

return valid_metrics_dict

0 comments on commit 3c342bb

Please sign in to comment.