Skip to content

Commit

Permalink
[PPQ] Step 0/n Support int8 conversion for NV Platform (#487)
Browse files Browse the repository at this point in the history
* ptq with fx, error occurs

* ptq with ppq and zeroq

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typos

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rename and add new args

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* adjust format

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change utils.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* check

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add args device

* split ptq.py to three files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove all_in_one, change export2onnx name

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactoring

* fix args errors and change class names

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Zhiqiang Wang <zhiqwang@foxmail.com>
  • Loading branch information
3 people committed Mar 10, 2023
1 parent 3eab7be commit 8b578eb
Show file tree
Hide file tree
Showing 6 changed files with 542 additions and 0 deletions.
36 changes: 36 additions & 0 deletions deployment/ppq/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# ppq int8 ptq Example

![Linux](https://img.shields.io/badge/Linux-FCC624?style=for-the-badge&logo=linux&logoColor=black) ![Windows](https://img.shields.io/badge/Windows-0078D6?style=for-the-badge&logo=windows&logoColor=white)

The ppq int8 ptq example of `yolort`.

## Dependencies

- ppq
- torch
- OpenCV
- onnx

## Usage

Here we will mainly discuss how to use the ppq interface, we recommend that you check out [tutorial](https://github.com/openppl-public/ppq/tree/master/ppq/samples) first. This code can be used to do the following stuff:

1. Distill your calibration data (Optional: If you don't have images for calibration and bn is in your model, you can use this)

```
python distill_data.py --checkpoint_path=./model/yolov5s.pt
```

1. Export your custom model to onnx format

```
python create_onnx.py --checkpoint_path=./model/yolov5s.pt
```

1. Quantization onnx-format float model to a json file and a float model

```
python ptq.py --onnx_input_path=./model/yolov5s.pt
```

More details can be checked in utils.py
121 changes: 121 additions & 0 deletions deployment/ppq/create_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from pathlib import Path

import torch
from torch import nn

from yolort.models._checkpoint import load_from_ultralytics
from yolort.models.backbone_utils import darknet_pan_backbone
from yolort.models.box_head import YOLOHead


class YOLO(nn.Module):
def __init__(self, backbone: nn.Module, strides, num_anchors, num_classes: int):
super().__init__()
self.backbone = backbone
self.head = YOLOHead(backbone.out_channels, num_anchors, strides, num_classes)

def forward(self, samples):

# get the features from the backbone
features = self.backbone(samples)

# compute the yolo heads outputs using the features
head_outputs = self.head(features)
return head_outputs


class ModelWrapper(torch.nn.Module):
"""
Wrapper class for model with dict/list rvalues.
"""

def __init__(self, model: torch.nn.Module) -> None:
"""
Init call.
"""
super().__init__()
self.model = model

def forward(self, input_x):
"""
Wrap forward call.
"""
data = self.model(input_x)

if isinstance(data, dict):
data_named_tuple = namedtuple("ModelEndpoints", sorted(data.keys())) # type: ignore
data = data_named_tuple(**data) # type: ignore

elif isinstance(data, list):
data = tuple(data)

return data


def make_model(checkpoint_path, version):

model_info = load_from_ultralytics(checkpoint_path, version=version)

backbone_name = f"darknet_{model_info['size']}_{version.replace('.', '_')}"
depth_multiple = model_info["depth_multiple"]
width_multiple = model_info["width_multiple"]
use_p6 = model_info["use_p6"]
backbone = darknet_pan_backbone(
backbone_name, depth_multiple, width_multiple, version=version, use_p6=use_p6
)
strides = model_info["strides"]
num_anchors = len(model_info["anchor_grids"][0]) // 2
num_classes = model_info["num_classes"]
model = YOLO(backbone, strides, num_anchors, num_classes)

model.load_state_dict(model_info["state_dict"])
model = ModelWrapper(model)

model = model.eval()

return model


def main():
import argparse

parser = argparse.ArgumentParser("ptq tool.", add_help=True)

parser.add_argument(
"--checkpoint_path", type=str, default="yolov5s.pt", help="The path of checkpoint weights"
)
parser.add_argument("--version", type=str, default="r6.0", help="opset version")
parser.add_argument("--threshold", type=float, default=0.25, help="threshold")
parser.add_argument("--device", type=str, default="cuda", help="opset version")
parser.add_argument("--input_size", default=[3, 640, 640], type=int, help="input size")
parser.add_argument("--opset_version", type=int, default=11, help="opset version")
parser.add_argument("--onnx_input_name", type=str, default="dummy_input", help="onnx input name")
parser.add_argument("--onnx_output_name", type=str, default="dummy_output", help="onnx output name")
parser.add_argument("--onnx_output_path", type=str, default="yolov5.onnx", help="onnx output name")

args = parser.parse_args()

print(f"Command Line Args: {args}")

# model initilize
checkpoint_path = Path(args.checkpoint_path)
assert checkpoint_path.exists(), f"Not found checkpoint file at '{checkpoint_path}'"
model = make_model(checkpoint_path, args.version)
model = model.to(args.device)
model.eval()

# Export torch checkpoint to ONNX
dummy_inputs = torch.randn([1] + args.input_size, device=args.device)
torch.onnx.export(
model,
dummy_inputs,
args.onnx_output_path,
args.opset_version,
do_constant_folding=False,
input_names=[args.onnx_input_name],
output_names=[args.onnx_output_name],
)


if __name__ == "__main__":
main()
74 changes: 74 additions & 0 deletions deployment/ppq/distill_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import argparse
import os
import shutil
from pathlib import Path

from create_onnx import make_model

from utils import get_distill_data


def main():
parser = argparse.ArgumentParser("ptq tool.", add_help=True)

parser.add_argument(
"--checkpoint_path", type=str, default="yolov5s.pt", help="The path of checkpoint weights"
)
parser.add_argument("--input_size", default=[3, 640, 640], type=int, help="input size")
parser.add_argument("--batch_size", default=1, type=int, help="batch size")
parser.add_argument("--version", type=str, default="r6.0", help="opset version")
parser.add_argument("--threshold", type=float, default=0.25, help="threshold")
parser.add_argument("--device", type=str, default="cuda", help="opset version")
parser.add_argument(
"--regenerate_data",
type=int,
default=1,
help="if you wangt to generate new data in place of old data",
)
parser.add_argument(
"--distilled_data_path", type=str, default="./distilled_data/", help="The path of distilled data"
)
parser.add_argument(
"--calibration_data_path",
type=str,
default="./distilled_data/",
help="The path of calibration data, if zeroq is not used, you should set it",
)
parser.add_argument("--distill_iterations", type=int, default=50, help="distill iterations")
parser.add_argument("--num_batches", type=int, default=10, help="num of batches")

args = parser.parse_args()
print(f"Command Line Args: {args}")

# model initilize
checkpoint_path = Path(args.checkpoint_path)
assert checkpoint_path.exists(), f"Not found checkpoint file at '{checkpoint_path}'"
model = make_model(checkpoint_path, args.version)
model = model.to(args.device)
model.eval()

# distill data
distilled_data_path = Path(args.distilled_data_path)
args.calibration_data_path = distilled_data_path
if args.regenerate_data and os.path.exists(distilled_data_path):
shutil.rmtree(distilled_data_path)
if not os.path.exists(distilled_data_path):
os.makedirs(distilled_data_path)
imgs_lists = os.listdir(distilled_data_path)
sorted(imgs_lists)

if len(imgs_lists) < args.num_batches:
args.num_batches = args.num_batches - len(imgs_lists)
get_distill_data(
args.distilled_data_path,
model,
args.input_size,
args.batch_size,
len(imgs_lists) + 1,
args.distill_iterations,
args.num_batches,
)


if __name__ == "__main__":
main()
136 changes: 136 additions & 0 deletions deployment/ppq/ptq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import argparse
import os

import onnx
import ppq.lib as PFL
import torch

from onnxsim import simplify

from ppq import graphwise_error_analyse, TargetPlatform, TorchExecutor
from ppq.api import export_ppq_graph, load_onnx_graph
from ppq.quantization.optim import (
ParameterQuantizePass,
PassiveParameterQuantizePass,
QuantAlignmentPass,
QuantizeFusionPass,
QuantizeSimplifyPass,
RuntimeCalibrationPass,
)

from utils import collate_fn, prepare_data_loaders


PLATFORM = TargetPlatform.TRT_INT8


def main():
parser = argparse.ArgumentParser("ptq tool.", add_help=True)

parser.add_argument(
"--calibration_data_path",
type=str,
default="./distilled_data/",
help="The path of calibration data, if zeroq is not used, you should set it",
)
parser.add_argument("--show_error_cal", type=int, default=1, help="flag to show error cal")

parser.add_argument("--input_size", default=[3, 640, 640], type=int, help="input size")

parser.add_argument("--calib_steps", type=int, default=64, help="opset version")

parser.add_argument("--onnx_input_path", type=str, default="./model/yolov5.onnx", help="onnx input path")

parser.add_argument(
"--quantized_res_path",
type=str,
default="./model/",
help="quantized outputs",
)

parser.add_argument("--device", type=str, default="cuda", help="opset version")

args = parser.parse_args()

sim_onnx_output_name = "sim_float_yolov5.onnx"
quantized_onnx_output_name = "quantized_float_yolov5.onnx"
quantized_json_output_name = "quantized_json_yolov5.onnx"

print(f"Command Line Args: {args}")

# quantization
onnx_model = onnx.load(args.onnx_input_path)
simplified, _ = simplify(onnx_model)
onnx.save(simplified, os.path.join(args.quantized_res_path, sim_onnx_output_name))
graph = load_onnx_graph(os.path.join(args.quantized_res_path, sim_onnx_output_name))

quantizer = PFL.Quantizer(platform=TargetPlatform.TRT_INT8, graph=graph)
dispatching = {op.name: TargetPlatform.FP32 for op in graph.operations.values()}
# dataloader
dataloader = prepare_data_loaders(args.calibration_data_path, args.input_size)
# 从第一个卷积到最后的卷积中间的所有算子量化,其他算子不量化
from ppq.IR import SearchableGraph

search_engine = SearchableGraph(graph)
for op in search_engine.opset_matching(
sp_expr=lambda x: x.type == "Conv",
rp_expr=lambda x, y: True,
ep_expr=lambda x: x.type == "Conv",
direction="down",
):
dispatching[op.name] = TargetPlatform.TRT_INT8

# 为算子初始化量化信息
for op in graph.operations.values():
if dispatching[op.name] == TargetPlatform.TRT_INT8:
quantizer.quantize_operation(op_name=op.name, platform=dispatching[op.name])

# 初始化执行器
executor = TorchExecutor(graph=graph, device=args.device)
executor.tracing_operation_meta(inputs=torch.zeros(size=[1] + args.input_size).cuda())
executor.load_graph(graph=graph)

# 创建优化管线
pipeline = PFL.Pipeline(
[
QuantizeSimplifyPass(),
QuantizeFusionPass(activation_type=quantizer.activation_fusion_types),
ParameterQuantizePass(),
RuntimeCalibrationPass(),
PassiveParameterQuantizePass(),
QuantAlignmentPass(force_overlap=True),
# 微调你的网络
# LearnedStepSizePass(steps=1500)
# 如果需要训练网络,训练过程必须发生在ParameterBakingPass之前
# ParameterBakingPass()
]
)

# 调用管线完成量化
pipeline.optimize(
graph=graph,
dataloader=dataloader,
verbose=True,
calib_steps=args.calib_steps,
collate_fn=collate_fn,
executor=executor,
)

if args.show_error_cal:
graphwise_error_analyse(
graph=graph,
running_device=args.device,
dataloader=dataloader,
collate_fn=collate_fn,
)

export_ppq_graph(
graph=graph,
platform=TargetPlatform.TRT_INT8,
graph_save_to=os.path.join(args.quantized_res_path, quantized_onnx_output_name),
config_save_to=os.path.join(args.quantized_res_path, quantized_json_output_name),
)


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions deployment/ppq/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# quantization --------------------------------
ppq>=0.6.6
Loading

0 comments on commit 8b578eb

Please sign in to comment.