Skip to content

Commit

Permalink
Merge pull request #1 from andreiionutdamian/feature/detect-torchscript
Browse files Browse the repository at this point in the history
update `detect.py` for torchscript support
  • Loading branch information
andreiionutdamian authored Oct 12, 2021
2 parents 4376b7f + 14a3b8f commit 0587735
Show file tree
Hide file tree
Showing 11 changed files with 75 additions and 38 deletions.
28 changes: 18 additions & 10 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import argparse
import os
import sys
from pathlib import Path

Expand All @@ -19,7 +20,7 @@
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = ROOT.relative_to(Path.cwd()) # relative
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative

from models.experimental import attempt_load
from utils.datasets import LoadImages, LoadStreams
Expand Down Expand Up @@ -55,6 +56,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
hide_labels=False, # hide labels
hide_conf=False, # hide confidences
half=False, # use FP16 half-precision inference
dnn=False, # use OpenCV DNN for ONNX inference
):
source = str(source)
save_img = not nosave and not source.endswith('.txt') # save inference images
Expand All @@ -71,7 +73,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
half &= device.type != 'cpu' # half precision only supported on CUDA

# Load model
w = weights[0] if isinstance(weights, list) else weights
w = str(weights[0] if isinstance(weights, list) else weights)
classify, suffix, suffixes = False, Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', '']
check_suffix(w, suffixes) # check weights have acceptable suffix
pt, onnx, tflite, pb, saved_model = (suffix == x for x in suffixes) # backend booleans
Expand All @@ -80,10 +82,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
if 'torchscript' in w:
# this is torchscript saved not a pickle that can be loaded with th.load
# we assume the file exist in the target folder
if Path(w).is_file():
model = torch.jit.load(w)
else:
raise ValueError('Cannot find torchscript file {}'.format(Path(w))
model = torch.jit.load(w)
else:
model = attempt_load(weights, map_location=device) # load FP32 model
stride = int(model.stride.max()) # model stride
Expand All @@ -94,9 +93,13 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
modelc = load_classifier(name='resnet50', n=2) # initialize
modelc.load_state_dict(torch.load('resnet50.pt', map_location=device)['model']).to(device).eval()
elif onnx:
check_requirements(('onnx', 'onnxruntime'))
import onnxruntime
session = onnxruntime.InferenceSession(w, None)
if dnn:
# check_requirements(('opencv-python>=4.5.4',))
net = cv2.dnn.readNetFromONNX(w)
else:
check_requirements(('onnx', 'onnxruntime'))
import onnxruntime
session = onnxruntime.InferenceSession(w, None)
else: # TensorFlow models
check_requirements(('tensorflow>=2.4.1',))
import tensorflow as tf
Expand Down Expand Up @@ -152,7 +155,11 @@ def wrap_frozen_graph(gd, inputs, outputs):
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
pred = model(img, augment=augment, visualize=visualize)[0]
elif onnx:
pred = torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: img}))
if dnn:
net.setInput(img)
pred = torch.tensor(net.forward())
else:
pred = torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: img}))
else: # tensorflow model (tflite, pb, saved_model)
imn = img.permute(0, 2, 3, 1).cpu().numpy() # image in numpy
if pb:
Expand Down Expand Up @@ -288,6 +295,7 @@ def parse_opt():
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
opt = parser.parse_args()
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
print_args(FILE.stem, opt)
Expand Down
3 changes: 2 additions & 1 deletion export.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""

import argparse
import os
import subprocess
import sys
import time
Expand All @@ -34,7 +35,7 @@
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = ROOT.relative_to(Path.cwd()) # relative
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative

from models.common import Conv
from models.experimental import attempt_load
Expand Down
10 changes: 10 additions & 0 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,16 @@ def autoshape(self):
LOGGER.info('AutoShape already enabled, skipping... ') # model already converted to model.autoshape()
return self

def _apply(self, fn):
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
self = super()._apply(fn)
m = self.model.model[-1] # Detect()
m.stride = fn(m.stride)
m.grid = list(map(fn, m.grid))
if isinstance(m.anchor_grid, list):
m.anchor_grid = list(map(fn, m.anchor_grid))
return self

@torch.no_grad()
def forward(self, imgs, size=640, augment=False, profile=False):
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
Expand Down
4 changes: 4 additions & 0 deletions models/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ def attempt_load(weights, map_location=None, inplace=True, fuse=True):
for m in model.modules():
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
m.inplace = inplace # pytorch 1.7.0 compatibility
if type(m) is Detect:
if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility
delattr(m, 'anchor_grid')
setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
elif type(m) is Conv:
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility

