From 3943d588d5a8f65715160c43dd65ffeb3284d20b Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 4 Aug 2022 23:26:30 +0200 Subject: [PATCH] Fix TensorRT --dynamic excess outputs bug (#8869) * Fix TensorRT --dynamic excess outputs bug Potential fix for https://github.com/ultralytics/yolov5/issues/8790 * Cleanup * Update common.py * Update common.py * New fix --- models/common.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/models/common.py b/models/common.py index c898d94a921a..cfa688ba940b 100644 --- a/models/common.py +++ b/models/common.py @@ -387,13 +387,13 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, context = model.create_execution_context() bindings = OrderedDict() fp16 = False # default updated below - dynamic_input = False + dynamic = False for index in range(model.num_bindings): name = model.get_binding_name(index) dtype = trt.nptype(model.get_binding_dtype(index)) if model.binding_is_input(index): if -1 in tuple(model.get_binding_shape(index)): # dynamic - dynamic_input = True + dynamic = True context.set_binding_shape(index, tuple(model.get_profile_shape(0, index)[2])) if dtype == np.float16: fp16 = True @@ -471,12 +471,14 @@ def forward(self, im, augment=False, visualize=False, val=False): im = im.cpu().numpy() # FP32 y = self.executable_network([im])[self.output_layer] elif self.engine: # TensorRT - if im.shape != self.bindings['images'].shape and self.dynamic_input: - self.context.set_binding_shape(self.model.get_binding_index('images'), im.shape) # reshape if dynamic + if self.dynamic and im.shape != self.bindings['images'].shape: + i_in, i_out = (self.model.get_binding_index(x) for x in ('images', 'output')) + self.context.set_binding_shape(i_in, im.shape) # reshape if dynamic self.bindings['images'] = self.bindings['images']._replace(shape=im.shape) - assert im.shape == self.bindings['images'].shape, ( - f"image shape {im.shape} exceeds model max shape {self.bindings['images'].shape}" if self.dynamic_input - else f"image shape {im.shape} does not match model shape {self.bindings['images'].shape}") + self.bindings['output'].data.resize_(tuple(self.context.get_binding_shape(i_out))) + s = self.bindings['images'].shape + assert im.shape == s, f"image shape {im.shape} " + \ + f"exceeds model max shape {s}" if self.dynamic else f"does not match model shape {s}" self.binding_addrs['images'] = int(im.data_ptr()) self.context.execute_v2(list(self.binding_addrs.values())) y = self.bindings['output'].data