Skip to content

Commit

Permalink
Fix TensorRT --dynamic excess outputs bug (ultralytics#8869)
Browse files Browse the repository at this point in the history
* Fix TensorRT --dynamic excess outputs bug

Potential fix for ultralytics#8790

* Cleanup

* Update common.py

* Update common.py

* New fix
  • Loading branch information
glenn-jocher authored and Clay Januhowski committed Sep 8, 2022
1 parent f944e91 commit 3943d58
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,13 +387,13 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False,
context = model.create_execution_context()
bindings = OrderedDict()
fp16 = False # default updated below
dynamic_input = False
dynamic = False
for index in range(model.num_bindings):
name = model.get_binding_name(index)
dtype = trt.nptype(model.get_binding_dtype(index))
if model.binding_is_input(index):
if -1 in tuple(model.get_binding_shape(index)): # dynamic
dynamic_input = True
dynamic = True
context.set_binding_shape(index, tuple(model.get_profile_shape(0, index)[2]))
if dtype == np.float16:
fp16 = True
Expand Down Expand Up @@ -471,12 +471,14 @@ def forward(self, im, augment=False, visualize=False, val=False):
im = im.cpu().numpy() # FP32
y = self.executable_network([im])[self.output_layer]
elif self.engine: # TensorRT
if im.shape != self.bindings['images'].shape and self.dynamic_input:
self.context.set_binding_shape(self.model.get_binding_index('images'), im.shape) # reshape if dynamic
if self.dynamic and im.shape != self.bindings['images'].shape:
i_in, i_out = (self.model.get_binding_index(x) for x in ('images', 'output'))
self.context.set_binding_shape(i_in, im.shape) # reshape if dynamic
self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
assert im.shape == self.bindings['images'].shape, (
f"image shape {im.shape} exceeds model max shape {self.bindings['images'].shape}" if self.dynamic_input
else f"image shape {im.shape} does not match model shape {self.bindings['images'].shape}")
self.bindings['output'].data.resize_(tuple(self.context.get_binding_shape(i_out)))
s = self.bindings['images'].shape
assert im.shape == s, f"image shape {im.shape} " + \
f"exceeds model max shape {s}" if self.dynamic else f"does not match model shape {s}"
self.binding_addrs['images'] = int(im.data_ptr())
self.context.execute_v2(list(self.binding_addrs.values()))
y = self.bindings['output'].data
Expand Down

0 comments on commit 3943d58

Please sign in to comment.