From 2c49ecb3ed8959650eb72c18f5f99b33ed58c93c Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 3 Apr 2022 22:51:11 +0200 Subject: [PATCH] TorchScript single-output fix (#7261) --- export.py | 18 ++++++++++++------ models/common.py | 7 ++++--- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/export.py b/export.py index 574bf8d9ed61..87be00376778 100644 --- a/export.py +++ b/export.py @@ -73,12 +73,18 @@ def export_formats(): # YOLOv5 export formats - x = [['PyTorch', '-', '.pt', True], ['TorchScript', 'torchscript', '.torchscript', True], - ['ONNX', 'onnx', '.onnx', True], ['OpenVINO', 'openvino', '_openvino_model', False], - ['TensorRT', 'engine', '.engine', True], ['CoreML', 'coreml', '.mlmodel', False], - ['TensorFlow SavedModel', 'saved_model', '_saved_model', True], ['TensorFlow GraphDef', 'pb', '.pb', True], - ['TensorFlow Lite', 'tflite', '.tflite', False], ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False], - ['TensorFlow.js', 'tfjs', '_web_model', False]] + x = [ + ['PyTorch', '-', '.pt', True], + ['TorchScript', 'torchscript', '.torchscript', True], + ['ONNX', 'onnx', '.onnx', True], + ['OpenVINO', 'openvino', '_openvino_model', False], + ['TensorRT', 'engine', '.engine', True], + ['CoreML', 'coreml', '.mlmodel', False], + ['TensorFlow SavedModel', 'saved_model', '_saved_model', True], + ['TensorFlow GraphDef', 'pb', '.pb', True], + ['TensorFlow Lite', 'tflite', '.tflite', False], + ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False], + ['TensorFlow.js', 'tfjs', '_web_model', False],] return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'GPU']) diff --git a/models/common.py b/models/common.py index 8396caa1af5c..dcd3e5f408dd 100644 --- a/models/common.py +++ b/models/common.py @@ -406,9 +406,10 @@ def wrap_frozen_graph(gd, inputs, outputs): def forward(self, im, augment=False, visualize=False, val=False): # YOLOv5 MultiBackend inference b, ch, h, w = im.shape # batch, channel, height, width - if self.pt or self.jit: # PyTorch - y = self.model(im) if self.jit else self.model(im, augment=augment, visualize=visualize) - return y if val else y[0] + if self.pt: # PyTorch + y = self.model(im, augment=augment, visualize=visualize)[0] + elif self.jit: # TorchScript + y = self.model(im)[0] elif self.dnn: # ONNX OpenCV DNN im = im.cpu().numpy() # torch to numpy self.net.setInput(im)