Skip to content

Commit

Permalink
process_mask_native() cleanup (#10366)
Browse files Browse the repository at this point in the history
* process_mask_native() cleanup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix arg name

* cleanup anno_json

* Remove scale_image

* Remove scale_image

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update to native

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
glenn-jocher and pre-commit-ci[bot] committed Dec 3, 2022
1 parent a1b6e79 commit 9722e6f
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 25 deletions.
17 changes: 9 additions & 8 deletions segment/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
increment_path, non_max_suppression, print_args, scale_boxes, scale_segments,
strip_optimizer, xyxy2xywh)
strip_optimizer)
from utils.plots import Annotator, colors, save_one_box
from utils.segment.general import masks2segments, process_mask, process_mask_native
from utils.torch_utils import select_device, smart_inference_mode
Expand Down Expand Up @@ -161,26 +161,27 @@ def run(

# Segments
if save_txt:
segments = reversed(masks2segments(masks))
segments = [
scale_segments(im0.shape if retina_masks else im.shape[2:], x, im0.shape, normalize=True)
for x in segments]
for x in reversed(masks2segments(masks))]

# Print results
for c in det[:, 5].unique():
n = (det[:, 5] == c).sum() # detections per class
s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string

# Mask plotting
plot_img = torch.as_tensor(im0, dtype=torch.float16).to(device).permute(2, 0, 1).flip(0).contiguous() / 255. \
if retina_masks else im[i]
annotator.masks(masks, colors=[colors(x, True) for x in det[:, 5]], im_gpu=plot_img)
annotator.masks(
masks,
colors=[colors(x, True) for x in det[:, 5]],
im_gpu=torch.as_tensor(im0, dtype=torch.float16).to(device).permute(2, 0, 1).flip(0).contiguous() /
255 if retina_masks else im[i])

# Write results
for j, (*xyxy, conf, cls) in enumerate(reversed(det[:, :6])):
if save_txt: # Write to file
segj = segments[j].reshape(-1) # (n,2) to (n*2)
line = (cls, *segj, conf) if save_conf else (cls, *segj) # label format
seg = segments[j].reshape(-1) # (n,2) to (n*2)
line = (cls, *seg, conf) if save_conf else (cls, *seg) # label format
with open(f'{txt_path}.txt', 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n')

Expand Down
10 changes: 5 additions & 5 deletions segment/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from utils.metrics import ConfusionMatrix, box_iou
from utils.plots import output_to_target, plot_val_study
from utils.segment.dataloaders import create_dataloader
from utils.segment.general import mask_iou, process_mask, process_mask_upsample, scale_image
from utils.segment.general import mask_iou, process_mask, process_mask_native, scale_image
from utils.segment.metrics import Metrics, ap_per_class_box_and_mask
from utils.segment.plots import plot_images_and_masks
from utils.torch_utils import de_parallel, select_device, smart_inference_mode
Expand Down Expand Up @@ -160,7 +160,7 @@ def run(
):
if save_json:
check_requirements(['pycocotools'])
process = process_mask_upsample # more accurate
process = process_mask_native # more accurate
else:
process = process_mask # faster

Expand Down Expand Up @@ -312,7 +312,7 @@ def run(

pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
if plots and batch_i < 3:
plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
plot_masks.append(pred_masks[:15]) # filter top 15 to plot

# Save/log
if save_txt:
Expand Down Expand Up @@ -367,8 +367,8 @@ def run(
# Save JSON
if save_json and len(jdict):
w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights
anno_json = str(Path(data.get('path', '../coco')) / 'annotations/instances_val2017.json') # annotations json
pred_json = str(save_dir / f"{w}_predictions.json") # predictions json
anno_json = str(Path('../datasets/coco/annotations/instances_val2017.json')) # annotations
pred_json = str(save_dir / f"{w}_predictions.json") # predictions
LOGGER.info(f'\nEvaluating pycocotools mAP... saving {pred_json}...')
with open(pred_json, 'w') as f:
json.dump(jdict, f)
Expand Down
20 changes: 10 additions & 10 deletions utils/segment/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def crop_mask(masks, boxes):
def process_mask_upsample(protos, masks_in, bboxes, shape):
"""
Crop after upsample.
proto_out: [mask_dim, mask_h, mask_w]
out_masks: [n, mask_dim], n is number of masks after nms
protos: [mask_dim, mask_h, mask_w]
masks_in: [n, mask_dim], n is number of masks after nms
bboxes: [n, 4], n is number of masks after nms
shape:input_image_size, (h, w)
shape: input_image_size, (h, w)
return: h, w, n
"""
Expand Down Expand Up @@ -67,25 +67,25 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
return masks.gt_(0.5)


def process_mask_native(protos, masks_in, bboxes, dst_shape):
def process_mask_native(protos, masks_in, bboxes, shape):
"""
Crop after upsample.
proto_out: [mask_dim, mask_h, mask_w]
out_masks: [n, mask_dim], n is number of masks after nms
protos: [mask_dim, mask_h, mask_w]
masks_in: [n, mask_dim], n is number of masks after nms
bboxes: [n, 4], n is number of masks after nms
shape:input_image_size, (h, w)
shape: input_image_size, (h, w)
return: h, w, n
"""
c, mh, mw = protos.shape # CHW
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
gain = min(mh / dst_shape[0], mw / dst_shape[1]) # gain = old / new
pad = (mw - dst_shape[1] * gain) / 2, (mh - dst_shape[0] * gain) / 2 # wh padding
gain = min(mh / shape[0], mw / shape[1]) # gain = old / new
pad = (mw - shape[1] * gain) / 2, (mh - shape[0] * gain) / 2 # wh padding
top, left = int(pad[1]), int(pad[0]) # y, x
bottom, right = int(mh - pad[1]), int(mw - pad[0])
masks = masks[:, top:bottom, left:right]

masks = F.interpolate(masks[None], dst_shape, mode='bilinear', align_corners=False)[0] # CHW
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
masks = crop_mask(masks, bboxes) # CHW
return masks.gt_(0.5)

Expand Down
4 changes: 2 additions & 2 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ def run(
# Save JSON
if save_json and len(jdict):
w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights
anno_json = str(Path(data.get('path', '../coco')) / 'annotations/instances_val2017.json') # annotations json
pred_json = str(save_dir / f"{w}_predictions.json") # predictions json
anno_json = str(Path('../datasets/coco/annotations/instances_val2017.json')) # annotations
pred_json = str(save_dir / f"{w}_predictions.json") # predictions
LOGGER.info(f'\nEvaluating pycocotools mAP... saving {pred_json}...')
with open(pred_json, 'w') as f:
json.dump(jdict, f)
Expand Down

0 comments on commit 9722e6f

Please sign in to comment.