Skip to content

Commit

Permalink
Add --tfl-detect for TFLite Detection
Browse files Browse the repository at this point in the history
  • Loading branch information
zldrobit committed Oct 12, 2020
1 parent 3d9a54d commit 1ab4883
Showing 1 changed file with 36 additions and 3 deletions.
39 changes: 36 additions & 3 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np

from models.experimental import attempt_load
from models.yolo import Detect
from utils.datasets import LoadStreams, LoadImages
from utils.general import (
check_img_size, non_max_suppression, apply_classifier, scale_coords,
Expand Down Expand Up @@ -195,9 +196,39 @@ def _imports_graph_def():
input_data = img.permute(0, 2, 3, 1).cpu().numpy()
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
pred = torch.tensor(output_data)

if not opt.tfl_detect:
output_data = interpreter.get_tensor(output_details[0]['index'])
pred = torch.tensor(output_data)
else:
import yaml
yaml_file = Path(opt.cfg).name
with open(opt.cfg) as f:
yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict

anchors = yaml['anchors']
nc = yaml['nc']
nl = len(anchors)
x = [torch.tensor(interpreter.get_tensor(output_details[i]['index']), device=device) for i in range(nl)]
def _make_grid(nx=20, ny=20):
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
return torch.stack((xv, yv), 2).view((1, 1, ny * nx, 2)).float()

no = nc + 5
grid = [torch.zeros(1)] * nl # init grid
a = torch.tensor(anchors).float().view(nl, -1, 2).to(device)
anchor_grid = a.clone().view(nl, 1, -1, 1, 2) # shape(nl,1,na,1,2)
z = [] # inference output
for i in range(nl):
_, _, ny_nx, _ = x[i].shape
nx = ny = int(np.sqrt(ny_nx))
grid[i] = _make_grid(nx, ny).to(x[i].device)
stride = imgsz // ny
y = x[i].sigmoid()
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + grid[i].to(x[i].device)) * stride # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * anchor_grid[i] # wh
z.append(y.view(-1, no))

pred = torch.unsqueeze(torch.cat(z, 0), 0)

# Apply NMS
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
Expand Down Expand Up @@ -285,6 +316,8 @@ def _imports_graph_def():
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
parser.add_argument('--augment', action='store_true', help='augmented inference')
parser.add_argument('--update', action='store_true', help='update all models')
parser.add_argument('--tfl-detect', action='store_true', help='add Detect module in TFLite')
parser.add_argument('--cfg', type=str, default='./models/yolov5s.yaml', help='cfg path')
opt = parser.parse_args()
print(opt)

Expand Down

0 comments on commit 1ab4883

Please sign in to comment.