From db6ad831dc1a6189a5c9bbfea7293ea4d5d9c3f1 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 24 Sep 2022 17:47:19 +0200 Subject: [PATCH 1/5] Add segment line predictions Signed-off-by: Glenn Jocher --- segment/predict.py | 78 ++++++++++++++++++++++++++-------------------- 1 file changed, 44 insertions(+), 34 deletions(-) diff --git a/segment/predict.py b/segment/predict.py index 2241204715b5..ff20cdc52141 100644 --- a/segment/predict.py +++ b/segment/predict.py @@ -32,6 +32,7 @@ from pathlib import Path import torch +import numpy as np FILE = Path(__file__).resolve() ROOT = FILE.parents[1] # YOLOv5 root directory @@ -42,7 +43,8 @@ from models.common import DetectMultiBackend from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2, - increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh) + increment_path, non_max_suppression, print_args, scale_boxes, scale_segments, + strip_optimizer, xyxy2xywh) from utils.plots import Annotator, colors, save_one_box from utils.segment.general import process_mask from utils.torch_utils import select_device, smart_inference_mode @@ -50,34 +52,34 @@ @smart_inference_mode() def run( - weights=ROOT / 'yolov5s-seg.pt', # model.pt path(s) - source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam) - data=ROOT / 'data/coco128.yaml', # dataset.yaml path - imgsz=(640, 640), # inference size (height, width) - conf_thres=0.25, # confidence threshold - iou_thres=0.45, # NMS IOU threshold - max_det=1000, # maximum detections per image - device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu - view_img=False, # show results - save_txt=False, # save results to *.txt - save_conf=False, # save confidences in --save-txt labels - save_crop=False, # save cropped prediction boxes - nosave=False, # do not save images/videos - classes=None, # filter by class: --class 0, or --class 0 2 3 - agnostic_nms=False, # class-agnostic NMS - augment=False, # augmented inference - visualize=False, # visualize features - update=False, # update all models - project=ROOT / 'runs/predict-seg', # save results to project/name - name='exp', # save results to project/name - exist_ok=False, # existing project/name ok, do not increment - line_thickness=3, # bounding box thickness (pixels) - hide_labels=False, # hide labels - hide_conf=False, # hide confidences - half=False, # use FP16 half-precision inference - dnn=False, # use OpenCV DNN for ONNX inference - vid_stride=1, # video frame-rate stride - retina_masks=False, + weights=ROOT / 'yolov5s-seg.pt', # model.pt path(s) + source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam) + data=ROOT / 'data/coco128.yaml', # dataset.yaml path + imgsz=(640, 640), # inference size (height, width) + conf_thres=0.25, # confidence threshold + iou_thres=0.45, # NMS IOU threshold + max_det=1000, # maximum detections per image + device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu + view_img=False, # show results + save_txt=False, # save results to *.txt + save_conf=False, # save confidences in --save-txt labels + save_crop=False, # save cropped prediction boxes + nosave=False, # do not save images/videos + classes=None, # filter by class: --class 0, or --class 0 2 3 + agnostic_nms=False, # class-agnostic NMS + augment=False, # augmented inference + visualize=False, # visualize features + update=False, # update all models + project=ROOT / 'runs/predict-seg', # save results to project/name + name='exp', # save results to project/name + exist_ok=False, # existing project/name ok, do not increment + line_thickness=3, # bounding box thickness (pixels) + hide_labels=False, # hide labels + hide_conf=False, # hide confidences + half=False, # use FP16 half-precision inference + dnn=False, # use OpenCV DNN for ONNX inference + vid_stride=1, # video frame-rate stride + retina_masks=False, ): source = str(source) save_img = not nosave and not source.endswith('.txt') # save inference images @@ -147,10 +149,16 @@ def run( s += '%gx%g ' % im.shape[2:] # print string gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh imc = im0.copy() if save_crop else im0 # for save_crop - annotator = Annotator(im0, line_width=line_thickness, example=str(names)) + annotator = Annotator(im0, line_width=line_thickness, example=str(names), pil=True) if len(det): masks = process_mask(proto[i], det[:, 6:], det[:, :4], im.shape[2:], upsample=True) # HWC + seg = [] + for x in masks.int().numpy().astype('uint8'): + c, hier = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + c = np.concatenate([x.reshape(-1, 2) for x in c]).astype('float32') + seg.append(scale_segments(im.shape[2:], c, im0.shape).round()) + # Rescale boxes from img_size to im0 size det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() @@ -160,12 +168,13 @@ def run( s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string # Mask plotting - annotator.masks(masks, - colors=[colors(x, True) for x in det[:, 5]], - im_gpu=None if retina_masks else im[i]) + # annotator.masks(masks, + # colors=[colors(x, True) for x in det[:, 5]], + # im_gpu=None if retina_masks else im[i]) # Write results - for *xyxy, conf, cls in reversed(det[:, :6]): + det, seg = reversed(det), list(reversed(seg)) + for j, (*xyxy, conf, cls) in enumerate(det[:, :6]): if save_txt: # Write to file xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format @@ -176,6 +185,7 @@ def run( c = int(cls) # integer class label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}') annotator.box_label(xyxy, label, color=colors(c, True)) + annotator.draw.polygon(seg[j], outline=colors(c, True), width=2) if save_crop: save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True) From 6decf68f719afa1516cd9dc40031f387a3854f75 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 24 Sep 2022 15:51:41 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- segment/predict.py | 58 +++++++++++++++++++++++----------------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/segment/predict.py b/segment/predict.py index ff20cdc52141..dfa4759b2717 100644 --- a/segment/predict.py +++ b/segment/predict.py @@ -31,8 +31,8 @@ import sys from pathlib import Path -import torch import numpy as np +import torch FILE = Path(__file__).resolve() ROOT = FILE.parents[1] # YOLOv5 root directory @@ -52,34 +52,34 @@ @smart_inference_mode() def run( - weights=ROOT / 'yolov5s-seg.pt', # model.pt path(s) - source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam) - data=ROOT / 'data/coco128.yaml', # dataset.yaml path - imgsz=(640, 640), # inference size (height, width) - conf_thres=0.25, # confidence threshold - iou_thres=0.45, # NMS IOU threshold - max_det=1000, # maximum detections per image - device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu - view_img=False, # show results - save_txt=False, # save results to *.txt - save_conf=False, # save confidences in --save-txt labels - save_crop=False, # save cropped prediction boxes - nosave=False, # do not save images/videos - classes=None, # filter by class: --class 0, or --class 0 2 3 - agnostic_nms=False, # class-agnostic NMS - augment=False, # augmented inference - visualize=False, # visualize features - update=False, # update all models - project=ROOT / 'runs/predict-seg', # save results to project/name - name='exp', # save results to project/name - exist_ok=False, # existing project/name ok, do not increment - line_thickness=3, # bounding box thickness (pixels) - hide_labels=False, # hide labels - hide_conf=False, # hide confidences - half=False, # use FP16 half-precision inference - dnn=False, # use OpenCV DNN for ONNX inference - vid_stride=1, # video frame-rate stride - retina_masks=False, + weights=ROOT / 'yolov5s-seg.pt', # model.pt path(s) + source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam) + data=ROOT / 'data/coco128.yaml', # dataset.yaml path + imgsz=(640, 640), # inference size (height, width) + conf_thres=0.25, # confidence threshold + iou_thres=0.45, # NMS IOU threshold + max_det=1000, # maximum detections per image + device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu + view_img=False, # show results + save_txt=False, # save results to *.txt + save_conf=False, # save confidences in --save-txt labels + save_crop=False, # save cropped prediction boxes + nosave=False, # do not save images/videos + classes=None, # filter by class: --class 0, or --class 0 2 3 + agnostic_nms=False, # class-agnostic NMS + augment=False, # augmented inference + visualize=False, # visualize features + update=False, # update all models + project=ROOT / 'runs/predict-seg', # save results to project/name + name='exp', # save results to project/name + exist_ok=False, # existing project/name ok, do not increment + line_thickness=3, # bounding box thickness (pixels) + hide_labels=False, # hide labels + hide_conf=False, # hide confidences + half=False, # use FP16 half-precision inference + dnn=False, # use OpenCV DNN for ONNX inference + vid_stride=1, # video frame-rate stride + retina_masks=False, ): source = str(source) save_img = not nosave and not source.endswith('.txt') # save inference images From dd29d78411f2894bc4853a12b9a1e15b43ec7b93 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 25 Sep 2022 14:36:34 +0200 Subject: [PATCH 3/5] Update --- segment/predict.py | 33 ++++++++++++++------------------- utils/segment/general.py | 14 ++++++++++++++ 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/segment/predict.py b/segment/predict.py index dfa4759b2717..4085ed4438fd 100644 --- a/segment/predict.py +++ b/segment/predict.py @@ -31,7 +31,6 @@ import sys from pathlib import Path -import numpy as np import torch FILE = Path(__file__).resolve() @@ -46,7 +45,7 @@ increment_path, non_max_suppression, print_args, scale_boxes, scale_segments, strip_optimizer, xyxy2xywh) from utils.plots import Annotator, colors, save_one_box -from utils.segment.general import process_mask +from utils.segment.general import process_mask, masks2segments from utils.torch_utils import select_device, smart_inference_mode @@ -149,18 +148,15 @@ def run( s += '%gx%g ' % im.shape[2:] # print string gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh imc = im0.copy() if save_crop else im0 # for save_crop - annotator = Annotator(im0, line_width=line_thickness, example=str(names), pil=True) + annotator = Annotator(im0, line_width=line_thickness, example=str(names)) if len(det): masks = process_mask(proto[i], det[:, 6:], det[:, :4], im.shape[2:], upsample=True) # HWC + det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() # rescale boxes to im0 size - seg = [] - for x in masks.int().numpy().astype('uint8'): - c, hier = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - c = np.concatenate([x.reshape(-1, 2) for x in c]).astype('float32') - seg.append(scale_segments(im.shape[2:], c, im0.shape).round()) - - # Rescale boxes from img_size to im0 size - det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() + # Segments + if save_txt: + segments = reversed(masks2segments(masks)) + segments = [scale_segments(im.shape[2:], x, im0.shape).round() for x in segments] # Print results for c in det[:, 5].unique(): @@ -168,16 +164,15 @@ def run( s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string # Mask plotting - # annotator.masks(masks, - # colors=[colors(x, True) for x in det[:, 5]], - # im_gpu=None if retina_masks else im[i]) + annotator.masks(masks, + colors=[colors(x, True) for x in det[:, 5]], + im_gpu=None if retina_masks else im[i]) # Write results - det, seg = reversed(det), list(reversed(seg)) - for j, (*xyxy, conf, cls) in enumerate(det[:, :6]): + for j, (*xyxy, conf, cls) in enumerate(reversed(det[:, :6])): if save_txt: # Write to file - xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh - line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format + segj = segments[j].reshape(-1) # (n,2) to (n*2) + line = (cls, *segj, conf) if save_conf else (cls, *segj) # label format with open(f'{txt_path}.txt', 'a') as f: f.write(('%g ' * len(line)).rstrip() % line + '\n') @@ -185,7 +180,7 @@ def run( c = int(cls) # integer class label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}') annotator.box_label(xyxy, label, color=colors(c, True)) - annotator.draw.polygon(seg[j], outline=colors(c, True), width=2) + annotator.draw.polygon(segments[j], outline=colors(c, True), width=3) if save_crop: save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True) diff --git a/utils/segment/general.py b/utils/segment/general.py index 36547ed0889c..4850bfea410e 100644 --- a/utils/segment/general.py +++ b/utils/segment/general.py @@ -1,6 +1,7 @@ import cv2 import torch import torch.nn.functional as F +import numpy as np def crop_mask(masks, boxes): @@ -118,3 +119,16 @@ def masks_iou(mask1, mask2, eps=1e-7): intersection = (mask1 * mask2).sum(1).clamp(0) # (N, ) union = (mask1.sum(1) + mask2.sum(1))[None] - intersection # (area1 + area2) - intersection return intersection / (union + eps) + + +def masks2segments(masks, strategy='largest'): + # Convert masks(n,160,160) into segments(n,xy) + segments = [] + for x in masks.int().numpy().astype('uint8'): + c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0] + if strategy == 'concat': # concatenate all segments + c = np.concatenate([x.reshape(-1, 2) for x in c]) + elif strategy == 'largest': # select largest segment + c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2) + segments.append(c.astype('float32')) + return segments From a6290f4fed7634bcb06c1f5660e1cfd828dbeadc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 25 Sep 2022 12:37:01 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- segment/predict.py | 2 +- utils/segment/general.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/segment/predict.py b/segment/predict.py index 4085ed4438fd..4e6e89225576 100644 --- a/segment/predict.py +++ b/segment/predict.py @@ -45,7 +45,7 @@ increment_path, non_max_suppression, print_args, scale_boxes, scale_segments, strip_optimizer, xyxy2xywh) from utils.plots import Annotator, colors, save_one_box -from utils.segment.general import process_mask, masks2segments +from utils.segment.general import masks2segments, process_mask from utils.torch_utils import select_device, smart_inference_mode diff --git a/utils/segment/general.py b/utils/segment/general.py index 4850bfea410e..655123bdcfeb 100644 --- a/utils/segment/general.py +++ b/utils/segment/general.py @@ -1,7 +1,7 @@ import cv2 +import numpy as np import torch import torch.nn.functional as F -import numpy as np def crop_mask(masks, boxes): From d62b762c9cca6ef3fbf3325525636e981ff60811 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 25 Sep 2022 14:40:19 +0200 Subject: [PATCH 5/5] Update --- segment/predict.py | 1 - 1 file changed, 1 deletion(-) diff --git a/segment/predict.py b/segment/predict.py index 4085ed4438fd..3a9890c4f0bd 100644 --- a/segment/predict.py +++ b/segment/predict.py @@ -146,7 +146,6 @@ def run( save_path = str(save_dir / p.name) # im.jpg txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt s += '%gx%g ' % im.shape[2:] # print string - gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh imc = im0.copy() if save_crop else im0 # for save_crop annotator = Annotator(im0, line_width=line_thickness, example=str(names)) if len(det):