Skip to content

Commit

Permalink
Merge pull request ultralytics#21 from Laughing-q/instance_seg
Browse files Browse the repository at this point in the history
speed up pycocotools ops
  • Loading branch information
AyushExel committed Aug 26, 2022
2 parents 8029d56 + cb16493 commit fee17a6
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions segment/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import numpy as np
import torch
from tqdm import tqdm
from multiprocessing.pool import ThreadPool

FILE = Path(__file__).resolve()
ROOT = FILE.parents[1] # YOLOv5 root directory
Expand All @@ -39,7 +40,6 @@

from models.common import DetectMultiBackend
from models.yolo import DetectionModel
from utils import threaded
from utils.callbacks import Callbacks
from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_yaml,
coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args,
Expand All @@ -51,6 +51,7 @@
from utils.segment.metrics import Metrics, ap_per_class_box_and_mask
from utils.segment.plots import plot_images_and_masks
from utils.torch_utils import de_parallel, select_device, smart_inference_mode
from utils.general import NUM_THREADS


def save_one_txt(predn, save_conf, shape, file):
Expand All @@ -63,17 +64,20 @@ def save_one_txt(predn, save_conf, shape, file):
f.write(('%g ' * len(line)).rstrip() % line + '\n')


@threaded
def save_one_json(predn, jdict, path, class_map, pred_masks):
# Save one JSON result {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
from pycocotools.mask import encode
def single_encode(x):
rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
rle["counts"] = rle["counts"].decode("utf-8")
return rle

image_id = int(path.stem) if path.stem.isnumeric() else path.stem
box = xyxy2xywh(predn[:, :4]) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
pred_masks = np.transpose(pred_masks, (2, 0, 1))
rles = [encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0] for x in pred_masks]
for rle in rles:
rle["counts"] = rle["counts"].decode("utf-8")
with ThreadPool(NUM_THREADS) as pool:
rles = pool.map(single_encode, pred_masks)
for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
jdict.append({
'image_id': image_id,
Expand Down

0 comments on commit fee17a6

Please sign in to comment.