Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorRT SegmentationModel fix #9465

Merged
merged 15 commits into from
Sep 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative

from models.experimental import attempt_load
from models.yolo import ClassificationModel, Detect
from models.yolo import ClassificationModel, Detect, DetectionModel, SegmentationModel
from utils.dataloaders import LoadImages
from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version,
check_yaml, colorstr, file_size, get_default_args, print_args, url2file, yaml_save)
Expand Down Expand Up @@ -134,6 +134,15 @@ def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
f = file.with_suffix('.onnx')

output_names = ['output0', 'output1'] if isinstance(model, SegmentationModel) else ['output0']
if dynamic:
dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640)
if isinstance(model, SegmentationModel):
dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160)
elif isinstance(model, DetectionModel):
dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)

torch.onnx.export(
model.cpu() if dynamic else model, # --dynamic only compatible with cpu
im.cpu() if dynamic else im,
Expand All @@ -142,16 +151,8 @@ def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX
opset_version=opset,
do_constant_folding=True,
input_names=['images'],
output_names=['output'],
dynamic_axes={
'images': {
0: 'batch',
2: 'height',
3: 'width'}, # shape(1,3,640,640)
'output': {
0: 'batch',
1: 'anchors'} # shape(1,25200,85)
} if dynamic else None)
output_names=output_names,
dynamic_axes=dynamic or None)

# Checks
model_onnx = onnx.load(f) # load onnx model
Expand Down
27 changes: 16 additions & 11 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,18 +390,21 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False,
model = runtime.deserialize_cuda_engine(f.read())
context = model.create_execution_context()
bindings = OrderedDict()
output_names = []
fp16 = False # default updated below
dynamic = False
for index in range(model.num_bindings):
name = model.get_binding_name(index)
dtype = trt.nptype(model.get_binding_dtype(index))
if model.binding_is_input(index):
if -1 in tuple(model.get_binding_shape(index)): # dynamic
for i in range(model.num_bindings):
name = model.get_binding_name(i)
dtype = trt.nptype(model.get_binding_dtype(i))
if model.binding_is_input(i):
if -1 in tuple(model.get_binding_shape(i)): # dynamic
dynamic = True
context.set_binding_shape(index, tuple(model.get_profile_shape(0, index)[2]))
context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
if dtype == np.float16:
fp16 = True
shape = tuple(context.get_binding_shape(index))
else: # output
output_names.append(name)
shape = tuple(context.get_binding_shape(i))
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
Expand Down Expand Up @@ -495,15 +498,17 @@ def forward(self, im, augment=False, visualize=False):
y = list(self.executable_network([im]).values())
elif self.engine: # TensorRT
if self.dynamic and im.shape != self.bindings['images'].shape:
i_in, i_out = (self.model.get_binding_index(x) for x in ('images', 'output'))
self.context.set_binding_shape(i_in, im.shape) # reshape if dynamic
i = self.model.get_binding_index('images')
self.context.set_binding_shape(i, im.shape) # reshape if dynamic
self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
self.bindings['output'].data.resize_(tuple(self.context.get_binding_shape(i_out)))
for name in self.output_names:
i = self.model.get_binding_index(name)
self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
s = self.bindings['images'].shape
assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
self.binding_addrs['images'] = int(im.data_ptr())
self.context.execute_v2(list(self.binding_addrs.values()))
y = self.bindings['output'].data
y = [self.bindings[x].data for x in sorted(self.output_names)]
elif self.coreml: # CoreML
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
im = Image.fromarray((im[0] * 255).astype('uint8'))
Expand Down