Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved Profile() inference timing #9024

Merged
merged 7 commits into from
Aug 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 17 additions & 20 deletions classify/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from models.common import DetectMultiBackend
from utils.augmentations import classify_transforms
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages
from utils.general import LOGGER, check_file, check_requirements, colorstr, increment_path, print_args
from utils.torch_utils import select_device, smart_inference_mode, time_sync
from utils.general import LOGGER, Profile, check_file, check_requirements, colorstr, increment_path, print_args
from utils.torch_utils import select_device, smart_inference_mode


@smart_inference_mode()
Expand All @@ -44,7 +44,7 @@ def run(
if is_url and is_file:
source = check_file(source) # download

seen, dt = 1, [0.0, 0.0, 0.0]
dt = Profile(), Profile(), Profile()
device = select_device(device)

# Directories
Expand All @@ -55,30 +55,27 @@ def run(
model = DetectMultiBackend(weights, device=device, dnn=dnn, fp16=half)
model.warmup(imgsz=(1, 3, imgsz, imgsz)) # warmup
dataset = LoadImages(source, img_size=imgsz, transforms=classify_transforms(imgsz))
for path, im, im0s, vid_cap, s in dataset:
for seen, (path, im, im0s, vid_cap, s) in enumerate(dataset):
# Image
t1 = time_sync()
im = im.unsqueeze(0).to(device)
im = im.half() if model.fp16 else im.float()
t2 = time_sync()
dt[0] += t2 - t1
with dt[0]:
im = im.unsqueeze(0).to(device)
im = im.half() if model.fp16 else im.float()

# Inference
results = model(im)
t3 = time_sync()
dt[1] += t3 - t2
with dt[1]:
results = model(im)

# Post-process
p = F.softmax(results, dim=1) # probabilities
i = p.argsort(1, descending=True)[:, :5].squeeze().tolist() # top 5 indices
dt[2] += time_sync() - t3
# if save:
# imshow_cls(im, f=save_dir / Path(path).name, verbose=True)
seen += 1
LOGGER.info(f"{s}{imgsz}x{imgsz} {', '.join(f'{model.names[j]} {p[0, j]:.2f}' for j in i)}")
with dt[2]:
p = F.softmax(results, dim=1) # probabilities
i = p.argsort(1, descending=True)[:, :5].squeeze().tolist() # top 5 indices
# if save:
# imshow_cls(im, f=save_dir / Path(path).name, verbose=True)
LOGGER.info(
f"{s}{imgsz}x{imgsz} {', '.join(f'{model.names[j]} {p[0, j]:.2f}' for j in i)}, {dt[1].dt * 1E3:.1f}ms")

# Print results
t = tuple(x / seen * 1E3 for x in dt) # speeds per image
t = tuple(x.t / (seen + 1) * 1E3 for x in dt) # speeds per image
shape = (1, 3, imgsz, imgsz)
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
Expand Down
29 changes: 13 additions & 16 deletions classify/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

from models.common import DetectMultiBackend
from utils.dataloaders import create_classification_dataloader
from utils.general import LOGGER, check_img_size, check_requirements, colorstr, increment_path, print_args
from utils.torch_utils import select_device, smart_inference_mode, time_sync
from utils.general import LOGGER, Profile, check_img_size, check_requirements, colorstr, increment_path, print_args
from utils.torch_utils import select_device, smart_inference_mode


@smart_inference_mode()
Expand Down Expand Up @@ -83,27 +83,24 @@ def run(
workers=workers)

model.eval()
pred, targets, loss, dt = [], [], 0, [0.0, 0.0, 0.0]
pred, targets, loss, dt = [], [], 0, (Profile(), Profile(), Profile())
n = len(dataloader) # number of batches
action = 'validating' if dataloader.dataset.root.stem == 'val' else 'testing'
desc = f"{pbar.desc[:-36]}{action:>36}" if pbar else f"{action}"
bar = tqdm(dataloader, desc, n, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}', position=0)
with torch.cuda.amp.autocast(enabled=device.type != 'cpu'):
for images, labels in bar:
t1 = time_sync()
images, labels = images.to(device, non_blocking=True), labels.to(device)
t2 = time_sync()
dt[0] += t2 - t1
with dt[0]:
images, labels = images.to(device, non_blocking=True), labels.to(device)

y = model(images)
t3 = time_sync()
dt[1] += t3 - t2
with dt[1]:
y = model(images)

pred.append(y.argsort(1, descending=True)[:, :5])
targets.append(labels)
if criterion:
loss += criterion(y, labels)
dt[2] += time_sync() - t3
with dt[2]:
pred.append(y.argsort(1, descending=True)[:, :5])
targets.append(labels)
if criterion:
loss += criterion(y, labels)

loss /= n
pred, targets = torch.cat(pred), torch.cat(targets)
Expand All @@ -122,7 +119,7 @@ def run(
LOGGER.info(f"{c:>24}{aci.shape[0]:>12}{top1i:>12.3g}{top5i:>12.3g}")

# Print results
t = tuple(x / len(dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image
t = tuple(x.t / len(dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image
shape = (1, 3, imgsz, imgsz)
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
Expand Down
35 changes: 16 additions & 19 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@

from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import select_device, smart_inference_mode, time_sync
from utils.torch_utils import select_device, smart_inference_mode


@smart_inference_mode()
Expand Down Expand Up @@ -107,26 +107,23 @@ def run(

# Run inference
model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup
seen, windows, dt = 0, [], [0.0, 0.0, 0.0]
seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
for path, im, im0s, vid_cap, s in dataset:
t1 = time_sync()
im = torch.from_numpy(im).to(device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
t2 = time_sync()
dt[0] += t2 - t1
with dt[0]:
im = torch.from_numpy(im).to(device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim

# Inference
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
pred = model(im, augment=augment, visualize=visualize)
t3 = time_sync()
dt[1] += t3 - t2
with dt[1]:
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
pred = model(im, augment=augment, visualize=visualize)

# NMS
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
dt[2] += time_sync() - t3
with dt[2]:
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)

# Second-stage classifier (optional)
# pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
Expand Down Expand Up @@ -201,10 +198,10 @@ def run(
vid_writer[i].write(im0)

# Print time (inference-only)
LOGGER.info(f'{s}Done. ({t3 - t2:.3f}s)')
LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")

# Print results
t = tuple(x / seen * 1E3 for x in dt) # speeds per image
t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
if save_txt or save_img:
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
Expand Down
117 changes: 59 additions & 58 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
from torch.cuda import amp

from utils.dataloaders import exif_transpose, letterbox
from utils.general import (LOGGER, ROOT, check_requirements, check_suffix, check_version, colorstr, increment_path,
make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh, yaml_load)
from utils.general import (LOGGER, ROOT, Profile, check_requirements, check_suffix, check_version, colorstr,
increment_path, make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh,
yaml_load)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import copy_attr, smart_inference_mode, time_sync
from utils.torch_utils import copy_attr, smart_inference_mode


def autopad(k, p=None): # kernel, padding
Expand Down Expand Up @@ -587,75 +588,75 @@ def _apply(self, fn):
return self

@smart_inference_mode()
def forward(self, imgs, size=640, augment=False, profile=False):
def forward(self, ims, size=640, augment=False, profile=False):
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
# file: imgs = 'data/images/zidane.jpg' # str or PosixPath
# file: ims = 'data/images/zidane.jpg' # str or PosixPath
# URI: = 'https://ultralytics.com/images/zidane.jpg'
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
# PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
# numpy: = np.zeros((640,1280,3)) # HWC
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images

t = [time_sync()]
p = next(self.model.parameters()) if self.pt else torch.zeros(1, device=self.model.device) # for device, type
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
if isinstance(imgs, torch.Tensor): # torch
with amp.autocast(autocast):
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference

# Pre-process
n, imgs = (len(imgs), list(imgs)) if isinstance(imgs, (list, tuple)) else (1, [imgs]) # number, list of images
shape0, shape1, files = [], [], [] # image and inference shapes, filenames
for i, im in enumerate(imgs):
f = f'image{i}' # filename
if isinstance(im, (str, Path)): # filename or uri
im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
im = np.asarray(exif_transpose(im))
elif isinstance(im, Image.Image): # PIL Image
im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
files.append(Path(f).with_suffix('.jpg').name)
if im.shape[0] < 5: # image in CHW
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
s = im.shape[:2] # HWC
shape0.append(s) # image shape
g = (size / max(s)) # gain
shape1.append([y * g for y in s])
imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
shape1 = [make_divisible(x, self.stride) if self.pt else size for x in np.array(shape1).max(0)] # inf shape
x = [letterbox(im, shape1, auto=False)[0] for im in imgs] # pad
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
t.append(time_sync())
dt = (Profile(), Profile(), Profile())
with dt[0]:
p = next(self.model.parameters()) if self.pt else torch.zeros(1, device=self.model.device) # param
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
if isinstance(ims, torch.Tensor): # torch
with amp.autocast(autocast):
return self.model(ims.to(p.device).type_as(p), augment, profile) # inference

# Pre-process
n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
shape0, shape1, files = [], [], [] # image and inference shapes, filenames
for i, im in enumerate(ims):
f = f'image{i}' # filename
if isinstance(im, (str, Path)): # filename or uri
im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
im = np.asarray(exif_transpose(im))
elif isinstance(im, Image.Image): # PIL Image
im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
files.append(Path(f).with_suffix('.jpg').name)
if im.shape[0] < 5: # image in CHW
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
s = im.shape[:2] # HWC
shape0.append(s) # image shape
g = (size / max(s)) # gain
shape1.append([y * g for y in s])
ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
shape1 = [make_divisible(x, self.stride) if self.pt else size for x in np.array(shape1).max(0)] # inf shape
x = [letterbox(im, shape1, auto=False)[0] for im in ims] # pad
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32

with amp.autocast(autocast):
# Inference
y = self.model(x, augment, profile) # forward
t.append(time_sync())
with dt[1]:
y = self.model(x, augment, profile) # forward

# Post-process
y = non_max_suppression(y if self.dmb else y[0],
self.conf,
self.iou,
self.classes,
self.agnostic,
self.multi_label,
max_det=self.max_det) # NMS
for i in range(n):
scale_coords(shape1, y[i][:, :4], shape0[i])
with dt[2]:
y = non_max_suppression(y if self.dmb else y[0],
self.conf,
self.iou,
self.classes,
self.agnostic,
self.multi_label,
max_det=self.max_det) # NMS
for i in range(n):
scale_coords(shape1, y[i][:, :4], shape0[i])

t.append(time_sync())
return Detections(imgs, y, files, t, self.names, x.shape)
return Detections(ims, y, files, dt, self.names, x.shape)


class Detections:
# YOLOv5 detections class for inference results
def __init__(self, imgs, pred, files, times=(0, 0, 0, 0), names=None, shape=None):
def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None):
super().__init__()
d = pred[0].device # device
gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in imgs] # normalizations
self.imgs = imgs # list of images as numpy arrays
gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations
self.ims = ims # list of images as numpy arrays
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
self.names = names # class names
self.files = files # image filenames
Expand All @@ -665,12 +666,12 @@ def __init__(self, imgs, pred, files, times=(0, 0, 0, 0), names=None, shape=None
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
self.n = len(self.pred) # number of images (batch size)
self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms)
self.s = shape # inference BCHW shape

def display(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
crops = []
for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
for i, (im, pred) in enumerate(zip(self.ims, self.pred)):
s = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
if pred.shape[0]:
for c in pred[:, -1].unique():
Expand Down Expand Up @@ -705,7 +706,7 @@ def display(self, pprint=False, show=False, save=False, crop=False, render=False
if i == self.n - 1:
LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
if render:
self.imgs[i] = np.asarray(im)
self.ims[i] = np.asarray(im)
if crop:
if save:
LOGGER.info(f'Saved results to {save_dir}\n')
Expand All @@ -728,7 +729,7 @@ def crop(self, save=True, save_dir='runs/detect/exp'):

def render(self, labels=True):
self.display(render=True, labels=labels) # render results
return self.imgs
return self.ims

def pandas(self):
# return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
Expand All @@ -743,9 +744,9 @@ def pandas(self):
def tolist(self):
# return a list of Detections objects, i.e. 'for result in results.tolist():'
r = range(self.n) # iterable
x = [Detections([self.imgs[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
x = [Detections([self.ims[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
# for d in x:
# for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
# for k in ['ims', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
# setattr(d, k, getattr(d, k)[0]) # pop out of list
return x

Expand Down
Loading