Skip to content

Commit

Permalink
Merge pull request #26 from Laughing-q/instance_seg
Browse files Browse the repository at this point in the history
update Annotator.masks()
  • Loading branch information
AyushExel committed Sep 5, 2022
2 parents 29c03da + d53c825 commit 7cdda21
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 40 deletions.
13 changes: 7 additions & 6 deletions segment/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.segment.general import process_mask, scale_image
from utils.segment.plots import plot_masks
from utils.segment.general import process_mask
from utils.torch_utils import select_device, smart_inference_mode


Expand Down Expand Up @@ -77,6 +76,7 @@ def run(
half=False, # use FP16 half-precision inference
dnn=False, # use OpenCV DNN for ONNX inference
vid_stride=1, # video frame-rate stride
retina_masks=False,
):
source = str(source)
save_img = not nosave and not source.endswith('.txt') # save inference images
Expand Down Expand Up @@ -157,8 +157,7 @@ def run(
s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string

# Mask plotting
im_masks = plot_masks(im[i], masks, colors=[colors(x, True) for x in det[:, 5]]) # shape(imh,imw,3)
annotator.im = scale_image(im.shape[2:], im_masks, im0.shape) # scale to original h, w
annotator.masks(masks, colors=[colors(x, True) for x in det[:, 5]], img_gpu=None if retina_masks else im[i])

# Write results
for *xyxy, conf, cls in reversed(det[:, :6]):
Expand All @@ -183,7 +182,8 @@ def run(
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
cv2.imshow(str(p), im0)
cv2.waitKey(1) # 1 millisecond
if cv2.waitKey(1) == ord('q'): # 1 millisecond
exit()

# Save results (image with detections)
if save_img:
Expand All @@ -205,7 +205,7 @@ def run(
vid_writer[i].write(im0)

# Print time (inference-only)
LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
# LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")

# Print results
t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
Expand Down Expand Up @@ -246,6 +246,7 @@ def parse_opt():
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
parser.add_argument('--retina-masks', action='store_true', help='whether to plot masks in native resolution')
opt = parser.parse_args()
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
print_args(vars(opt))
Expand Down
52 changes: 45 additions & 7 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from utils import TryExcept, threaded
from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_coords, increment_path,
is_ascii, xywh2xyxy, xyxy2xywh)
from utils.segment.general import scale_image
from utils.metrics import fitness

# Settings
Expand Down Expand Up @@ -113,14 +114,51 @@ def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 2
thickness=tf,
lineType=cv2.LINE_AA)

def masks(self, masks, colors, alpha=0.5, eps=1e-7):
# Add multiple masks of shape(h,w,n) with colors list([r,g,b], [r,g,b], ...)
if len(masks):
masks = masks.astype(np.float32) / 255.0 # shape(h,w,n)
colors = np.array(colors, dtype=np.uint8) # shape(n,3)
s = masks.sum(2, keepdims=True)
masks = masks @ colors / (s + eps) # (h,w,n) @ (n,3) = (h,w,3)
def masks(self, masks, colors, img_gpu=None, alpha=0.5):
"""Plot masks at once.
Args:
masks (tensor): predicted masks on cuda, shape: [n, h, w]
colors (List[List[Int]]): colors for predicted masks, [[r, g, b] * n]
img_gpu (tensor): img is in cuda, shape: [3, h, w], range: [0, 1]
retina_masks (bool): whether to plot masks in native resolution.
"""
if self.pil:
# convert to numpy first
self.im = np.asarray(self.im).copy()
if img_gpu is None:
# Add multiple masks of shape(h,w,n) with colors list([r,g,b], [r,g,b], ...)
if len(masks) == 0:
return
if isinstance(masks, torch.Tensor):
masks = torch.as_tensor(masks, dtype=torch.uint8)
masks = masks.permute(1, 2, 0).contiguous()
masks = masks.cpu().numpy()
# masks = np.ascontiguousarray(masks.transpose(1, 2, 0))
masks = scale_image(masks.shape[:2], masks, self.im.shape)
masks = np.asarray(masks, dtype=np.float32)
colors = np.asarray(colors, dtype=np.float32) # shape(n,3)
s = masks.sum(2, keepdims=True).clip(0, 1) # add all masks together
masks = (masks @ colors).clip(0, 255) # (h,w,n) @ (n,3) = (h,w,3)
self.im[:] = masks * alpha + self.im * (1 - s * alpha)
else:
if len(masks) == 0:
self.im[:] = img_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
colors = torch.tensor(colors, device=img_gpu.device, dtype=torch.float32) / 255.0
colors = colors[:, None, None] # shape(n,1,1,3)
masks = masks.unsqueeze(3) # shape(n,h,w,1)
masks_color = masks * (colors * alpha) # shape(n,h,w,3)

inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
mcs = (masks_color * inv_alph_masks).sum(0) * 2 # mask color summand shape(n,h,w,3)

img_gpu = img_gpu.flip(dims=[0]) # flip channel
img_gpu = img_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
img_gpu = img_gpu * inv_alph_masks[-1] + mcs
im_mask = (img_gpu * 255).byte().cpu().numpy()
self.im[:] = scale_image(img_gpu.shape, im_mask, self.im.shape)
if self.pil:
# convert im back to PIL and update draw
self.fromarray(self.im)

def rectangle(self, xy, fill=None, outline=None, width=1):
# Add rectangle to image (PIL-only)
Expand Down
27 changes: 0 additions & 27 deletions utils/segment/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,6 @@
from ..plots import Annotator, colors


def plot_masks(im, masks, colors, alpha=0.5):
"""
Args:
im (tensor): img is in cuda, shape: [3, h, w], range: [0, 1]
masks (tensor): predicted masks on cuda, shape: [n, h, w]
colors (List[List[Int]]): colors for predicted masks, [[r, g, b] * n]
Return:
ndarray: img after draw masks, shape: [h, w, 3]
"""
if len(masks) == 0:
return im.permute(1, 2, 0).contiguous().cpu().numpy() * 255

colors = torch.tensor(colors, device=im.device).float() / 255.0
colors = colors[:, None, None] # shape(n,1,1,3)
masks = masks.unsqueeze(3) # shape(n,h,w,1)
masks_color = masks * (colors * alpha) # shape(n,h,w,3)

inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
mcs = (masks_color * inv_alph_masks).sum(0) * 2 # mask color summand shape(n,h,w,3)

im = im.flip(dims=[0]) # flip channel
im = im.permute(1, 2, 0).contiguous() # shape(h,w,3)
im = im * inv_alph_masks[-1] + mcs
return (im * 255).byte().cpu().numpy()


@threaded
def plot_images_and_masks(images, targets, masks, paths=None, fname='images.jpg', names=None):
# Plot image grid with labels
Expand Down

0 comments on commit 7cdda21

Please sign in to comment.