Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update Annotator.masks() #26

Merged
merged 5 commits into from
Sep 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions segment/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.segment.general import process_mask, scale_image
from utils.segment.plots import plot_masks
from utils.segment.general import process_mask
from utils.torch_utils import select_device, smart_inference_mode


Expand Down Expand Up @@ -77,6 +76,7 @@ def run(
half=False, # use FP16 half-precision inference
dnn=False, # use OpenCV DNN for ONNX inference
vid_stride=1, # video frame-rate stride
retina_masks=False,
):
source = str(source)
save_img = not nosave and not source.endswith('.txt') # save inference images
Expand Down Expand Up @@ -157,8 +157,7 @@ def run(
s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string

# Mask plotting
im_masks = plot_masks(im[i], masks, colors=[colors(x, True) for x in det[:, 5]]) # shape(imh,imw,3)
annotator.im = scale_image(im.shape[2:], im_masks, im0.shape) # scale to original h, w
annotator.masks(masks, colors=[colors(x, True) for x in det[:, 5]], img_gpu=None if retina_masks else im[i])

# Write results
for *xyxy, conf, cls in reversed(det[:, :6]):
Expand All @@ -183,7 +182,8 @@ def run(
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
cv2.imshow(str(p), im0)
cv2.waitKey(1) # 1 millisecond
if cv2.waitKey(1) == ord('q'): # 1 millisecond
exit()

# Save results (image with detections)
if save_img:
Expand All @@ -205,7 +205,7 @@ def run(
vid_writer[i].write(im0)

# Print time (inference-only)
LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
# LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")

# Print results
t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
Expand Down Expand Up @@ -246,6 +246,7 @@ def parse_opt():
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
parser.add_argument('--retina-masks', action='store_true', help='whether to plot masks in native resolution')
opt = parser.parse_args()
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
print_args(vars(opt))
Expand Down
52 changes: 45 additions & 7 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from utils import TryExcept, threaded
from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_coords, increment_path,
is_ascii, xywh2xyxy, xyxy2xywh)
from utils.segment.general import scale_image
from utils.metrics import fitness

# Settings
Expand Down Expand Up @@ -113,14 +114,51 @@ def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 2
thickness=tf,
lineType=cv2.LINE_AA)

def masks(self, masks, colors, alpha=0.5, eps=1e-7):
# Add multiple masks of shape(h,w,n) with colors list([r,g,b], [r,g,b], ...)
if len(masks):
masks = masks.astype(np.float32) / 255.0 # shape(h,w,n)
colors = np.array(colors, dtype=np.uint8) # shape(n,3)
s = masks.sum(2, keepdims=True)
masks = masks @ colors / (s + eps) # (h,w,n) @ (n,3) = (h,w,3)
def masks(self, masks, colors, img_gpu=None, alpha=0.5):
"""Plot masks at once.
Args:
masks (tensor): predicted masks on cuda, shape: [n, h, w]
colors (List[List[Int]]): colors for predicted masks, [[r, g, b] * n]
img_gpu (tensor): img is in cuda, shape: [3, h, w], range: [0, 1]
retina_masks (bool): whether to plot masks in native resolution.
"""
if self.pil:
# convert to numpy first
self.im = np.asarray(self.im).copy()
if img_gpu is None:
# Add multiple masks of shape(h,w,n) with colors list([r,g,b], [r,g,b], ...)
if len(masks) == 0:
return
if isinstance(masks, torch.Tensor):
masks = torch.as_tensor(masks, dtype=torch.uint8)
masks = masks.permute(1, 2, 0).contiguous()
masks = masks.cpu().numpy()
# masks = np.ascontiguousarray(masks.transpose(1, 2, 0))
masks = scale_image(masks.shape[:2], masks, self.im.shape)
masks = np.asarray(masks, dtype=np.float32)
colors = np.asarray(colors, dtype=np.float32) # shape(n,3)
s = masks.sum(2, keepdims=True).clip(0, 1) # add all masks together
masks = (masks @ colors).clip(0, 255) # (h,w,n) @ (n,3) = (h,w,3)
self.im[:] = masks * alpha + self.im * (1 - s * alpha)
else:
if len(masks) == 0:
self.im[:] = img_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
colors = torch.tensor(colors, device=img_gpu.device, dtype=torch.float32) / 255.0
colors = colors[:, None, None] # shape(n,1,1,3)
masks = masks.unsqueeze(3) # shape(n,h,w,1)
masks_color = masks * (colors * alpha) # shape(n,h,w,3)

inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
mcs = (masks_color * inv_alph_masks).sum(0) * 2 # mask color summand shape(n,h,w,3)

img_gpu = img_gpu.flip(dims=[0]) # flip channel
img_gpu = img_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
img_gpu = img_gpu * inv_alph_masks[-1] + mcs
im_mask = (img_gpu * 255).byte().cpu().numpy()
self.im[:] = scale_image(img_gpu.shape, im_mask, self.im.shape)
if self.pil:
# convert im back to PIL and update draw
self.fromarray(self.im)

def rectangle(self, xy, fill=None, outline=None, width=1):
# Add rectangle to image (PIL-only)
Expand Down
27 changes: 0 additions & 27 deletions utils/segment/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,6 @@
from ..plots import Annotator, colors


def plot_masks(im, masks, colors, alpha=0.5):
"""
Args:
im (tensor): img is in cuda, shape: [3, h, w], range: [0, 1]
masks (tensor): predicted masks on cuda, shape: [n, h, w]
colors (List[List[Int]]): colors for predicted masks, [[r, g, b] * n]
Return:
ndarray: img after draw masks, shape: [h, w, 3]

"""
if len(masks) == 0:
return im.permute(1, 2, 0).contiguous().cpu().numpy() * 255

colors = torch.tensor(colors, device=im.device).float() / 255.0
colors = colors[:, None, None] # shape(n,1,1,3)
masks = masks.unsqueeze(3) # shape(n,h,w,1)
masks_color = masks * (colors * alpha) # shape(n,h,w,3)

inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
mcs = (masks_color * inv_alph_masks).sum(0) * 2 # mask color summand shape(n,h,w,3)

im = im.flip(dims=[0]) # flip channel
im = im.permute(1, 2, 0).contiguous() # shape(h,w,3)
im = im * inv_alph_masks[-1] + mcs
return (im * 255).byte().cpu().numpy()


@threaded
def plot_images_and_masks(images, targets, masks, paths=None, fname='images.jpg', names=None):
# Plot image grid with labels
Expand Down