-
Notifications
You must be signed in to change notification settings - Fork 153
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PPQ] Step 0/n Support int8 conversion for NV Platform (#487)
* ptq with fx, error occurs * ptq with ppq and zeroq * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typos * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rename and add new args * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * adjust format * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change utils.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * check * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add args device * split ptq.py to three files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove all_in_one, change export2onnx name * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactoring * fix args errors and change class names * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Zhiqiang Wang <zhiqwang@foxmail.com>
- Loading branch information
1 parent
3eab7be
commit 8b578eb
Showing
6 changed files
with
542 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# ppq int8 ptq Example | ||
|
||
![Linux](https://img.shields.io/badge/Linux-FCC624?style=for-the-badge&logo=linux&logoColor=black) ![Windows](https://img.shields.io/badge/Windows-0078D6?style=for-the-badge&logo=windows&logoColor=white) | ||
|
||
The ppq int8 ptq example of `yolort`. | ||
|
||
## Dependencies | ||
|
||
- ppq | ||
- torch | ||
- OpenCV | ||
- onnx | ||
|
||
## Usage | ||
|
||
Here we will mainly discuss how to use the ppq interface, we recommend that you check out [tutorial](https://github.com/openppl-public/ppq/tree/master/ppq/samples) first. This code can be used to do the following stuff: | ||
|
||
1. Distill your calibration data (Optional: If you don't have images for calibration and bn is in your model, you can use this) | ||
|
||
``` | ||
python distill_data.py --checkpoint_path=./model/yolov5s.pt | ||
``` | ||
|
||
1. Export your custom model to onnx format | ||
|
||
``` | ||
python create_onnx.py --checkpoint_path=./model/yolov5s.pt | ||
``` | ||
|
||
1. Quantization onnx-format float model to a json file and a float model | ||
|
||
``` | ||
python ptq.py --onnx_input_path=./model/yolov5s.pt | ||
``` | ||
|
||
More details can be checked in utils.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
from pathlib import Path | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from yolort.models._checkpoint import load_from_ultralytics | ||
from yolort.models.backbone_utils import darknet_pan_backbone | ||
from yolort.models.box_head import YOLOHead | ||
|
||
|
||
class YOLO(nn.Module): | ||
def __init__(self, backbone: nn.Module, strides, num_anchors, num_classes: int): | ||
super().__init__() | ||
self.backbone = backbone | ||
self.head = YOLOHead(backbone.out_channels, num_anchors, strides, num_classes) | ||
|
||
def forward(self, samples): | ||
|
||
# get the features from the backbone | ||
features = self.backbone(samples) | ||
|
||
# compute the yolo heads outputs using the features | ||
head_outputs = self.head(features) | ||
return head_outputs | ||
|
||
|
||
class ModelWrapper(torch.nn.Module): | ||
""" | ||
Wrapper class for model with dict/list rvalues. | ||
""" | ||
|
||
def __init__(self, model: torch.nn.Module) -> None: | ||
""" | ||
Init call. | ||
""" | ||
super().__init__() | ||
self.model = model | ||
|
||
def forward(self, input_x): | ||
""" | ||
Wrap forward call. | ||
""" | ||
data = self.model(input_x) | ||
|
||
if isinstance(data, dict): | ||
data_named_tuple = namedtuple("ModelEndpoints", sorted(data.keys())) # type: ignore | ||
data = data_named_tuple(**data) # type: ignore | ||
|
||
elif isinstance(data, list): | ||
data = tuple(data) | ||
|
||
return data | ||
|
||
|
||
def make_model(checkpoint_path, version): | ||
|
||
model_info = load_from_ultralytics(checkpoint_path, version=version) | ||
|
||
backbone_name = f"darknet_{model_info['size']}_{version.replace('.', '_')}" | ||
depth_multiple = model_info["depth_multiple"] | ||
width_multiple = model_info["width_multiple"] | ||
use_p6 = model_info["use_p6"] | ||
backbone = darknet_pan_backbone( | ||
backbone_name, depth_multiple, width_multiple, version=version, use_p6=use_p6 | ||
) | ||
strides = model_info["strides"] | ||
num_anchors = len(model_info["anchor_grids"][0]) // 2 | ||
num_classes = model_info["num_classes"] | ||
model = YOLO(backbone, strides, num_anchors, num_classes) | ||
|
||
model.load_state_dict(model_info["state_dict"]) | ||
model = ModelWrapper(model) | ||
|
||
model = model.eval() | ||
|
||
return model | ||
|
||
|
||
def main(): | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser("ptq tool.", add_help=True) | ||
|
||
parser.add_argument( | ||
"--checkpoint_path", type=str, default="yolov5s.pt", help="The path of checkpoint weights" | ||
) | ||
parser.add_argument("--version", type=str, default="r6.0", help="opset version") | ||
parser.add_argument("--threshold", type=float, default=0.25, help="threshold") | ||
parser.add_argument("--device", type=str, default="cuda", help="opset version") | ||
parser.add_argument("--input_size", default=[3, 640, 640], type=int, help="input size") | ||
parser.add_argument("--opset_version", type=int, default=11, help="opset version") | ||
parser.add_argument("--onnx_input_name", type=str, default="dummy_input", help="onnx input name") | ||
parser.add_argument("--onnx_output_name", type=str, default="dummy_output", help="onnx output name") | ||
parser.add_argument("--onnx_output_path", type=str, default="yolov5.onnx", help="onnx output name") | ||
|
||
args = parser.parse_args() | ||
|
||
print(f"Command Line Args: {args}") | ||
|
||
# model initilize | ||
checkpoint_path = Path(args.checkpoint_path) | ||
assert checkpoint_path.exists(), f"Not found checkpoint file at '{checkpoint_path}'" | ||
model = make_model(checkpoint_path, args.version) | ||
model = model.to(args.device) | ||
model.eval() | ||
|
||
# Export torch checkpoint to ONNX | ||
dummy_inputs = torch.randn([1] + args.input_size, device=args.device) | ||
torch.onnx.export( | ||
model, | ||
dummy_inputs, | ||
args.onnx_output_path, | ||
args.opset_version, | ||
do_constant_folding=False, | ||
input_names=[args.onnx_input_name], | ||
output_names=[args.onnx_output_name], | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import argparse | ||
import os | ||
import shutil | ||
from pathlib import Path | ||
|
||
from create_onnx import make_model | ||
|
||
from utils import get_distill_data | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser("ptq tool.", add_help=True) | ||
|
||
parser.add_argument( | ||
"--checkpoint_path", type=str, default="yolov5s.pt", help="The path of checkpoint weights" | ||
) | ||
parser.add_argument("--input_size", default=[3, 640, 640], type=int, help="input size") | ||
parser.add_argument("--batch_size", default=1, type=int, help="batch size") | ||
parser.add_argument("--version", type=str, default="r6.0", help="opset version") | ||
parser.add_argument("--threshold", type=float, default=0.25, help="threshold") | ||
parser.add_argument("--device", type=str, default="cuda", help="opset version") | ||
parser.add_argument( | ||
"--regenerate_data", | ||
type=int, | ||
default=1, | ||
help="if you wangt to generate new data in place of old data", | ||
) | ||
parser.add_argument( | ||
"--distilled_data_path", type=str, default="./distilled_data/", help="The path of distilled data" | ||
) | ||
parser.add_argument( | ||
"--calibration_data_path", | ||
type=str, | ||
default="./distilled_data/", | ||
help="The path of calibration data, if zeroq is not used, you should set it", | ||
) | ||
parser.add_argument("--distill_iterations", type=int, default=50, help="distill iterations") | ||
parser.add_argument("--num_batches", type=int, default=10, help="num of batches") | ||
|
||
args = parser.parse_args() | ||
print(f"Command Line Args: {args}") | ||
|
||
# model initilize | ||
checkpoint_path = Path(args.checkpoint_path) | ||
assert checkpoint_path.exists(), f"Not found checkpoint file at '{checkpoint_path}'" | ||
model = make_model(checkpoint_path, args.version) | ||
model = model.to(args.device) | ||
model.eval() | ||
|
||
# distill data | ||
distilled_data_path = Path(args.distilled_data_path) | ||
args.calibration_data_path = distilled_data_path | ||
if args.regenerate_data and os.path.exists(distilled_data_path): | ||
shutil.rmtree(distilled_data_path) | ||
if not os.path.exists(distilled_data_path): | ||
os.makedirs(distilled_data_path) | ||
imgs_lists = os.listdir(distilled_data_path) | ||
sorted(imgs_lists) | ||
|
||
if len(imgs_lists) < args.num_batches: | ||
args.num_batches = args.num_batches - len(imgs_lists) | ||
get_distill_data( | ||
args.distilled_data_path, | ||
model, | ||
args.input_size, | ||
args.batch_size, | ||
len(imgs_lists) + 1, | ||
args.distill_iterations, | ||
args.num_batches, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import argparse | ||
import os | ||
|
||
import onnx | ||
import ppq.lib as PFL | ||
import torch | ||
|
||
from onnxsim import simplify | ||
|
||
from ppq import graphwise_error_analyse, TargetPlatform, TorchExecutor | ||
from ppq.api import export_ppq_graph, load_onnx_graph | ||
from ppq.quantization.optim import ( | ||
ParameterQuantizePass, | ||
PassiveParameterQuantizePass, | ||
QuantAlignmentPass, | ||
QuantizeFusionPass, | ||
QuantizeSimplifyPass, | ||
RuntimeCalibrationPass, | ||
) | ||
|
||
from utils import collate_fn, prepare_data_loaders | ||
|
||
|
||
PLATFORM = TargetPlatform.TRT_INT8 | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser("ptq tool.", add_help=True) | ||
|
||
parser.add_argument( | ||
"--calibration_data_path", | ||
type=str, | ||
default="./distilled_data/", | ||
help="The path of calibration data, if zeroq is not used, you should set it", | ||
) | ||
parser.add_argument("--show_error_cal", type=int, default=1, help="flag to show error cal") | ||
|
||
parser.add_argument("--input_size", default=[3, 640, 640], type=int, help="input size") | ||
|
||
parser.add_argument("--calib_steps", type=int, default=64, help="opset version") | ||
|
||
parser.add_argument("--onnx_input_path", type=str, default="./model/yolov5.onnx", help="onnx input path") | ||
|
||
parser.add_argument( | ||
"--quantized_res_path", | ||
type=str, | ||
default="./model/", | ||
help="quantized outputs", | ||
) | ||
|
||
parser.add_argument("--device", type=str, default="cuda", help="opset version") | ||
|
||
args = parser.parse_args() | ||
|
||
sim_onnx_output_name = "sim_float_yolov5.onnx" | ||
quantized_onnx_output_name = "quantized_float_yolov5.onnx" | ||
quantized_json_output_name = "quantized_json_yolov5.onnx" | ||
|
||
print(f"Command Line Args: {args}") | ||
|
||
# quantization | ||
onnx_model = onnx.load(args.onnx_input_path) | ||
simplified, _ = simplify(onnx_model) | ||
onnx.save(simplified, os.path.join(args.quantized_res_path, sim_onnx_output_name)) | ||
graph = load_onnx_graph(os.path.join(args.quantized_res_path, sim_onnx_output_name)) | ||
|
||
quantizer = PFL.Quantizer(platform=TargetPlatform.TRT_INT8, graph=graph) | ||
dispatching = {op.name: TargetPlatform.FP32 for op in graph.operations.values()} | ||
# dataloader | ||
dataloader = prepare_data_loaders(args.calibration_data_path, args.input_size) | ||
# 从第一个卷积到最后的卷积中间的所有算子量化,其他算子不量化 | ||
from ppq.IR import SearchableGraph | ||
|
||
search_engine = SearchableGraph(graph) | ||
for op in search_engine.opset_matching( | ||
sp_expr=lambda x: x.type == "Conv", | ||
rp_expr=lambda x, y: True, | ||
ep_expr=lambda x: x.type == "Conv", | ||
direction="down", | ||
): | ||
dispatching[op.name] = TargetPlatform.TRT_INT8 | ||
|
||
# 为算子初始化量化信息 | ||
for op in graph.operations.values(): | ||
if dispatching[op.name] == TargetPlatform.TRT_INT8: | ||
quantizer.quantize_operation(op_name=op.name, platform=dispatching[op.name]) | ||
|
||
# 初始化执行器 | ||
executor = TorchExecutor(graph=graph, device=args.device) | ||
executor.tracing_operation_meta(inputs=torch.zeros(size=[1] + args.input_size).cuda()) | ||
executor.load_graph(graph=graph) | ||
|
||
# 创建优化管线 | ||
pipeline = PFL.Pipeline( | ||
[ | ||
QuantizeSimplifyPass(), | ||
QuantizeFusionPass(activation_type=quantizer.activation_fusion_types), | ||
ParameterQuantizePass(), | ||
RuntimeCalibrationPass(), | ||
PassiveParameterQuantizePass(), | ||
QuantAlignmentPass(force_overlap=True), | ||
# 微调你的网络 | ||
# LearnedStepSizePass(steps=1500) | ||
# 如果需要训练网络,训练过程必须发生在ParameterBakingPass之前 | ||
# ParameterBakingPass() | ||
] | ||
) | ||
|
||
# 调用管线完成量化 | ||
pipeline.optimize( | ||
graph=graph, | ||
dataloader=dataloader, | ||
verbose=True, | ||
calib_steps=args.calib_steps, | ||
collate_fn=collate_fn, | ||
executor=executor, | ||
) | ||
|
||
if args.show_error_cal: | ||
graphwise_error_analyse( | ||
graph=graph, | ||
running_device=args.device, | ||
dataloader=dataloader, | ||
collate_fn=collate_fn, | ||
) | ||
|
||
export_ppq_graph( | ||
graph=graph, | ||
platform=TargetPlatform.TRT_INT8, | ||
graph_save_to=os.path.join(args.quantized_res_path, quantized_onnx_output_name), | ||
config_save_to=os.path.join(args.quantized_res_path, quantized_json_output_name), | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# quantization -------------------------------- | ||
ppq>=0.6.6 |
Oops, something went wrong.