From 932dc78496ca532a41780335468589ad7f0147f7 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 14 Mar 2022 15:07:13 +0100 Subject: [PATCH] YOLOv5 Export Benchmarks for GPU (#6963) * Add benchmarks.py GPU support * Updates * Updates * Updates * Updates * Add --half * Add TRT requirements * Cleanup * Add TF to warmup types * Update export.py * Update export.py * Update benchmarks.py --- export.py | 24 ++++++++++++------------ models/common.py | 7 ++++--- utils/benchmarks.py | 18 +++++++++++++++--- 3 files changed, 31 insertions(+), 18 deletions(-) diff --git a/export.py b/export.py index c50de15cf0b8..d4f980fdb993 100644 --- a/export.py +++ b/export.py @@ -75,18 +75,18 @@ def export_formats(): # YOLOv5 export formats - x = [['PyTorch', '-', '.pt'], - ['TorchScript', 'torchscript', '.torchscript'], - ['ONNX', 'onnx', '.onnx'], - ['OpenVINO', 'openvino', '_openvino_model'], - ['TensorRT', 'engine', '.engine'], - ['CoreML', 'coreml', '.mlmodel'], - ['TensorFlow SavedModel', 'saved_model', '_saved_model'], - ['TensorFlow GraphDef', 'pb', '.pb'], - ['TensorFlow Lite', 'tflite', '.tflite'], - ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite'], - ['TensorFlow.js', 'tfjs', '_web_model']] - return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix']) + x = [['PyTorch', '-', '.pt', True], + ['TorchScript', 'torchscript', '.torchscript', True], + ['ONNX', 'onnx', '.onnx', True], + ['OpenVINO', 'openvino', '_openvino_model', False], + ['TensorRT', 'engine', '.engine', True], + ['CoreML', 'coreml', '.mlmodel', False], + ['TensorFlow SavedModel', 'saved_model', '_saved_model', True], + ['TensorFlow GraphDef', 'pb', '.pb', True], + ['TensorFlow Lite', 'tflite', '.tflite', False], + ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False], + ['TensorFlow.js', 'tfjs', '_web_model', False]] + return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'GPU']) def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')): diff --git a/models/common.py b/models/common.py index 83aecb7569d6..4ad040fcd7f1 100644 --- a/models/common.py +++ b/models/common.py @@ -464,10 +464,11 @@ def forward(self, im, augment=False, visualize=False, val=False): def warmup(self, imgsz=(1, 3, 640, 640)): # Warmup model by running inference once - if self.pt or self.jit or self.onnx or self.engine: # warmup types - if isinstance(self.device, torch.device) and self.device.type != 'cpu': # only warmup GPU models + if any((self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb)): # warmup types + if self.device.type != 'cpu': # only warmup GPU models im = torch.zeros(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input - self.forward(im) # warmup + for _ in range(2 if self.jit else 1): # + self.forward(im) # warmup @staticmethod def model_type(p='path/to/model.pt'): diff --git a/utils/benchmarks.py b/utils/benchmarks.py index 962df812a9d3..bdbbdc43b639 100644 --- a/utils/benchmarks.py +++ b/utils/benchmarks.py @@ -19,6 +19,7 @@ Requirements: $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU + $ pip install -U nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com # TensorRT Usage: $ python utils/benchmarks.py --weights yolov5s.pt --img 640 @@ -41,20 +42,29 @@ import val from utils import notebook_init from utils.general import LOGGER, print_args +from utils.torch_utils import select_device def run(weights=ROOT / 'yolov5s.pt', # weights path imgsz=640, # inference size (pixels) batch_size=1, # batch size data=ROOT / 'data/coco128.yaml', # dataset.yaml path + device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu + half=False, # use FP16 half-precision inference ): y, t = [], time.time() formats = export.export_formats() - for i, (name, f, suffix) in formats.iterrows(): # index, (name, file, suffix) + device = select_device(device) + for i, (name, f, suffix, gpu) in formats.iterrows(): # index, (name, file, suffix, gpu-capable) try: - w = weights if f == '-' else export.run(weights=weights, imgsz=[imgsz], include=[f], device='cpu')[-1] + if device.type != 'cpu': + assert gpu, f'{name} inference not supported on GPU' + if f == '-': + w = weights # PyTorch format + else: + w = export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # all others assert suffix in str(w), 'export failed' - result = val.run(data, w, batch_size, imgsz=imgsz, plots=False, device='cpu', task='benchmark') + result = val.run(data, w, batch_size, imgsz, plots=False, device=device, task='benchmark', half=half) metrics = result[0] # metrics (mp, mr, map50, map, *losses(box, obj, cls)) speeds = result[2] # times (preprocess, inference, postprocess) y.append([name, metrics[3], speeds[1]]) # mAP, t_inference @@ -78,6 +88,8 @@ def parse_opt(): parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)') parser.add_argument('--batch-size', type=int, default=1, help='batch size') parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path') + parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') opt = parser.parse_args() print_args(FILE.stem, opt) return opt