Skip to content

Commit

Permalink
Support TORCHSCRIPT export for NCNN
Browse files Browse the repository at this point in the history
  • Loading branch information
triple-Mu committed Mar 7, 2023
1 parent 3568585 commit 4e6b821
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 27 deletions.
3 changes: 2 additions & 1 deletion projects/easydeploy/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .backend import MMYoloBackend
from .backendwrapper import ORTWrapper, TRTWrapper
from .model import DeployModel

__all__ = ['DeployModel', 'TRTWrapper', 'ORTWrapper']
__all__ = ['DeployModel', 'TRTWrapper', 'ORTWrapper', 'MMYoloBackend']
43 changes: 30 additions & 13 deletions projects/easydeploy/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,34 @@
from mmyolo.models import RepVGGBlock
from mmyolo.models.dense_heads import (RTMDetHead, YOLOv5Head, YOLOv7Head,
YOLOXHead)
from mmyolo.models.layers import CSPLayerWithTwoConv
from ..backbone import DeployC2f, DeployFocus, GConvFocus, NcnnFocus
from ..backbone import DeployFocus, GConvFocus, NcnnFocus
from ..bbox_code import (rtmdet_bbox_decoder, yolov5_bbox_decoder,
yolox_bbox_decoder)
from ..nms import batched_nms, efficient_nms, onnx_nms
from .backend import MMYoloBackend


class DeployModel(nn.Module):
transpose = False

def __init__(self,
baseModel: nn.Module,
backend: MMYoloBackend,
postprocess_cfg: Optional[ConfigDict] = None):
super().__init__()
self.baseModel = baseModel
self.baseHead = baseModel.bbox_head
self.backend = backend
if postprocess_cfg is None:
self.with_postprocess = False
else:
self.with_postprocess = True
self.baseHead = baseModel.bbox_head
self.__init_sub_attributes()
self.detector_type = type(self.baseHead)
self.pre_top_k = postprocess_cfg.get('pre_top_k', 1000)
self.keep_top_k = postprocess_cfg.get('keep_top_k', 100)
self.iou_threshold = postprocess_cfg.get('iou_threshold', 0.65)
self.score_threshold = postprocess_cfg.get('score_threshold', 0.25)
self.backend = postprocess_cfg.get('backend', 1)
self.__switch_deploy()

def __init_sub_attributes(self):
Expand All @@ -47,21 +49,25 @@ def __init_sub_attributes(self):
self.num_classes = self.baseHead.num_classes

def __switch_deploy(self):
if self.backend in (MMYoloBackend.HORIZONX3, MMYoloBackend.NCNN,
MMYoloBackend.TORCHSCRIPT):
self.transpose = True
for layer in self.baseModel.modules():
if isinstance(layer, RepVGGBlock):
layer.switch_to_deploy()
elif isinstance(layer, Focus):
# onnxruntime tensorrt8 tensorrt7
if self.backend in (1, 2, 3):
# onnxruntime openvino tensorrt8 tensorrt7
if self.backend in (MMYoloBackend.ONNXRUNTIME,
MMYoloBackend.OPENVINO,
MMYoloBackend.TENSORRT8,
MMYoloBackend.TENSORRT7):
self.baseModel.backbone.stem = DeployFocus(layer)
# ncnn
elif self.backend == 4:
elif self.backend == MMYoloBackend.NCNN:
self.baseModel.backbone.stem = NcnnFocus(layer)
# switch focus to group conv
else:
self.baseModel.backbone.stem = GConvFocus(layer)
elif isinstance(layer, CSPLayerWithTwoConv):
setattr(layer, '__class__', DeployC2f)

def pred_by_feat(self,
cls_scores: List[Tensor],
Expand Down Expand Up @@ -129,11 +135,11 @@ def pred_by_feat(self,
self.score_threshold, self.pre_top_k, self.keep_top_k)

def select_nms(self):
if self.backend == 1:
if self.backend in (MMYoloBackend.ONNXRUNTIME, MMYoloBackend.OPENVINO):
nms_func = onnx_nms
elif self.backend == 2:
elif self.backend == MMYoloBackend.TENSORRT8:
nms_func = efficient_nms
elif self.backend == 3:
elif self.backend == MMYoloBackend.TENSORRT7:
nms_func = batched_nms
else:
raise NotImplementedError
Expand All @@ -147,4 +153,15 @@ def forward(self, inputs: Tensor):
if self.with_postprocess:
return self.pred_by_feat(*neck_outputs)
else:
return neck_outputs
outputs = []
if self.transpose:
for feats in zip(*neck_outputs):
if self.backend in (MMYoloBackend.NCNN,
MMYoloBackend.TORCHSCRIPT):
outputs.append(
torch.cat(
[feat.permute(0, 2, 3, 1) for feat in feats],
-1))
else:
outputs.append(torch.cat(feats, 1).permute(0, 2, 3, 1))
return tuple(outputs)
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import torch
from mmdet.apis import init_detector
from mmengine.config import ConfigDict
from mmengine.logging import print_log
from mmengine.utils.path import mkdir_or_exist

from mmyolo.utils import register_all_modules
from projects.easydeploy.model import DeployModel
from projects.easydeploy.model import DeployModel, MMYoloBackend

warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)
warnings.filterwarnings(action='ignore', category=torch.jit.ScriptWarning)
Expand Down Expand Up @@ -43,7 +43,10 @@ def parse_args():
parser.add_argument(
'--opset', type=int, default=11, help='ONNX opset version')
parser.add_argument(
'--backend', type=int, default=1, help='Backend for export onnx')
'--backend',
type=str,
default='onnxruntime',
help='Backend for export onnx')
parser.add_argument(
'--pre-topk',
type=int,
Expand Down Expand Up @@ -77,10 +80,16 @@ def build_model_from_cfg(config_path, checkpoint_path, device):

def main():
args = parse_args()
register_all_modules()

mkdir_or_exist(args.work_dir)

backend = MMYoloBackend(args.backend.lower())
if backend in (MMYoloBackend.ONNXRUNTIME, MMYoloBackend.OPENVINO,
MMYoloBackend.TENSORRT8, MMYoloBackend.TENSORRT7):
if not args.model_only:
print_log('Export ONNX with bbox decoder and NMS ...')
else:
args.model_only = True
print_log(f'Can not export postprocess for {args.backend.lower()}.\n'
f'Set "args.model_only=True" default.')
if args.model_only:
postprocess_cfg = None
output_names = None
Expand All @@ -89,21 +98,22 @@ def main():
pre_top_k=args.pre_topk,
keep_top_k=args.keep_topk,
iou_threshold=args.iou_threshold,
score_threshold=args.score_threshold,
backend=args.backend)
score_threshold=args.score_threshold)
output_names = ['num_dets', 'boxes', 'scores', 'labels']
baseModel = build_model_from_cfg(args.config, args.checkpoint, args.device)

deploy_model = DeployModel(
baseModel=baseModel, postprocess_cfg=postprocess_cfg)
baseModel=baseModel, backend=backend, postprocess_cfg=postprocess_cfg)
deploy_model.eval()

fake_input = torch.randn(args.batch_size, 3,
*args.img_size).to(args.device)
# dry run
deploy_model(fake_input)

save_onnx_path = os.path.join(args.work_dir, 'end2end.onnx')
save_onnx_path = os.path.join(
args.work_dir,
os.path.basename(args.checkpoint).replace('pth', 'onnx'))
# export onnx
with BytesIO() as f:
torch.onnx.export(
Expand All @@ -118,7 +128,7 @@ def main():
onnx.checker.check_model(onnx_model)

# Fix tensorrt onnx output shape, just for view
if args.backend in (2, 3):
if backend in (MMYoloBackend.TENSORRT8, MMYoloBackend.TENSORRT8):
shapes = [
args.batch_size, 1, args.batch_size, args.keep_topk, 4,
args.batch_size, args.keep_topk, args.batch_size,
Expand All @@ -133,9 +143,9 @@ def main():
onnx_model, check = onnxsim.simplify(onnx_model)
assert check, 'assert check failed'
except Exception as e:
print(f'Simplify failure: {e}')
print_log(f'Simplify failure: {e}')
onnx.save(onnx_model, save_onnx_path)
print(f'ONNX export success, save into {save_onnx_path}')
print_log(f'ONNX export success, save into {save_onnx_path}')


if __name__ == '__main__':
Expand Down
71 changes: 71 additions & 0 deletions projects/easydeploy/tools/export_torchscript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import argparse
import os
import warnings

import torch
from mmdet.apis import init_detector
from mmengine.logging import print_log
from mmengine.utils.path import mkdir_or_exist

from projects.easydeploy.model import DeployModel, MMYoloBackend

warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)
warnings.filterwarnings(action='ignore', category=torch.jit.ScriptWarning)
warnings.filterwarnings(action='ignore', category=UserWarning)
warnings.filterwarnings(action='ignore', category=FutureWarning)
warnings.filterwarnings(action='ignore', category=ResourceWarning)


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--work-dir', default='./work_dir', help='Path to save export model')
parser.add_argument(
'--img-size',
nargs='+',
type=int,
default=[640, 640],
help='Image size of height and width')
parser.add_argument('--batch-size', type=int, default=1, help='Batch size')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
args.img_size *= 2 if len(args.img_size) == 1 else 1
return args


def build_model_from_cfg(config_path, checkpoint_path, device):
model = init_detector(config_path, checkpoint_path, device=device)
model.eval()
return model


def main():
args = parse_args()
mkdir_or_exist(args.work_dir)

baseModel = build_model_from_cfg(args.config, args.checkpoint, args.device)

deploy_model = DeployModel(
baseModel=baseModel,
backend=MMYoloBackend.TORCHSCRIPT,
postprocess_cfg=None)
deploy_model.eval()

fake_input = torch.randn(args.batch_size, 3,
*args.img_size).to(args.device)
# dry run
deploy_model(fake_input)

save_torchscript_path = os.path.join(
args.work_dir,
os.path.basename(args.checkpoint).replace('pth', 'torchscript'))
mod = torch.jit.trace(deploy_model, fake_input)
mod.save(save_torchscript_path)
print_log(f'TORCHSCRIPT export success, save into {save_torchscript_path}')


if __name__ == '__main__':
main()

0 comments on commit 4e6b821

Please sign in to comment.