Skip to content

Commit

Permalink
Rename MMYOLOBackend
Browse files Browse the repository at this point in the history
  • Loading branch information
triple-Mu committed Mar 7, 2023
1 parent 4e6b821 commit d80c5ec
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 23 deletions.
4 changes: 2 additions & 2 deletions projects/easydeploy/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .backend import MMYoloBackend
from .backend import MMYOLOBackend
from .backendwrapper import ORTWrapper, TRTWrapper
from .model import DeployModel

__all__ = ['DeployModel', 'TRTWrapper', 'ORTWrapper', 'MMYoloBackend']
__all__ = ['DeployModel', 'TRTWrapper', 'ORTWrapper', 'MMYOLOBackend']
28 changes: 14 additions & 14 deletions projects/easydeploy/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
from ..bbox_code import (rtmdet_bbox_decoder, yolov5_bbox_decoder,
yolox_bbox_decoder)
from ..nms import batched_nms, efficient_nms, onnx_nms
from .backend import MMYoloBackend
from .backend import MMYOLOBackend


class DeployModel(nn.Module):
transpose = False

def __init__(self,
baseModel: nn.Module,
backend: MMYoloBackend,
backend: MMYOLOBackend,
postprocess_cfg: Optional[ConfigDict] = None):
super().__init__()
self.baseModel = baseModel
Expand All @@ -49,21 +49,21 @@ def __init_sub_attributes(self):
self.num_classes = self.baseHead.num_classes

def __switch_deploy(self):
if self.backend in (MMYoloBackend.HORIZONX3, MMYoloBackend.NCNN,
MMYoloBackend.TORCHSCRIPT):
if self.backend in (MMYOLOBackend.HORIZONX3, MMYOLOBackend.NCNN,
MMYOLOBackend.TORCHSCRIPT):
self.transpose = True
for layer in self.baseModel.modules():
if isinstance(layer, RepVGGBlock):
layer.switch_to_deploy()
elif isinstance(layer, Focus):
# onnxruntime openvino tensorrt8 tensorrt7
if self.backend in (MMYoloBackend.ONNXRUNTIME,
MMYoloBackend.OPENVINO,
MMYoloBackend.TENSORRT8,
MMYoloBackend.TENSORRT7):
if self.backend in (MMYOLOBackend.ONNXRUNTIME,
MMYOLOBackend.OPENVINO,
MMYOLOBackend.TENSORRT8,
MMYOLOBackend.TENSORRT7):
self.baseModel.backbone.stem = DeployFocus(layer)
# ncnn
elif self.backend == MMYoloBackend.NCNN:
elif self.backend == MMYOLOBackend.NCNN:
self.baseModel.backbone.stem = NcnnFocus(layer)
# switch focus to group conv
else:
Expand Down Expand Up @@ -135,11 +135,11 @@ def pred_by_feat(self,
self.score_threshold, self.pre_top_k, self.keep_top_k)

def select_nms(self):
if self.backend in (MMYoloBackend.ONNXRUNTIME, MMYoloBackend.OPENVINO):
if self.backend in (MMYOLOBackend.ONNXRUNTIME, MMYOLOBackend.OPENVINO):
nms_func = onnx_nms
elif self.backend == MMYoloBackend.TENSORRT8:
elif self.backend == MMYOLOBackend.TENSORRT8:
nms_func = efficient_nms
elif self.backend == MMYoloBackend.TENSORRT7:
elif self.backend == MMYOLOBackend.TENSORRT7:
nms_func = batched_nms
else:
raise NotImplementedError
Expand All @@ -156,8 +156,8 @@ def forward(self, inputs: Tensor):
outputs = []
if self.transpose:
for feats in zip(*neck_outputs):
if self.backend in (MMYoloBackend.NCNN,
MMYoloBackend.TORCHSCRIPT):
if self.backend in (MMYOLOBackend.NCNN,
MMYOLOBackend.TORCHSCRIPT):
outputs.append(
torch.cat(
[feat.permute(0, 2, 3, 1) for feat in feats],
Expand Down
10 changes: 5 additions & 5 deletions projects/easydeploy/tools/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mmengine.logging import print_log
from mmengine.utils.path import mkdir_or_exist

from projects.easydeploy.model import DeployModel, MMYoloBackend
from projects.easydeploy.model import DeployModel, MMYOLOBackend

warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)
warnings.filterwarnings(action='ignore', category=torch.jit.ScriptWarning)
Expand Down Expand Up @@ -81,9 +81,9 @@ def build_model_from_cfg(config_path, checkpoint_path, device):
def main():
args = parse_args()
mkdir_or_exist(args.work_dir)
backend = MMYoloBackend(args.backend.lower())
if backend in (MMYoloBackend.ONNXRUNTIME, MMYoloBackend.OPENVINO,
MMYoloBackend.TENSORRT8, MMYoloBackend.TENSORRT7):
backend = MMYOLOBackend(args.backend.lower())
if backend in (MMYOLOBackend.ONNXRUNTIME, MMYOLOBackend.OPENVINO,
MMYOLOBackend.TENSORRT8, MMYOLOBackend.TENSORRT7):
if not args.model_only:
print_log('Export ONNX with bbox decoder and NMS ...')
else:
Expand Down Expand Up @@ -128,7 +128,7 @@ def main():
onnx.checker.check_model(onnx_model)

# Fix tensorrt onnx output shape, just for view
if backend in (MMYoloBackend.TENSORRT8, MMYoloBackend.TENSORRT8):
if backend in (MMYOLOBackend.TENSORRT8, MMYOLOBackend.TENSORRT8):
shapes = [
args.batch_size, 1, args.batch_size, args.keep_topk, 4,
args.batch_size, args.keep_topk, args.batch_size,
Expand Down
4 changes: 2 additions & 2 deletions projects/easydeploy/tools/export_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mmengine.logging import print_log
from mmengine.utils.path import mkdir_or_exist

from projects.easydeploy.model import DeployModel, MMYoloBackend
from projects.easydeploy.model import DeployModel, MMYOLOBackend

warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)
warnings.filterwarnings(action='ignore', category=torch.jit.ScriptWarning)
Expand Down Expand Up @@ -50,7 +50,7 @@ def main():

deploy_model = DeployModel(
baseModel=baseModel,
backend=MMYoloBackend.TORCHSCRIPT,
backend=MMYOLOBackend.TORCHSCRIPT,
postprocess_cfg=None)
deploy_model.eval()

Expand Down

0 comments on commit d80c5ec

Please sign in to comment.