Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speed up evaluation #20

Merged
merged 3 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions segment/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def process_batch_masks(predn, pred_masks, gt_masks, labels, iouv, overlap):
mode="bilinear",
align_corners=False,
).squeeze(0)
gt_masks = gt_masks.gt_(0.5)

iou = mask_iou(
gt_masks.view(gt_masks.shape[0], -1),
Expand Down Expand Up @@ -171,7 +172,7 @@ def run(
mask_downsample_ratio=1,
compute_loss=None,
):
process = process_mask_upsample if plots else process_mask
process = process_mask_upsample if save_json else process_mask
# Initialize/load model and set device
training = model is not None
if training: # called by train.py
Expand Down Expand Up @@ -304,9 +305,6 @@ def run(
proto_out = train_out[1][si]
pred_masks = process(proto_out, pred[:, 6:], pred[:, :4],
shape=im[si].shape[1:]).permute(2, 0, 1).contiguous().float()
if plots and batch_i < 3:
# filter top 15 to plot
plot_masks.append(torch.as_tensor(pred_masks[:15], dtype=torch.uint8).cpu())

# Predictions
if single_cls:
Expand All @@ -326,6 +324,12 @@ def run(
stats.append(
(correct_masks, correct_bboxes, pred[:, 4], pred[:, 5], labels[:, 0])) # (correct, conf, pcls, tcls)

# convert pred_masks to uint8
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
if plots and batch_i < 3:
# filter top 15 to plot
plot_masks.append(pred_masks[:15].cpu())

# Save/log
if save_txt:
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
Expand All @@ -336,13 +340,6 @@ def run(

# Plot images
if plots and batch_i < 3:
if masks.shape[1:] != im.shape[2:]:
masks = F.interpolate(
masks.unsqueeze(0).float(),
im.shape[2:],
mode="bilinear",
align_corners=False,
).squeeze(0)
plot_images_and_masks(im, targets, masks, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names)
if len(plot_masks):
plot_masks = torch.cat(plot_masks, dim=0)
Expand Down
4 changes: 2 additions & 2 deletions utils/segment/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def process_mask_upsample(proto_out, out_masks, bboxes, shape):
"""

c, mh, mw = proto_out.shape # CHW
masks = (out_masks.tanh() @ proto_out.view(c, -1)).sigmoid().view(-1, mh, mw)
masks = (out_masks.tanh() @ proto_out.float().view(c, -1)).sigmoid().view(-1, mh, mw)
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
masks = crop(masks.permute(1, 2, 0).contiguous(), bboxes) # HWC
return masks.gt_(0.5)
Expand All @@ -63,7 +63,7 @@ def process_mask(proto_out, out_masks, bboxes, shape, upsample=False):

c, mh, mw = proto_out.shape # CHW
ih, iw = shape
masks = (out_masks.tanh() @ proto_out.view(c, -1)).sigmoid().view(-1, mh, mw) # CHW
masks = (out_masks.tanh() @ proto_out.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW

downsampled_bboxes = bboxes.clone()
downsampled_bboxes[:, 0] *= mw / iw
Expand Down
9 changes: 5 additions & 4 deletions utils/segment/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def plot_images_and_masks(images, targets, masks, paths=None, fname='images.jpg'
if paths:
annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
if len(targets) > 0:
j = targets[:, 0] == i
ti = targets[j] # image targets
idx = targets[:, 0] == i
ti = targets[idx] # image targets

boxes = xywh2xyxy(ti[:, 2:6]).T
classes = ti[:, 1].astype('int')
Expand Down Expand Up @@ -126,13 +126,14 @@ def plot_images_and_masks(images, targets, masks, paths=None, fname='images.jpg'
image_masks = np.repeat(image_masks, nl, axis=0)
image_masks = np.where(image_masks == index, 1.0, 0.0)
else:
image_masks = masks[j]
image_masks = masks[idx]

im = np.asarray(annotator.im).copy()
for j, box in enumerate(boxes.T.tolist()):
if labels or conf[j] > 0.25: # 0.25 conf thresh
color = colors(classes[j])
if scale < 1:
mh, mw = image_masks[j].shape
if mh != h or mw != w:
mask = image_masks[j].astype(np.uint8)
mask = cv2.resize(mask, (w, h))
mask = mask.astype(np.bool)
Expand Down