Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 1144 new export in quantization #1511

Merged
merged 6 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,12 @@ def modify_params_for_qat(

# Q/DQ Layers take a lot of space for activations in training mode
if get_param(quantization_params, "selective_quantizer_params") and get_param(quantization_params["selective_quantizer_params"], "learn_amax"):
train_dataloader_params["batch_size"] //= batch_size_divisor
val_dataloader_params["batch_size"] //= batch_size_divisor
train_dataloader_params["batch_size"] = max(1, train_dataloader_params["batch_size"] // batch_size_divisor)
val_dataloader_params["batch_size"] = max(1, val_dataloader_params["batch_size"] // batch_size_divisor)

logger.warning(f"New dataset_params.train_dataloader_params.batch_size: {train_dataloader_params['batch_size']}")
logger.warning(f"New dataset_params.val_dataloader_params.batch_size: {val_dataloader_params['batch_size']}")
training_hyperparams["max_epochs"] //= max_epochs_divisor
training_hyperparams["max_epochs"] = max(1, training_hyperparams["max_epochs"] // max_epochs_divisor)
logger.warning(f"New number of epochs: {training_hyperparams['max_epochs']}")
training_hyperparams["initial_lr"] *= lr_decay_factor
if get_param(training_hyperparams, "warmup_initial_lr") is not None:
Expand Down
56 changes: 45 additions & 11 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import inspect
import os
import typing
from copy import deepcopy
from typing import Union, Tuple, Mapping, Dict, Any, List, Optional

Expand Down Expand Up @@ -99,6 +100,8 @@
from super_gradients.common.environment.cfg_utils import load_experiment_cfg, add_params_to_cfg, load_recipe
from super_gradients.common.factories.pre_launch_callbacks_factory import PreLaunchCallbacksFactory
from super_gradients.training.params import TrainingParams
from super_gradients.module_interfaces import ExportableObjectDetectionModel
from super_gradients.conversion import ExportQuantizationMode

logger = get_logger(__name__)

Expand Down Expand Up @@ -2518,10 +2521,27 @@ def ptq(
:return: Validation results of the calibrated model.
"""

logger.debug("Performing post-training quantization (PTQ)...")
logger.debug(f"Experiment name {self.experiment_name}")

run_id = core_utils.get_param(self.training_params, "run_id", None)
logger.debug(f"Experiment run id {run_id}")

self.checkpoints_dir_path = get_checkpoints_dir_path(ckpt_root_dir=self.ckpt_root_dir, experiment_name=self.experiment_name, run_id=run_id)
logger.debug(f"Checkpoints directory {self.checkpoints_dir_path}")

os.makedirs(self.checkpoints_dir_path, exist_ok=True)

from super_gradients.training.utils.quantization.fix_pytorch_quantization_modules import patch_pytorch_quantization_modules_if_needed

patch_pytorch_quantization_modules_if_needed()

if quantization_params is None:
quantization_params = load_recipe("quantization_params/default_quantization_params").quantization_params
logger.info(f"Using default quantization params: {quantization_params}")

model = unwrap_model(model) # Unwrap model in case it is wrapped with DataParallel or DistributedDataParallel

selective_quantizer_params = get_param(quantization_params, "selective_quantizer_params")
calib_params = get_param(quantization_params, "calib_params")
model.to(device_config.device)
Expand Down Expand Up @@ -2559,17 +2579,31 @@ def ptq(
logger.info("\n".join(results))

input_shape = next(iter(valid_loader))[0].shape
os.makedirs(self.checkpoints_dir_path, exist_ok=True)
qdq_onnx_path = os.path.join(self.checkpoints_dir_path, f"{self.experiment_name}_{'x'.join((str(x) for x in input_shape))}_ptq.onnx")

# TODO: modify SG's convert_to_onnx for quantized models and use it instead
export_quantized_module_to_onnx(
model=model.cpu(),
onnx_filename=qdq_onnx_path,
input_shape=input_shape,
input_size=input_shape,
train=False,
deepcopy_model=deepcopy_model_for_export,
input_shape_with_batch_size_one = tuple([1] + list(input_shape[1:]))
qdq_onnx_path = os.path.join(
self.checkpoints_dir_path, f"{self.experiment_name}_{'x'.join((str(x) for x in input_shape_with_batch_size_one))}_ptq.onnx"
)
logger.debug(f"Output ONNX file path {qdq_onnx_path}")

if isinstance(model, ExportableObjectDetectionModel):
model: ExportableObjectDetectionModel = typing.cast(ExportableObjectDetectionModel, model)
export_result = model.export(
output=qdq_onnx_path,
quantization_mode=ExportQuantizationMode.INT8,
input_image_shape=(input_shape_with_batch_size_one[2], input_shape_with_batch_size_one[3]),
preprocessing=False,
postprocessing=True,
)
logger.info(repr(export_result))
else:
# TODO: modify SG's convert_to_onnx for quantized models and use it instead
export_quantized_module_to_onnx(
model=model.cpu(),
onnx_filename=qdq_onnx_path,
input_shape=input_shape_with_batch_size_one,
input_size=input_shape_with_batch_size_one,
train=False,
deepcopy_model=deepcopy_model_for_export,
)

return valid_metrics_dict