Skip to content

Commit

Permalink
New scale_segments() function (#9570)
Browse files Browse the repository at this point in the history
* Rename scale_coords to scale_boxes

* add scale_segments
  • Loading branch information
glenn-jocher committed Sep 24, 2022
1 parent d669a74 commit c8e5230
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 28 deletions.
4 changes: 2 additions & 2 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import select_device, smart_inference_mode

Expand Down Expand Up @@ -148,7 +148,7 @@ def run(
annotator = Annotator(im0, line_width=line_thickness, example=str(names))
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()

# Print results
for c in det[:, 5].unique():
Expand Down
4 changes: 2 additions & 2 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from utils.dataloaders import exif_transpose, letterbox
from utils.general import (LOGGER, ROOT, Profile, check_requirements, check_suffix, check_version, colorstr,
increment_path, make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh,
increment_path, make_divisible, non_max_suppression, scale_boxes, xywh2xyxy, xyxy2xywh,
yaml_load)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import copy_attr, smart_inference_mode
Expand Down Expand Up @@ -703,7 +703,7 @@ def forward(self, ims, size=640, augment=False, profile=False):
self.multi_label,
max_det=self.max_det) # NMS
for i in range(n):
scale_coords(shape1, y[i][:, :4], shape0[i])
scale_boxes(shape1, y[i][:, :4], shape0[i])

return Detections(ims, y, files, dt, self.names, x.shape)

Expand Down
4 changes: 2 additions & 2 deletions segment/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.segment.general import process_mask
from utils.torch_utils import select_device, smart_inference_mode
Expand Down Expand Up @@ -152,7 +152,7 @@ def run(
masks = process_mask(proto[i], det[:, 6:], det[:, :4], im.shape[2:], upsample=True) # HWC

# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()

# Print results
for c in det[:, 5].unique():
Expand Down
6 changes: 3 additions & 3 deletions segment/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from utils.callbacks import Callbacks
from utils.general import (LOGGER, NUM_THREADS, Profile, check_dataset, check_img_size, check_requirements, check_yaml,
coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args,
scale_coords, xywh2xyxy, xyxy2xywh)
scale_boxes, xywh2xyxy, xyxy2xywh)
from utils.metrics import ConfusionMatrix, box_iou
from utils.plots import output_to_target, plot_val_study
from utils.segment.dataloaders import create_dataloader
Expand Down Expand Up @@ -298,12 +298,12 @@ def run(
if single_cls:
pred[:, 5] = 0
predn = pred.clone()
scale_coords(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred
scale_boxes(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred

# Evaluate
if nl:
tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
scale_coords(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels
scale_boxes(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
correct_bboxes = process_batch(predn, labelsn, iouv)
correct_masks = process_batch(predn, labelsn, iouv, pred_masks, gt_masks, overlap=overlap, masks=True)
Expand Down
46 changes: 36 additions & 10 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
if clip:
clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
Expand Down Expand Up @@ -769,7 +769,23 @@ def resample_segments(segments, n=1000):
return segments


def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
# Rescale boxes (xyxy) from img1_shape to img0_shape
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
else:
gain = ratio_pad[0][0]
pad = ratio_pad[1]

boxes[:, [0, 2]] -= pad[0] # x padding
boxes[:, [1, 3]] -= pad[1] # y padding
boxes[:, :4] /= gain
clip_boxes(boxes, img0_shape)
return boxes


def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None):
# Rescale coords (xyxy) from img1_shape to img0_shape
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
Expand All @@ -778,15 +794,15 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
gain = ratio_pad[0][0]
pad = ratio_pad[1]

coords[:, [0, 2]] -= pad[0] # x padding
coords[:, [1, 3]] -= pad[1] # y padding
coords[:, :4] /= gain
clip_coords(coords, img0_shape)
return coords
segments[:, 0] -= pad[0] # x padding
segments[:, 1] -= pad[1] # y padding
segments /= gain
clip_segments(segments, img0_shape)
return segments


def clip_coords(boxes, shape):
# Clip bounding xyxy bounding boxes to image shape (height, width)
def clip_boxes(boxes, shape):
# Clip boxes (xyxy) to image shape (height, width)
if isinstance(boxes, torch.Tensor): # faster individually
boxes[:, 0].clamp_(0, shape[1]) # x1
boxes[:, 1].clamp_(0, shape[0]) # y1
Expand All @@ -797,6 +813,16 @@ def clip_coords(boxes, shape):
boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2


def clip_segments(boxes, shape):
# Clip segments (xy1,xy2,...) to image shape (height, width)
if isinstance(boxes, torch.Tensor): # faster individually
boxes[:, 0].clamp_(0, shape[1]) # x
boxes[:, 1].clamp_(0, shape[0]) # y
else: # np.array (faster grouped)
boxes[:, 0] = boxes[:, 0].clip(0, shape[1]) # x
boxes[:, 1] = boxes[:, 1].clip(0, shape[0]) # y


def non_max_suppression(
prediction,
conf_thres=0.25,
Expand Down Expand Up @@ -980,7 +1006,7 @@ def apply_classifier(x, model, img, im0):
d[:, :4] = xywh2xyxy(b).long()

# Rescale boxes from img_size to im0 size
scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
scale_boxes(img.shape[2:], d[:, :4], im0[i].shape)

# Classes
pred_cls1 = d[:, 5].long()
Expand Down
8 changes: 4 additions & 4 deletions utils/loggers/comet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import yaml

from utils.dataloaders import img2label_paths
from utils.general import check_dataset, scale_coords, xywh2xyxy
from utils.general import check_dataset, scale_boxes, xywh2xyxy
from utils.metrics import box_iou

COMET_PREFIX = "comet://"
Expand Down Expand Up @@ -293,14 +293,14 @@ def preprocess_prediction(self, image, labels, shape, pred):
pred[:, 5] = 0

predn = pred.clone()
scale_coords(image.shape[1:], predn[:, :4], shape[0], shape[1])
scale_boxes(image.shape[1:], predn[:, :4], shape[0], shape[1])

labelsn = None
if nl:
tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
scale_coords(image.shape[1:], tbox, shape[0], shape[1]) # native-space labels
scale_boxes(image.shape[1:], tbox, shape[0], shape[1]) # native-space labels
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
scale_coords(image.shape[1:], predn[:, :4], shape[0], shape[1]) # native-space pred
scale_boxes(image.shape[1:], predn[:, :4], shape[0], shape[1]) # native-space pred

return predn, labelsn

Expand Down
4 changes: 2 additions & 2 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from PIL import Image, ImageDraw, ImageFont

from utils import TryExcept, threaded
from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_coords, increment_path,
from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_boxes, increment_path,
is_ascii, xywh2xyxy, xyxy2xywh)
from utils.metrics import fitness
from utils.segment.general import scale_image
Expand Down Expand Up @@ -565,7 +565,7 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False,
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
xyxy = xywh2xyxy(b).long()
clip_coords(xyxy, im.shape)
clip_boxes(xyxy, im.shape)
crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
if save:
file.parent.mkdir(parents=True, exist_ok=True) # make directory
Expand Down
6 changes: 3 additions & 3 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from utils.dataloaders import create_dataloader
from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_yaml,
coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args,
scale_coords, xywh2xyxy, xyxy2xywh)
scale_boxes, xywh2xyxy, xyxy2xywh)
from utils.metrics import ConfusionMatrix, ap_per_class, box_iou
from utils.plots import output_to_target, plot_images, plot_val_study
from utils.torch_utils import select_device, smart_inference_mode
Expand Down Expand Up @@ -244,12 +244,12 @@ def run(
if single_cls:
pred[:, 5] = 0
predn = pred.clone()
scale_coords(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred
scale_boxes(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred

# Evaluate
if nl:
tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
scale_coords(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels
scale_boxes(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
correct = process_batch(predn, labelsn, iouv)
if plots:
Expand Down

0 comments on commit c8e5230

Please sign in to comment.