Skip to content

Commit

Permalink
Merge pull request #20 from Laughing-q/instance_seg
Browse files Browse the repository at this point in the history
speed up evaluation
  • Loading branch information
AyushExel committed Aug 24, 2022
2 parents be5a244 + 61212a6 commit 2eb1a71
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 17 deletions.
19 changes: 8 additions & 11 deletions segment/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def process_batch_masks(predn, pred_masks, gt_masks, labels, iouv, overlap):
mode="bilinear",
align_corners=False,
).squeeze(0)
gt_masks = gt_masks.gt_(0.5)

iou = mask_iou(
gt_masks.view(gt_masks.shape[0], -1),
Expand Down Expand Up @@ -171,7 +172,7 @@ def run(
mask_downsample_ratio=1,
compute_loss=None,
):
process = process_mask_upsample if plots else process_mask
process = process_mask_upsample if save_json else process_mask
# Initialize/load model and set device
training = model is not None
if training: # called by train.py
Expand Down Expand Up @@ -304,9 +305,6 @@ def run(
proto_out = train_out[1][si]
pred_masks = process(proto_out, pred[:, 6:], pred[:, :4],
shape=im[si].shape[1:]).permute(2, 0, 1).contiguous().float()
if plots and batch_i < 3:
# filter top 15 to plot
plot_masks.append(torch.as_tensor(pred_masks[:15], dtype=torch.uint8).cpu())

# Predictions
if single_cls:
Expand All @@ -326,6 +324,12 @@ def run(
stats.append(
(correct_masks, correct_bboxes, pred[:, 4], pred[:, 5], labels[:, 0])) # (correct, conf, pcls, tcls)

# convert pred_masks to uint8
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
if plots and batch_i < 3:
# filter top 15 to plot
plot_masks.append(pred_masks[:15].cpu())

# Save/log
if save_txt:
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
Expand All @@ -336,13 +340,6 @@ def run(

# Plot images
if plots and batch_i < 3:
if masks.shape[1:] != im.shape[2:]:
masks = F.interpolate(
masks.unsqueeze(0).float(),
im.shape[2:],
mode="bilinear",
align_corners=False,
).squeeze(0)
plot_images_and_masks(im, targets, masks, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names)
if len(plot_masks):
plot_masks = torch.cat(plot_masks, dim=0)
Expand Down
4 changes: 2 additions & 2 deletions utils/segment/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def process_mask_upsample(proto_out, out_masks, bboxes, shape):
"""

c, mh, mw = proto_out.shape # CHW
masks = (out_masks.tanh() @ proto_out.view(c, -1)).sigmoid().view(-1, mh, mw)
masks = (out_masks.tanh() @ proto_out.float().view(c, -1)).sigmoid().view(-1, mh, mw)
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
masks = crop(masks.permute(1, 2, 0).contiguous(), bboxes) # HWC
return masks.gt_(0.5)
Expand All @@ -63,7 +63,7 @@ def process_mask(proto_out, out_masks, bboxes, shape, upsample=False):

c, mh, mw = proto_out.shape # CHW
ih, iw = shape
masks = (out_masks.tanh() @ proto_out.view(c, -1)).sigmoid().view(-1, mh, mw) # CHW
masks = (out_masks.tanh() @ proto_out.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW

downsampled_bboxes = bboxes.clone()
downsampled_bboxes[:, 0] *= mw / iw
Expand Down
9 changes: 5 additions & 4 deletions utils/segment/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def plot_images_and_masks(images, targets, masks, paths=None, fname='images.jpg'
if paths:
annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
if len(targets) > 0:
j = targets[:, 0] == i
ti = targets[j] # image targets
idx = targets[:, 0] == i
ti = targets[idx] # image targets

boxes = xywh2xyxy(ti[:, 2:6]).T
classes = ti[:, 1].astype('int')
Expand Down Expand Up @@ -126,13 +126,14 @@ def plot_images_and_masks(images, targets, masks, paths=None, fname='images.jpg'
image_masks = np.repeat(image_masks, nl, axis=0)
image_masks = np.where(image_masks == index, 1.0, 0.0)
else:
image_masks = masks[j]
image_masks = masks[idx]

im = np.asarray(annotator.im).copy()
for j, box in enumerate(boxes.T.tolist()):
if labels or conf[j] > 0.25: # 0.25 conf thresh
color = colors(classes[j])
if scale < 1:
mh, mw = image_masks[j].shape
if mh != h or mw != w:
mask = image_masks[j].astype(np.uint8)
mask = cv2.resize(mask, (w, h))
mask = mask.astype(np.bool)
Expand Down

0 comments on commit 2eb1a71

Please sign in to comment.