Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YOLOv5 Export Benchmarks for GPU #6963

Merged
merged 14 commits into from
Mar 14, 2022
24 changes: 12 additions & 12 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,18 @@

def export_formats():
# YOLOv5 export formats
x = [['PyTorch', '-', '.pt'],
['TorchScript', 'torchscript', '.torchscript'],
['ONNX', 'onnx', '.onnx'],
['OpenVINO', 'openvino', '_openvino_model'],
['TensorRT', 'engine', '.engine'],
['CoreML', 'coreml', '.mlmodel'],
['TensorFlow SavedModel', 'saved_model', '_saved_model'],
['TensorFlow GraphDef', 'pb', '.pb'],
['TensorFlow Lite', 'tflite', '.tflite'],
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite'],
['TensorFlow.js', 'tfjs', '_web_model']]
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix'])
x = [['PyTorch', '-', '.pt', True],
['TorchScript', 'torchscript', '.torchscript', True],
['ONNX', 'onnx', '.onnx', True],
['OpenVINO', 'openvino', '_openvino_model', False],
['TensorRT', 'engine', '.engine', True],
['CoreML', 'coreml', '.mlmodel', False],
['TensorFlow SavedModel', 'saved_model', '_saved_model', True],
['TensorFlow GraphDef', 'pb', '.pb', True],
['TensorFlow Lite', 'tflite', '.tflite', False],
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False],
['TensorFlow.js', 'tfjs', '_web_model', False]]
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'GPU'])


def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
Expand Down
7 changes: 4 additions & 3 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,10 +464,11 @@ def forward(self, im, augment=False, visualize=False, val=False):

def warmup(self, imgsz=(1, 3, 640, 640)):
# Warmup model by running inference once
if self.pt or self.jit or self.onnx or self.engine: # warmup types
if isinstance(self.device, torch.device) and self.device.type != 'cpu': # only warmup GPU models
if any((self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb)): # warmup types
if self.device.type != 'cpu': # only warmup GPU models
im = torch.zeros(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
self.forward(im) # warmup
for _ in range(2 if self.jit else 1): #
self.forward(im) # warmup

@staticmethod
def model_type(p='path/to/model.pt'):
Expand Down
18 changes: 15 additions & 3 deletions utils/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Requirements:
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU
$ pip install -U nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com # TensorRT

Usage:
$ python utils/benchmarks.py --weights yolov5s.pt --img 640
Expand All @@ -41,20 +42,29 @@
import val
from utils import notebook_init
from utils.general import LOGGER, print_args
from utils.torch_utils import select_device


def run(weights=ROOT / 'yolov5s.pt', # weights path
imgsz=640, # inference size (pixels)
batch_size=1, # batch size
data=ROOT / 'data/coco128.yaml', # dataset.yaml path
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
half=False, # use FP16 half-precision inference
):
y, t = [], time.time()
formats = export.export_formats()
for i, (name, f, suffix) in formats.iterrows(): # index, (name, file, suffix)
device = select_device(device)
for i, (name, f, suffix, gpu) in formats.iterrows(): # index, (name, file, suffix, gpu-capable)
try:
w = weights if f == '-' else export.run(weights=weights, imgsz=[imgsz], include=[f], device='cpu')[-1]
if device.type != 'cpu':
assert gpu, f'{name} inference not supported on GPU'
if f == '-':
w = weights # PyTorch format
else:
w = export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # all others
assert suffix in str(w), 'export failed'
result = val.run(data, w, batch_size, imgsz=imgsz, plots=False, device='cpu', task='benchmark')
result = val.run(data, w, batch_size, imgsz, plots=False, device=device, task='benchmark', half=half)
metrics = result[0] # metrics (mp, mr, map50, map, *losses(box, obj, cls))
speeds = result[2] # times (preprocess, inference, postprocess)
y.append([name, metrics[3], speeds[1]]) # mAP, t_inference
Expand All @@ -78,6 +88,8 @@ def parse_opt():
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
opt = parser.parse_args()
print_args(FILE.stem, opt)
return opt
Expand Down