Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support torchscript export #547

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion projects/easydeploy/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def __init__(self,
postprocess_cfg: Optional[ConfigDict] = None):
super().__init__()
self.baseModel = baseModel
self.mean = baseModel.data_preprocessor.mean
self.std = baseModel.data_preprocessor.std
if postprocess_cfg is None:
self.with_postprocess = False
else:
Expand Down Expand Up @@ -143,8 +145,12 @@ def select_nms(self):
return nms_func

def forward(self, inputs: Tensor):
inputs = (inputs - self.mean) / self.std
neck_outputs = self.baseModel(inputs)

if self.with_postprocess:
return self.pred_by_feat(*neck_outputs)
else:
return neck_outputs
return [
torch.cat(o, 1).permute(0, 2, 3, 1) for o in zip(*neck_outputs)
]
72 changes: 72 additions & 0 deletions projects/easydeploy/tools/export_torchscript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import argparse
import os
import warnings

import torch
from mmcv.cnn import fuse_conv_bn
from mmdet.apis import init_detector
from mmengine.utils.path import mkdir_or_exist

from mmyolo.utils import register_all_modules
from projects.easydeploy.model import DeployModel

warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)
warnings.filterwarnings(action='ignore', category=torch.jit.ScriptWarning)
warnings.filterwarnings(action='ignore', category=UserWarning)
warnings.filterwarnings(action='ignore', category=FutureWarning)
warnings.filterwarnings(action='ignore', category=ResourceWarning)


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--work-dir', default='./work_dir', help='Path to save export model')
parser.add_argument(
'--img-size',
nargs='+',
type=int,
default=[640, 640],
help='Image size of height and width')
parser.add_argument('--batch-size', type=int, default=1, help='Batch size')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
args.img_size *= 2 if len(args.img_size) == 1 else 1
return args


def build_model_from_cfg(config_path, checkpoint_path, device):
model = init_detector(config_path, checkpoint_path, device=device)
model.eval()
return model


def main():
args = parse_args()
register_all_modules()

mkdir_or_exist(args.work_dir)

baseModel = build_model_from_cfg(args.config, args.checkpoint, args.device)

deploy_model = DeployModel(baseModel=baseModel, postprocess_cfg=None)
deploy_model = fuse_conv_bn(deploy_model).eval()

fake_input = torch.randn(args.batch_size, 3,
*args.img_size).to(args.device)
# dry run
deploy_model(fake_input)

save_torchscript_path = os.path.join(args.work_dir, 'end2end.torchscript')

# export torchscript
with torch.jit.optimized_execution(True):
ts = torch.jit.trace(deploy_model, fake_input, strict=False)
ts.save(save_torchscript_path)
print(f'TORCHSCRIPT export success, save into {save_torchscript_path}')


if __name__ == '__main__':
main()