Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add segment line predictions #9571

Merged
merged 9 commits into from
Sep 25, 2022
20 changes: 12 additions & 8 deletions segment/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@
from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
increment_path, non_max_suppression, print_args, scale_boxes, scale_segments,
strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.segment.general import process_mask
from utils.segment.general import masks2segments, process_mask
from utils.torch_utils import select_device, smart_inference_mode


Expand Down Expand Up @@ -145,14 +146,16 @@ def run(
save_path = str(save_dir / p.name) # im.jpg
txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
s += '%gx%g ' % im.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
imc = im0.copy() if save_crop else im0 # for save_crop
annotator = Annotator(im0, line_width=line_thickness, example=str(names))
if len(det):
masks = process_mask(proto[i], det[:, 6:], det[:, :4], im.shape[2:], upsample=True) # HWC
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() # rescale boxes to im0 size

# Rescale boxes from img_size to im0 size
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
# Segments
if save_txt:
segments = reversed(masks2segments(masks))
segments = [scale_segments(im.shape[2:], x, im0.shape).round() for x in segments]

# Print results
for c in det[:, 5].unique():
Expand All @@ -165,17 +168,18 @@ def run(
im_gpu=None if retina_masks else im[i])

# Write results
for *xyxy, conf, cls in reversed(det[:, :6]):
for j, (*xyxy, conf, cls) in enumerate(reversed(det[:, :6])):
if save_txt: # Write to file
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
segj = segments[j].reshape(-1) # (n,2) to (n*2)
line = (cls, *segj, conf) if save_conf else (cls, *segj) # label format
with open(f'{txt_path}.txt', 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n')

if save_img or save_crop or view_img: # Add bbox to image
c = int(cls) # integer class
label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
annotator.box_label(xyxy, label, color=colors(c, True))
annotator.draw.polygon(segments[j], outline=colors(c, True), width=3)
if save_crop:
save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)

Expand Down
14 changes: 14 additions & 0 deletions utils/segment/general.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import cv2
import numpy as np
import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -118,3 +119,16 @@ def masks_iou(mask1, mask2, eps=1e-7):
intersection = (mask1 * mask2).sum(1).clamp(0) # (N, )
union = (mask1.sum(1) + mask2.sum(1))[None] - intersection # (area1 + area2) - intersection
return intersection / (union + eps)


def masks2segments(masks, strategy='largest'):
# Convert masks(n,160,160) into segments(n,xy)
segments = []
for x in masks.int().numpy().astype('uint8'):
c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
if strategy == 'concat': # concatenate all segments
c = np.concatenate([x.reshape(-1, 2) for x in c])
elif strategy == 'largest': # select largest segment
c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
segments.append(c.astype('float32'))
return segments