Expand Down
2 changes: 1 addition & 1 deletion models/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detec
self.na = len(anchors[0]) // 2 # number of anchors
self.grid = [tf.zeros(1)] * self.nl # init grid
self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
self.anchor_grid = tf.reshape(tf.convert_to_tensor(w.anchor_grid.numpy(), dtype=tf.float32),
self.anchor_grid = tf.reshape(self.anchors * tf.reshape(self.stride, [self.nl, 1, 1]),
[self.nl, 1, -1, 1, 2])
self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
self.training = False # set to False after building model
Expand Down
31 changes: 22 additions & 9 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
self.nl = len(anchors) # number of detection layers
self.na = len(anchors[0]) // 2 # number of anchors
self.grid = [torch.zeros(1)] * self.nl # init grid
a = torch.tensor(anchors).float().view(self.nl, -1, 2)
self.register_buffer('anchors', a) # shape(nl,na,2)
self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid
self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
self.inplace = inplace # use in-place ops (e.g. slice assignment)

Expand All @@ -59,24 +58,27 @@ def forward(self, x):

if not self.training: # inference
if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic:
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)

y = x[i].sigmoid()
if self.inplace:
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
y = torch.cat((xy, wh, y[..., 4:]), -1)
z.append(y.view(bs, -1, self.no))

return x if self.training else (torch.cat(z, 1), x)

@staticmethod
def _make_grid(nx=20, ny=20):
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
def _make_grid(self, nx=20, ny=20, i=0):
d = self.anchors[i].device
yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)])
grid = torch.stack((xv, yv), 2).expand((1, self.na, ny, nx, 2)).float()
anchor_grid = (self.anchors[i].clone() * self.stride[i]) \
.view((1, self.na, 1, 1, 2)).expand((1, self.na, ny, nx, 2)).float()
return grid, anchor_grid


class Model(nn.Module):
Expand Down Expand Up @@ -232,6 +234,17 @@ def autoshape(self): # add AutoShape module
def info(self, verbose=False, img_size=640): # print model information
model_info(self, verbose, img_size)

def _apply(self, fn):
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
self = super()._apply(fn)
m = self.model[-1] # Detect()
if isinstance(m, Detect):
m.stride = fn(m.stride)
m.grid = list(map(fn, m.grid))
if isinstance(m.anchor_grid, list):
m.anchor_grid = list(map(fn, m.anchor_grid))
return self


def parse_model(d, ch): # model_dict, input_channels(3)
LOGGER.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ numpy>=1.18.5
opencv-python>=4.1.2
Pillow>=7.1.2
PyYAML>=5.3.1
requests>=2.23.0
scipy>=1.4.1
torch>=1.7.0
torchvision>=0.8.1
Expand All @@ -16,7 +17,7 @@ tensorboard>=2.4.1
# wandb

# Plotting ------------------------------------
pandas
pandas>=1.1.4
seaborn>=0.11.0

# Export --------------------------------------
Expand Down
8 changes: 4 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = ROOT.relative_to(Path.cwd()) # relative
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative

import val # for end-of-epoch mAP
from models.experimental import attempt_load
Expand Down Expand Up @@ -99,7 +99,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
plots = not evolve # create plots
cuda = device.type != 'cpu'
init_seeds(1 + RANK)
with torch_distributed_zero_first(RANK):
with torch_distributed_zero_first(LOCAL_RANK):
data_dict = data_dict or check_dataset(data) # check if None
train_path, val_path = data_dict['train'], data_dict['val']
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
Expand All @@ -111,7 +111,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
check_suffix(weights, '.pt') # check weights
pretrained = weights.endswith('.pt')
if pretrained:
with torch_distributed_zero_first(RANK):
with torch_distributed_zero_first(LOCAL_RANK):
weights = attempt_download(weights) # download if not found locally
ckpt = torch.load(weights, map_location=device) # load checkpoint
model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
Expand Down Expand Up @@ -208,7 +208,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary

# Trainloader
train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=RANK,
hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK,
workers=workers, image_weights=opt.image_weights, quad=opt.quad,
prefix=colorstr('train: '))
mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
Expand Down
10 changes: 4 additions & 6 deletions utils/autoanchor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@

def check_anchor_order(m):
# Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
a = m.anchor_grid.prod(-1).view(-1) # anchor area
a = m.anchors.prod(-1).view(-1) # anchor area
da = a[-1] - a[0] # delta a
ds = m.stride[-1] - m.stride[0] # delta s
if da.sign() != ds.sign(): # same order
print('Reversing anchor order')
m.anchors[:] = m.anchors.flip(0)
m.anchor_grid[:] = m.anchor_grid.flip(0)


def check_anchors(dataset, model, thr=4.0, imgsz=640):
Expand All @@ -41,20 +40,19 @@ def metric(k): # compute metric
bpr = (best > 1. / thr).float().mean() # best possible recall
return bpr, aat

anchors = m.anchor_grid.clone().cpu().view(-1, 2) # current anchors
bpr, aat = metric(anchors)
anchors = m.anchors.clone() * m.stride.to(m.anchors.device).view(-1, 1, 1) # current anchors
bpr, aat = metric(anchors.cpu().view(-1, 2))
print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='')
if bpr < 0.98: # threshold to recompute
print('. Attempting to improve anchors, please wait...')
na = m.anchor_grid.numel() // 2 # number of anchors
na = m.anchors.numel() // 2 # number of anchors
try:
anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
except Exception as e:
print(f'{prefix}ERROR: {e}')
new_bpr = metric(anchors)[0]
if new_bpr > bpr: # replace anchors
anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
m.anchor_grid[:] = anchors.clone().view_as(m.anchor_grid) # for inference
m.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
check_anchor_order(m)
print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.')
Expand Down
4 changes: 2 additions & 2 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_
ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)[1].ravel()

fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
# for f in [Path(path) / f'study_coco_{x}.txt' for x in ['yolov5s6', 'yolov5m6', 'yolov5l6', 'yolov5x6']]:
# for f in [save_dir / f'study_coco_{x}.txt' for x in ['yolov5n6', 'yolov5s6', 'yolov5m6', 'yolov5l6', 'yolov5x6']]:
for f in sorted(save_dir.glob('study*.txt')):
y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
x = np.arange(y.shape[1]) if x is None else np.array(x)
Expand All @@ -284,7 +284,7 @@ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_
ax2.grid(alpha=0.2)
ax2.set_yticks(np.arange(20, 60, 5))
ax2.set_xlim(0, 57)
ax2.set_ylim(30, 55)
ax2.set_ylim(25, 55)
ax2.set_xlabel('GPU Speed (ms/img)')
ax2.set_ylabel('COCO AP val')
ax2.legend(loc='lower right')
Expand Down
10 changes: 6 additions & 4 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = ROOT.relative_to(Path.cwd()) # relative
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative

from models.experimental import attempt_load
from utils.datasets import create_dataloader
Expand Down Expand Up @@ -147,8 +147,9 @@ def run(data,
if not training:
if device.type != 'cpu':
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
pad = 0.0 if task == 'speed' else 0.5
task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images
dataloader = create_dataloader(data[task], imgsz, batch_size, gs, single_cls, pad=0.5, rect=True,
dataloader = create_dataloader(data[task], imgsz, batch_size, gs, single_cls, pad=pad, rect=True,
prefix=colorstr(f'{task}: '))[0]

seen = 0
Expand Down Expand Up @@ -333,20 +334,21 @@ def main(opt):
run(**vars(opt))

elif opt.task == 'speed': # speed benchmarks
# python val.py --task speed --data coco.yaml --batch 1 --weights yolov5n.pt yolov5s.pt...
for w in opt.weights if isinstance(opt.weights, list) else [opt.weights]:
run(opt.data, weights=w, batch_size=opt.batch_size, imgsz=opt.imgsz, conf_thres=.25, iou_thres=.45,
device=opt.device, save_json=False, plots=False)

elif opt.task == 'study': # run over a range of settings and save/plot
# python val.py --task study --data coco.yaml --iou 0.7 --weights yolov5s.pt yolov5m.pt yolov5l.pt yolov5x.pt
# python val.py --task study --data coco.yaml --iou 0.7 --weights yolov5n.pt yolov5s.pt...
x = list(range(256, 1536 + 128, 128)) # x axis (image sizes)
for w in opt.weights if isinstance(opt.weights, list) else [opt.weights]:
f = f'study_{Path(opt.data).stem}_{Path(w).stem}.txt' # filename to save to
y = [] # y axis
for i in x: # img-size
print(f'\nRunning {f} point {i}...')
r, _, t = run(opt.data, weights=w, batch_size=opt.batch_size, imgsz=i, conf_thres=opt.conf_thres,
iou_thres=opt.iou_thres, save_json=opt.save_json, plots=False)
iou_thres=opt.iou_thres, device=opt.device, save_json=opt.save_json, plots=False)
y.append(r + t) # results and times
np.savetxt(f, y, fmt='%10.4g') # save
os.system('zip -r study.zip study_*.txt')
Expand Down

0 comments on commit 0587735

Please sign in to comment.