From aa1562f9e044a86388776c7604ca0eb6a592c194 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 11 Mar 2022 13:16:34 +0100 Subject: [PATCH 1/7] DetectMultiBackend() `--half` handling --- detect.py | 15 +++------------ models/common.py | 12 +++++++----- val.py | 16 ++++------------ 3 files changed, 14 insertions(+), 29 deletions(-) diff --git a/detect.py b/detect.py index ba43ed9e1eed..06c23ea342f1 100644 --- a/detect.py +++ b/detect.py @@ -89,19 +89,10 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) # Load model device = select_device(device) - model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data) + model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, half=half) stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine imgsz = check_img_size(imgsz, s=stride) # check image size - # Half - half &= (pt or jit or onnx or engine) and device.type != 'cpu' # FP16 supported on limited backends with CUDA - if pt or jit: - model.model.half() if half else model.model.float() - elif engine and model.trt_fp16_input != half: - LOGGER.info('model ' + ( - 'requires' if model.trt_fp16_input else 'incompatible with') + ' --half. Adjusting automatically.') - half = model.trt_fp16_input - # Dataloader if webcam: view_img = check_imshow() @@ -114,12 +105,12 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) vid_path, vid_writer = [None] * bs, [None] * bs # Run inference - model.warmup(imgsz=(1 if pt else bs, 3, *imgsz), half=half) # warmup + model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup dt, seen = [0.0, 0.0, 0.0], 0 for path, im, im0s, vid_cap, s in dataset: t1 = time_sync() im = torch.from_numpy(im).to(device) - im = im.half() if half else im.float() # uint8 to fp16/32 + im = im.half() if model.half else im.float() # uint8 to fp16/32 im /= 255 # 0 - 255 to 0.0 - 1.0 if len(im.shape) == 3: im = im[None] # expand for batch dim diff --git a/models/common.py b/models/common.py index 70ee7105abfc..66934edde408 100644 --- a/models/common.py +++ b/models/common.py @@ -277,7 +277,7 @@ def forward(self, x): class DetectMultiBackend(nn.Module): # YOLOv5 MultiBackend class for python inference on various backends - def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None): + def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None, half=False): # Usage: # PyTorch: weights = *.pt # TorchScript: *.torchscript @@ -297,6 +297,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None): pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self.model_type(w) # get backend stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults w = attempt_download(w) # download if not local + half &= (pt or jit or onnx or engine) and isinstance(device, torch.device) and device.type != 'cpu' # FP16 if data: # data.yaml path (optional) with open(data, errors='ignore') as f: names = yaml.safe_load(f)['names'] # class names @@ -305,11 +306,13 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None): model = attempt_load(weights if isinstance(weights, list) else w, map_location=device) stride = max(int(model.stride.max()), 32) # model stride names = model.module.names if hasattr(model, 'module') else model.names # get class names + model.half() if half else model.float() self.model = model # explicitly assign for to(), cpu(), cuda(), half() elif jit: # TorchScript LOGGER.info(f'Loading {w} for TorchScript inference...') extra_files = {'config.txt': ''} # model metadata model = torch.jit.load(w, _extra_files=extra_files) + model.half() if half else model.float() if extra_files['config.txt']: d = json.loads(extra_files['config.txt']) # extra_files dict stride, names = int(d['stride']), d['names'] @@ -338,7 +341,6 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None): import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0 Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) - trt_fp16_input = False logger = trt.Logger(trt.Logger.INFO) with open(w, 'rb') as f, trt.Runtime(logger) as runtime: model = runtime.deserialize_cuda_engine(f.read()) @@ -350,7 +352,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None): data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device) bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr())) if model.binding_is_input(index) and dtype == np.float16: - trt_fp16_input = True + half = True binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) context = model.create_execution_context() batch_size = bindings['images'].shape[0] @@ -458,11 +460,11 @@ def forward(self, im, augment=False, visualize=False, val=False): y = torch.tensor(y) if isinstance(y, np.ndarray) else y return (y, []) if val else y - def warmup(self, imgsz=(1, 3, 640, 640), half=False): + def warmup(self, imgsz=(1, 3, 640, 640)): # Warmup model by running inference once if self.pt or self.jit or self.onnx or self.engine: # warmup types if isinstance(self.device, torch.device) and self.device.type != 'cpu': # only warmup GPU models - im = torch.zeros(*imgsz).to(self.device).type(torch.half if half else torch.float) # input image + im = torch.zeros(*imgsz).to(self.device).type(torch.half if self.half else torch.float) # input image self.forward(im) # warmup @staticmethod diff --git a/val.py b/val.py index dfbfa3935210..f57f3df2727f 100644 --- a/val.py +++ b/val.py @@ -125,7 +125,6 @@ def run(data, training = model is not None if training: # called by train.py device, pt, jit, engine = next(model.parameters()).device, True, False, False # get model device, PyTorch model - half &= device.type != 'cpu' # half precision only supported on CUDA model.half() if half else model.float() else: # called directly @@ -136,20 +135,13 @@ def run(data, (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir # Load model - model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data) + model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, half=half) stride, pt, jit, onnx, engine = model.stride, model.pt, model.jit, model.onnx, model.engine imgsz = check_img_size(imgsz, s=stride) # check image size - half &= (pt or jit or onnx or engine) and device.type != 'cpu' # FP16 supported on limited backends with CUDA - if pt or jit: - model.model.half() if half else model.model.float() - elif engine: + half = model.half # FP16 supported on limited backends with CUDA + if engine: batch_size = model.batch_size - if model.trt_fp16_input != half: - LOGGER.info('model ' + ( - 'requires' if model.trt_fp16_input else 'incompatible with') + ' --half. Adjusting automatically.') - half = model.trt_fp16_input else: - half = False batch_size = 1 # export.py models default to batch-size 1 device = torch.device('cpu') LOGGER.info(f'Forcing --batch-size 1 square inference shape(1,3,{imgsz},{imgsz}) for non-PyTorch backends') @@ -166,7 +158,7 @@ def run(data, # Dataloader if not training: - model.warmup(imgsz=(1 if pt else batch_size, 3, imgsz, imgsz), half=half) # warmup + model.warmup(imgsz=(1 if pt else batch_size, 3, imgsz, imgsz)) # warmup pad = 0.0 if task in ('speed', 'benchmark') else 0.5 rect = False if task == 'benchmark' else pt # square inference for benchmarks task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images From b84a56d6e77b6ddd9d4c88d0167f58d63ea237a3 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 11 Mar 2022 13:29:09 +0100 Subject: [PATCH 2/7] CI fixes --- detect.py | 2 +- val.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/detect.py b/detect.py index 06c23ea342f1..3eb56c3f565a 100644 --- a/detect.py +++ b/detect.py @@ -90,7 +90,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) # Load model device = select_device(device) model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, half=half) - stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine + stride, names, pt = model.stride, model.names, model.pt imgsz = check_img_size(imgsz, s=stride) # check image size # Dataloader diff --git a/val.py b/val.py index f57f3df2727f..36ac8d80f4da 100644 --- a/val.py +++ b/val.py @@ -136,7 +136,7 @@ def run(data, # Load model model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, half=half) - stride, pt, jit, onnx, engine = model.stride, model.pt, model.jit, model.onnx, model.engine + stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine imgsz = check_img_size(imgsz, s=stride) # check image size half = model.half # FP16 supported on limited backends with CUDA if engine: From a2eadf88d4081de23a7063004f83f868a20fd157 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 11 Mar 2022 13:44:54 +0100 Subject: [PATCH 3/7] rename .half to .fp16 to avoid conflict --- detect.py | 4 ++-- models/common.py | 10 +++++----- val.py | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/detect.py b/detect.py index 3eb56c3f565a..ccb9fbf5103f 100644 --- a/detect.py +++ b/detect.py @@ -89,7 +89,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) # Load model device = select_device(device) - model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, half=half) + model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half) stride, names, pt = model.stride, model.names, model.pt imgsz = check_img_size(imgsz, s=stride) # check image size @@ -110,7 +110,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) for path, im, im0s, vid_cap, s in dataset: t1 = time_sync() im = torch.from_numpy(im).to(device) - im = im.half() if model.half else im.float() # uint8 to fp16/32 + im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 im /= 255 # 0 - 255 to 0.0 - 1.0 if len(im.shape) == 3: im = im[None] # expand for batch dim diff --git a/models/common.py b/models/common.py index 66934edde408..004e81027464 100644 --- a/models/common.py +++ b/models/common.py @@ -277,7 +277,7 @@ def forward(self, x): class DetectMultiBackend(nn.Module): # YOLOv5 MultiBackend class for python inference on various backends - def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None, half=False): + def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False): # Usage: # PyTorch: weights = *.pt # TorchScript: *.torchscript @@ -297,7 +297,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None, half pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self.model_type(w) # get backend stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults w = attempt_download(w) # download if not local - half &= (pt or jit or onnx or engine) and isinstance(device, torch.device) and device.type != 'cpu' # FP16 + fp16 &= (pt or jit or onnx or engine) and device.type != 'cpu' # FP16 if data: # data.yaml path (optional) with open(data, errors='ignore') as f: names = yaml.safe_load(f)['names'] # class names @@ -306,13 +306,13 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None, half model = attempt_load(weights if isinstance(weights, list) else w, map_location=device) stride = max(int(model.stride.max()), 32) # model stride names = model.module.names if hasattr(model, 'module') else model.names # get class names - model.half() if half else model.float() + model.half() if fp16 else model.float() self.model = model # explicitly assign for to(), cpu(), cuda(), half() elif jit: # TorchScript LOGGER.info(f'Loading {w} for TorchScript inference...') extra_files = {'config.txt': ''} # model metadata model = torch.jit.load(w, _extra_files=extra_files) - model.half() if half else model.float() + model.half() if fp16 else model.float() if extra_files['config.txt']: d = json.loads(extra_files['config.txt']) # extra_files dict stride, names = int(d['stride']), d['names'] @@ -352,7 +352,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None, half data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device) bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr())) if model.binding_is_input(index) and dtype == np.float16: - half = True + fp16 = True binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) context = model.create_execution_context() batch_size = bindings['images'].shape[0] diff --git a/val.py b/val.py index 36ac8d80f4da..814ddb661bc0 100644 --- a/val.py +++ b/val.py @@ -135,10 +135,10 @@ def run(data, (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir # Load model - model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, half=half) + model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half) stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine imgsz = check_img_size(imgsz, s=stride) # check image size - half = model.half # FP16 supported on limited backends with CUDA + half = model.fp16 # FP16 supported on limited backends with CUDA if engine: batch_size = model.batch_size else: From 38c62e12ccb00edf4217ef01c12b78fe1a1aeb43 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 11 Mar 2022 13:46:21 +0100 Subject: [PATCH 4/7] warmup fix --- models/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/common.py b/models/common.py index 004e81027464..064f358ca30e 100644 --- a/models/common.py +++ b/models/common.py @@ -464,7 +464,7 @@ def warmup(self, imgsz=(1, 3, 640, 640)): # Warmup model by running inference once if self.pt or self.jit or self.onnx or self.engine: # warmup types if isinstance(self.device, torch.device) and self.device.type != 'cpu': # only warmup GPU models - im = torch.zeros(*imgsz).to(self.device).type(torch.half if self.half else torch.float) # input image + im = torch.zeros(*imgsz).to(self.device).type(torch.half if self.fp16 else torch.float) # input image self.forward(im) # warmup @staticmethod From 81baba074e88fa16c99d7c4963d12c19d83e4ef0 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 11 Mar 2022 13:51:51 +0100 Subject: [PATCH 5/7] val update --- val.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/val.py b/val.py index 814ddb661bc0..64c4d4ff9dae 100644 --- a/val.py +++ b/val.py @@ -142,9 +142,10 @@ def run(data, if engine: batch_size = model.batch_size else: - batch_size = 1 # export.py models default to batch-size 1 - device = torch.device('cpu') - LOGGER.info(f'Forcing --batch-size 1 square inference shape(1,3,{imgsz},{imgsz}) for non-PyTorch backends') + device = model.device + if not pt or jit: + batch_size = 1 # export.py models default to batch-size 1 + LOGGER.info(f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models') # Data data = check_dataset(data) # check From 6eb74043a3b5e30552e1ab5826c9a09999f15007 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 11 Mar 2022 14:16:25 +0100 Subject: [PATCH 6/7] engine update --- models/common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/models/common.py b/models/common.py index 064f358ca30e..dd2bc57e99fd 100644 --- a/models/common.py +++ b/models/common.py @@ -341,6 +341,7 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0 Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) + fp16 = False # default updated below logger = trt.Logger(trt.Logger.INFO) with open(w, 'rb') as f, trt.Runtime(logger) as runtime: model = runtime.deserialize_cuda_engine(f.read()) From 604a592c9e92ef10fb4ca11b7037b5b6bb3377e5 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 11 Mar 2022 14:16:45 +0100 Subject: [PATCH 7/7] engine update --- models/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/common.py b/models/common.py index dd2bc57e99fd..c9f41c720e22 100644 --- a/models/common.py +++ b/models/common.py @@ -341,11 +341,11 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0 Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) - fp16 = False # default updated below logger = trt.Logger(trt.Logger.INFO) with open(w, 'rb') as f, trt.Runtime(logger) as runtime: model = runtime.deserialize_cuda_engine(f.read()) bindings = OrderedDict() + fp16 = False # default updated below for index in range(model.num_bindings): name = model.get_binding_name(index) dtype = trt.nptype(model.get_binding_dtype(index))