Skip to content

Commit

Permalink
Cat apriori to autolabels (ultralytics#1484)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Nov 23, 2020
1 parent 51c4f02 commit 8ec92e5
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
3 changes: 2 additions & 1 deletion detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def detect(save_img=False):
vid_writer.write(im0)

if save_txt or save_img:
print('Results saved to %s' % save_dir)
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
print(f"Results saved to {save_dir}{s}")

print('Done. (%.3fs)' % (time.time() - t0))

Expand Down
13 changes: 7 additions & 6 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,22 +101,22 @@ def test(data,
img /= 255.0 # 0 - 255 to 0.0 - 1.0
targets = targets.to(device)
nb, _, height, width = img.shape # batch size, channels, height, width
whwh = torch.Tensor([width, height, width, height]).to(device)
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device)

# Disable gradients
with torch.no_grad():
# Run model
t = time_synchronized()
inf_out, train_out = model(img, augment=augment) # inference and training outputs
t0 += time_synchronized() - t

# Compute loss
if training: # if model has loss hyperparameters
if training:
loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # box, obj, cls

# Run NMS
t = time_synchronized()
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres)
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_txt else [] # for autolabelling
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, labels=lb)
t1 += time_synchronized() - t

# Statistics per image
Expand Down Expand Up @@ -174,7 +174,7 @@ def test(data,
tcls_tensor = labels[:, 0]

# target boxes
tbox = xywh2xyxy(labels[:, 1:5]) * whwh
tbox = xywh2xyxy(labels[:, 1:5])
scale_coords(img[si].shape[1:], tbox, shapes[si][0], shapes[si][1]) # native-space labels

# Per target class
Expand Down Expand Up @@ -264,7 +264,8 @@ def test(data,

# Return results
if not training:
print('Results saved to %s' % save_dir)
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
print(f"Results saved to {save_dir}{s}")
model.float() # for training
maps = np.zeros(nc) + map
for i, c in enumerate(ap_class):
Expand Down
12 changes: 11 additions & 1 deletion utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def wh_iou(wh1, wh2):
return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)


def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False, classes=None, agnostic=False):
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, classes=None, agnostic=False, labels=()):
"""Performs Non-Maximum Suppression (NMS) on inference results
Returns:
Expand All @@ -279,6 +279,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False,
time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS

t = time.time()
output = [torch.zeros(0, 6)] * prediction.shape[0]
Expand All @@ -287,6 +288,15 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False,
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x[xc[xi]] # confidence

# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
l = labels[xi]
v = torch.zeros((len(l), nc + 5), device=x.device)
v[:, :4] = l[:, 1:5] # box
v[:, 4] = 1.0 # conf
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
x = torch.cat((x, v), 0)

# If none remain process next image
if not x.shape[0]:
continue
Expand Down

0 comments on commit 8ec92e5

Please sign in to comment.