Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DetectMultiBackend() --half handling #6945

Merged
merged 8 commits into from
Mar 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 4 additions & 13 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,10 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)

# Load model
device = select_device(device)
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data)
stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
stride, names, pt = model.stride, model.names, model.pt
imgsz = check_img_size(imgsz, s=stride) # check image size

# Half
half &= (pt or jit or onnx or engine) and device.type != 'cpu' # FP16 supported on limited backends with CUDA
if pt or jit:
model.model.half() if half else model.model.float()
elif engine and model.trt_fp16_input != half:
LOGGER.info('model ' + (
'requires' if model.trt_fp16_input else 'incompatible with') + ' --half. Adjusting automatically.')
half = model.trt_fp16_input

# Dataloader
if webcam:
view_img = check_imshow()
Expand All @@ -114,12 +105,12 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
vid_path, vid_writer = [None] * bs, [None] * bs

# Run inference
model.warmup(imgsz=(1 if pt else bs, 3, *imgsz), half=half) # warmup
model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup
dt, seen = [0.0, 0.0, 0.0], 0
for path, im, im0s, vid_cap, s in dataset:
t1 = time_sync()
im = torch.from_numpy(im).to(device)
im = im.half() if half else im.float() # uint8 to fp16/32
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
Expand Down
13 changes: 8 additions & 5 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def forward(self, x):

class DetectMultiBackend(nn.Module):
# YOLOv5 MultiBackend class for python inference on various backends
def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False):
# Usage:
# PyTorch: weights = *.pt
# TorchScript: *.torchscript
Expand All @@ -297,6 +297,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self.model_type(w) # get backend
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
w = attempt_download(w) # download if not local
fp16 &= (pt or jit or onnx or engine) and device.type != 'cpu' # FP16
if data: # data.yaml path (optional)
with open(data, errors='ignore') as f:
names = yaml.safe_load(f)['names'] # class names
Expand All @@ -305,11 +306,13 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
model = attempt_load(weights if isinstance(weights, list) else w, map_location=device)
stride = max(int(model.stride.max()), 32) # model stride
names = model.module.names if hasattr(model, 'module') else model.names # get class names
model.half() if fp16 else model.float()
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
elif jit: # TorchScript
LOGGER.info(f'Loading {w} for TorchScript inference...')
extra_files = {'config.txt': ''} # model metadata
model = torch.jit.load(w, _extra_files=extra_files)
model.half() if fp16 else model.float()
if extra_files['config.txt']:
d = json.loads(extra_files['config.txt']) # extra_files dict
stride, names = int(d['stride']), d['names']
Expand Down Expand Up @@ -338,19 +341,19 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
trt_fp16_input = False
logger = trt.Logger(trt.Logger.INFO)
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
model = runtime.deserialize_cuda_engine(f.read())
bindings = OrderedDict()
fp16 = False # default updated below
for index in range(model.num_bindings):
name = model.get_binding_name(index)
dtype = trt.nptype(model.get_binding_dtype(index))
shape = tuple(model.get_binding_shape(index))
data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
if model.binding_is_input(index) and dtype == np.float16:
trt_fp16_input = True
fp16 = True
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
context = model.create_execution_context()
batch_size = bindings['images'].shape[0]
Expand Down Expand Up @@ -458,11 +461,11 @@ def forward(self, im, augment=False, visualize=False, val=False):
y = torch.tensor(y) if isinstance(y, np.ndarray) else y
return (y, []) if val else y

def warmup(self, imgsz=(1, 3, 640, 640), half=False):
def warmup(self, imgsz=(1, 3, 640, 640)):
# Warmup model by running inference once
if self.pt or self.jit or self.onnx or self.engine: # warmup types
if isinstance(self.device, torch.device) and self.device.type != 'cpu': # only warmup GPU models
im = torch.zeros(*imgsz).to(self.device).type(torch.half if half else torch.float) # input image
im = torch.zeros(*imgsz).to(self.device).type(torch.half if self.fp16 else torch.float) # input image
self.forward(im) # warmup

@staticmethod
Expand Down
25 changes: 9 additions & 16 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def run(data,
training = model is not None
if training: # called by train.py
device, pt, jit, engine = next(model.parameters()).device, True, False, False # get model device, PyTorch model

half &= device.type != 'cpu' # half precision only supported on CUDA
model.half() if half else model.float()
else: # called directly
Expand All @@ -136,23 +135,17 @@ def run(data,
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir

# Load model
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data)
stride, pt, jit, onnx, engine = model.stride, model.pt, model.jit, model.onnx, model.engine
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
imgsz = check_img_size(imgsz, s=stride) # check image size
half &= (pt or jit or onnx or engine) and device.type != 'cpu' # FP16 supported on limited backends with CUDA
if pt or jit:
model.model.half() if half else model.model.float()
elif engine:
half = model.fp16 # FP16 supported on limited backends with CUDA
if engine:
batch_size = model.batch_size
if model.trt_fp16_input != half:
LOGGER.info('model ' + (
'requires' if model.trt_fp16_input else 'incompatible with') + ' --half. Adjusting automatically.')
half = model.trt_fp16_input
else:
half = False
batch_size = 1 # export.py models default to batch-size 1
device = torch.device('cpu')
LOGGER.info(f'Forcing --batch-size 1 square inference shape(1,3,{imgsz},{imgsz}) for non-PyTorch backends')
device = model.device
if not pt or jit:
batch_size = 1 # export.py models default to batch-size 1
LOGGER.info(f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')

# Data
data = check_dataset(data) # check
Expand All @@ -166,7 +159,7 @@ def run(data,

# Dataloader
if not training:
model.warmup(imgsz=(1 if pt else batch_size, 3, imgsz, imgsz), half=half) # warmup
model.warmup(imgsz=(1 if pt else batch_size, 3, imgsz, imgsz)) # warmup
pad = 0.0 if task in ('speed', 'benchmark') else 0.5
rect = False if task == 'benchmark' else pt # square inference for benchmarks
task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images
Expand Down