diff --git a/models/common.py b/models/common.py index 2e5d5a198e33..5c82b18f102c 100644 --- a/models/common.py +++ b/models/common.py @@ -465,17 +465,15 @@ def forward(self, im, augment=False, visualize=False, val=False): if self.pt: # PyTorch y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im) - if isinstance(y, tuple): - y = y[0] elif self.jit: # TorchScript - y = self.model(im)[0] + y = self.model(im) elif self.dnn: # ONNX OpenCV DNN im = im.cpu().numpy() # torch to numpy self.net.setInput(im) y = self.net.forward() elif self.onnx: # ONNX Runtime im = im.cpu().numpy() # torch to numpy - y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})[0] + y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im}) elif self.xml: # OpenVINO im = im.cpu().numpy() # FP32 y = self.executable_network([im])[self.output_layer] @@ -522,6 +520,8 @@ def forward(self, im, augment=False, visualize=False, val=False): y = (y.astype(np.float32) - zero_point) * scale # re-scale y[..., :4] *= [w, h, w, h] # xywh normalized to pixels + if isinstance(y, (list, tuple)): + y = y[0] if isinstance(y, np.ndarray): y = torch.from_numpy(y).to(self.device) return (y, []) if val else y