From f65725ffc17195b16674abb61e77535cf871e1a1 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 12 Mar 2022 19:01:02 +0100 Subject: [PATCH 01/12] Add benchmarks.py GPU support --- utils/benchmarks.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/utils/benchmarks.py b/utils/benchmarks.py index 962df812a9d3..acb63aed2d6e 100644 --- a/utils/benchmarks.py +++ b/utils/benchmarks.py @@ -41,15 +41,18 @@ import val from utils import notebook_init from utils.general import LOGGER, print_args +from utils.torch_utils import select_device def run(weights=ROOT / 'yolov5s.pt', # weights path imgsz=640, # inference size (pixels) batch_size=1, # batch size data=ROOT / 'data/coco128.yaml', # dataset.yaml path + device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu ): y, t = [], time.time() formats = export.export_formats() + device = select_device(device) for i, (name, f, suffix) in formats.iterrows(): # index, (name, file, suffix) try: w = weights if f == '-' else export.run(weights=weights, imgsz=[imgsz], include=[f], device='cpu')[-1] @@ -78,6 +81,7 @@ def parse_opt(): parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)') parser.add_argument('--batch-size', type=int, default=1, help='batch size') parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path') + parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') opt = parser.parse_args() print_args(FILE.stem, opt) return opt From ce0bce8f4c42f40ac7a3ba4a0be4969477e6fbd9 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 12 Mar 2022 19:07:07 +0100 Subject: [PATCH 02/12] Updates --- utils/benchmarks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/benchmarks.py b/utils/benchmarks.py index acb63aed2d6e..b0d7d15ede69 100644 --- a/utils/benchmarks.py +++ b/utils/benchmarks.py @@ -55,9 +55,9 @@ def run(weights=ROOT / 'yolov5s.pt', # weights path device = select_device(device) for i, (name, f, suffix) in formats.iterrows(): # index, (name, file, suffix) try: - w = weights if f == '-' else export.run(weights=weights, imgsz=[imgsz], include=[f], device='cpu')[-1] + w = weights if f == '-' else export.run(weights=weights, imgsz=[imgsz], include=[f], device=device)[-1] assert suffix in str(w), 'export failed' - result = val.run(data, w, batch_size, imgsz=imgsz, plots=False, device='cpu', task='benchmark') + result = val.run(data, w, batch_size, imgsz=imgsz, plots=False, device=device, task='benchmark') metrics = result[0] # metrics (mp, mr, map50, map, *losses(box, obj, cls)) speeds = result[2] # times (preprocess, inference, postprocess) y.append([name, metrics[3], speeds[1]]) # mAP, t_inference From e34ff3c0064c1f94019fe397fed8d7fa1dd3b994 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 12 Mar 2022 19:14:51 +0100 Subject: [PATCH 03/12] Updates --- utils/benchmarks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/benchmarks.py b/utils/benchmarks.py index b0d7d15ede69..acef79e8a756 100644 --- a/utils/benchmarks.py +++ b/utils/benchmarks.py @@ -57,7 +57,7 @@ def run(weights=ROOT / 'yolov5s.pt', # weights path try: w = weights if f == '-' else export.run(weights=weights, imgsz=[imgsz], include=[f], device=device)[-1] assert suffix in str(w), 'export failed' - result = val.run(data, w, batch_size, imgsz=imgsz, plots=False, device=device, task='benchmark') + result = val.run(data, w, batch_size, imgsz=imgsz, plots=False, device=device, task='benchmark', half=False) metrics = result[0] # metrics (mp, mr, map50, map, *losses(box, obj, cls)) speeds = result[2] # times (preprocess, inference, postprocess) y.append([name, metrics[3], speeds[1]]) # mAP, t_inference From 34335d5bfefda81e221bd923acb9a3edc1fd1b52 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 12 Mar 2022 19:18:30 +0100 Subject: [PATCH 04/12] Updates --- models/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/common.py b/models/common.py index 83aecb7569d6..3c7d4f6283a2 100644 --- a/models/common.py +++ b/models/common.py @@ -465,7 +465,7 @@ def forward(self, im, augment=False, visualize=False, val=False): def warmup(self, imgsz=(1, 3, 640, 640)): # Warmup model by running inference once if self.pt or self.jit or self.onnx or self.engine: # warmup types - if isinstance(self.device, torch.device) and self.device.type != 'cpu': # only warmup GPU models + if self.device.type != 'cpu': # only warmup GPU models im = torch.zeros(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input self.forward(im) # warmup From 73dec99accc9a3a45e06f446409f677f867cd28b Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 12 Mar 2022 19:22:08 +0100 Subject: [PATCH 05/12] Updates --- models/common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/models/common.py b/models/common.py index 3c7d4f6283a2..03dd919e6ba4 100644 --- a/models/common.py +++ b/models/common.py @@ -467,7 +467,8 @@ def warmup(self, imgsz=(1, 3, 640, 640)): if self.pt or self.jit or self.onnx or self.engine: # warmup types if self.device.type != 'cpu': # only warmup GPU models im = torch.zeros(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input - self.forward(im) # warmup + for _ in range(2 if self.jit else 1): # + self.forward(im) # warmup @staticmethod def model_type(p='path/to/model.pt'): From f08e222c2711d217c01b9993b3bc8b20530f3f60 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 12 Mar 2022 19:34:52 +0100 Subject: [PATCH 06/12] Add --half --- utils/benchmarks.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/utils/benchmarks.py b/utils/benchmarks.py index acef79e8a756..aa1dde710deb 100644 --- a/utils/benchmarks.py +++ b/utils/benchmarks.py @@ -49,15 +49,19 @@ def run(weights=ROOT / 'yolov5s.pt', # weights path batch_size=1, # batch size data=ROOT / 'data/coco128.yaml', # dataset.yaml path device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu + half=False, # use FP16 half-precision inference ): y, t = [], time.time() formats = export.export_formats() device = select_device(device) for i, (name, f, suffix) in formats.iterrows(): # index, (name, file, suffix) try: - w = weights if f == '-' else export.run(weights=weights, imgsz=[imgsz], include=[f], device=device)[-1] + if f == '-': # PyTorch + w = weights + else: + w = export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] assert suffix in str(w), 'export failed' - result = val.run(data, w, batch_size, imgsz=imgsz, plots=False, device=device, task='benchmark', half=False) + result = val.run(data, w, batch_size, imgsz, plots=False, device=device, task='benchmark', half=half) metrics = result[0] # metrics (mp, mr, map50, map, *losses(box, obj, cls)) speeds = result[2] # times (preprocess, inference, postprocess) y.append([name, metrics[3], speeds[1]]) # mAP, t_inference @@ -82,6 +86,7 @@ def parse_opt(): parser.add_argument('--batch-size', type=int, default=1, help='batch size') parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') opt = parser.parse_args() print_args(FILE.stem, opt) return opt From ab2b1c05ebc4c8b26b2452f549a5481e9d8b8121 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 12 Mar 2022 20:02:26 +0100 Subject: [PATCH 07/12] Add TRT requirements --- utils/benchmarks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/utils/benchmarks.py b/utils/benchmarks.py index aa1dde710deb..1679a2815cf8 100644 --- a/utils/benchmarks.py +++ b/utils/benchmarks.py @@ -19,6 +19,7 @@ Requirements: $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU + $ pip install -U nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com # TensorRT Usage: $ python utils/benchmarks.py --weights yolov5s.pt --img 640 From 169f39abb0a8dce943097cdf95bdcafda560df4c Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 14 Mar 2022 11:57:55 +0100 Subject: [PATCH 08/12] Cleanup --- utils/benchmarks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/benchmarks.py b/utils/benchmarks.py index 1679a2815cf8..ecbc32388475 100644 --- a/utils/benchmarks.py +++ b/utils/benchmarks.py @@ -57,10 +57,10 @@ def run(weights=ROOT / 'yolov5s.pt', # weights path device = select_device(device) for i, (name, f, suffix) in formats.iterrows(): # index, (name, file, suffix) try: - if f == '-': # PyTorch - w = weights + if f == '-': + w = weights # PyTorch format else: - w = export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] + w = export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # all others assert suffix in str(w), 'export failed' result = val.run(data, w, batch_size, imgsz, plots=False, device=device, task='benchmark', half=half) metrics = result[0] # metrics (mp, mr, map50, map, *losses(box, obj, cls)) From 412671924e56b754d071201c46aff423f6519810 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 14 Mar 2022 12:29:58 +0100 Subject: [PATCH 09/12] Add TF to warmup types --- models/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/common.py b/models/common.py index 03dd919e6ba4..4ad040fcd7f1 100644 --- a/models/common.py +++ b/models/common.py @@ -464,7 +464,7 @@ def forward(self, im, augment=False, visualize=False, val=False): def warmup(self, imgsz=(1, 3, 640, 640)): # Warmup model by running inference once - if self.pt or self.jit or self.onnx or self.engine: # warmup types + if any((self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb)): # warmup types if self.device.type != 'cpu': # only warmup GPU models im = torch.zeros(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input for _ in range(2 if self.jit else 1): # From 8a07c06dc0be1c0f18a8c529dd69fb478f346e78 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 14 Mar 2022 13:06:40 +0100 Subject: [PATCH 10/12] Update export.py --- export.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/export.py b/export.py index c50de15cf0b8..01d894ee4aae 100644 --- a/export.py +++ b/export.py @@ -75,18 +75,18 @@ def export_formats(): # YOLOv5 export formats - x = [['PyTorch', '-', '.pt'], - ['TorchScript', 'torchscript', '.torchscript'], - ['ONNX', 'onnx', '.onnx'], - ['OpenVINO', 'openvino', '_openvino_model'], - ['TensorRT', 'engine', '.engine'], - ['CoreML', 'coreml', '.mlmodel'], - ['TensorFlow SavedModel', 'saved_model', '_saved_model'], - ['TensorFlow GraphDef', 'pb', '.pb'], - ['TensorFlow Lite', 'tflite', '.tflite'], - ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite'], - ['TensorFlow.js', 'tfjs', '_web_model']] - return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix']) + x = [['PyTorch', '-', '.pt', True], + ['TorchScript', 'torchscript', '.torchscript', True], + ['ONNX', 'onnx', '.onnx', True], + ['OpenVINO', 'openvino', '_openvino_model', False], + ['TensorRT', 'engine', '.engine', True], + ['CoreML', 'coreml', '.mlmodel', True], + ['TensorFlow SavedModel', 'saved_model', '_saved_model', True], + ['TensorFlow GraphDef', 'pb', '.pb', True], + ['TensorFlow Lite', 'tflite', '.tflite', False], + ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False], + ['TensorFlow.js', 'tfjs', '_web_model', False]] + return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'GPU']) def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')): From 67093bfddde890fc59e37e84db4d9e63ac4b7ae5 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 14 Mar 2022 13:32:47 +0100 Subject: [PATCH 11/12] Update export.py --- export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/export.py b/export.py index 01d894ee4aae..d4f980fdb993 100644 --- a/export.py +++ b/export.py @@ -80,7 +80,7 @@ def export_formats(): ['ONNX', 'onnx', '.onnx', True], ['OpenVINO', 'openvino', '_openvino_model', False], ['TensorRT', 'engine', '.engine', True], - ['CoreML', 'coreml', '.mlmodel', True], + ['CoreML', 'coreml', '.mlmodel', False], ['TensorFlow SavedModel', 'saved_model', '_saved_model', True], ['TensorFlow GraphDef', 'pb', '.pb', True], ['TensorFlow Lite', 'tflite', '.tflite', False], From 0c1025f608756917d294831143a96ed4d255c313 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 14 Mar 2022 13:49:41 +0100 Subject: [PATCH 12/12] Update benchmarks.py --- utils/benchmarks.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/utils/benchmarks.py b/utils/benchmarks.py index ecbc32388475..bdbbdc43b639 100644 --- a/utils/benchmarks.py +++ b/utils/benchmarks.py @@ -55,8 +55,10 @@ def run(weights=ROOT / 'yolov5s.pt', # weights path y, t = [], time.time() formats = export.export_formats() device = select_device(device) - for i, (name, f, suffix) in formats.iterrows(): # index, (name, file, suffix) + for i, (name, f, suffix, gpu) in formats.iterrows(): # index, (name, file, suffix, gpu-capable) try: + if device.type != 'cpu': + assert gpu, f'{name} inference not supported on GPU' if f == '-': w = weights # PyTorch format else: