Skip to content

Commit

Permalink
Update DetectMultiBackend for tuple outputs (ultralytics#9274)
Browse files Browse the repository at this point in the history
Update
  • Loading branch information
glenn-jocher authored and Clay Januhowski committed Sep 8, 2022
1 parent 59b9ad1 commit 20ce1c9
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,17 +465,15 @@ def forward(self, im, augment=False, visualize=False, val=False):

if self.pt: # PyTorch
y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
if isinstance(y, tuple):
y = y[0]
elif self.jit: # TorchScript
y = self.model(im)[0]
y = self.model(im)
elif self.dnn: # ONNX OpenCV DNN
im = im.cpu().numpy() # torch to numpy
self.net.setInput(im)
y = self.net.forward()
elif self.onnx: # ONNX Runtime
im = im.cpu().numpy() # torch to numpy
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})[0]
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
elif self.xml: # OpenVINO
im = im.cpu().numpy() # FP32
y = self.executable_network([im])[self.output_layer]
Expand Down Expand Up @@ -522,6 +520,8 @@ def forward(self, im, augment=False, visualize=False, val=False):
y = (y.astype(np.float32) - zero_point) * scale # re-scale
y[..., :4] *= [w, h, w, h] # xywh normalized to pixels

if isinstance(y, (list, tuple)):
y = y[0]
if isinstance(y, np.ndarray):
y = torch.from_numpy(y).to(self.device)
return (y, []) if val else y
Expand Down

0 comments on commit 20ce1c9

Please sign in to comment.