From cb1649309a7bf48f0f06b28e0427987a69fa3730 Mon Sep 17 00:00:00 2001 From: Laughing-q <1185102784@qq.com> Date: Fri, 26 Aug 2022 21:54:31 +0800 Subject: [PATCH] speed up pycocotools ops --- segment/val.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/segment/val.py b/segment/val.py index 16bf9c79d0e7..8fbe13ee06a1 100644 --- a/segment/val.py +++ b/segment/val.py @@ -28,6 +28,7 @@ import numpy as np import torch from tqdm import tqdm +from multiprocessing.pool import ThreadPool FILE = Path(__file__).resolve() ROOT = FILE.parents[1] # YOLOv5 root directory @@ -39,7 +40,6 @@ from models.common import DetectMultiBackend from models.yolo import DetectionModel -from utils import threaded from utils.callbacks import Callbacks from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_yaml, coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args, @@ -51,6 +51,7 @@ from utils.segment.metrics import Metrics, ap_per_class_box_and_mask from utils.segment.plots import plot_images_and_masks from utils.torch_utils import de_parallel, select_device, smart_inference_mode +from utils.general import NUM_THREADS def save_one_txt(predn, save_conf, shape, file): @@ -63,17 +64,20 @@ def save_one_txt(predn, save_conf, shape, file): f.write(('%g ' * len(line)).rstrip() % line + '\n') -@threaded def save_one_json(predn, jdict, path, class_map, pred_masks): # Save one JSON result {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236} from pycocotools.mask import encode + def single_encode(x): + rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0] + rle["counts"] = rle["counts"].decode("utf-8") + return rle + image_id = int(path.stem) if path.stem.isnumeric() else path.stem box = xyxy2xywh(predn[:, :4]) # xywh box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner pred_masks = np.transpose(pred_masks, (2, 0, 1)) - rles = [encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0] for x in pred_masks] - for rle in rles: - rle["counts"] = rle["counts"].decode("utf-8") + with ThreadPool(NUM_THREADS) as pool: + rles = pool.map(single_encode, pred_masks) for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())): jdict.append({ 'image_id': image_id,