Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix(core): fix memory leak issue and switch to subprocess backend #216

Merged
merged 8 commits into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ For more details, please refer to our [report on Arxiv](https://arxiv.org/abs/21
<img src="assets/git_fig.png" width="1000" >

## Updates!!
* 【2021/07/28】 We fix the fatal error of [memory leak](https://github.com/Megvii-BaseDetection/YOLOX/issues/103)
* 【2021/07/26】 We now support [MegEngine](https://github.com/Megvii-BaseDetection/YOLOX/tree/main/demo/MegEngine) deployment.
* 【2021/07/20】 We have released our technical report on [Arxiv](https://arxiv.org/abs/2107.08430).

Expand Down
101 changes: 61 additions & 40 deletions tools/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.

import argparse
import os
import time
from loguru import logger

import cv2
Expand All @@ -16,20 +13,29 @@
from yolox.exp import get_exp
from yolox.utils import fuse_model, get_model_info, postprocess, vis

IMAGE_EXT = ['.jpg', '.jpeg', '.webp', '.bmp', '.png']
import argparse
import os
import time

IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]


def make_parser():
parser = argparse.ArgumentParser("YOLOX Demo!")
parser.add_argument('demo', default='image', help='demo type, eg. image, video and webcam')
parser.add_argument(
"demo", default="image", help="demo type, eg. image, video and webcam"
)
parser.add_argument("-expn", "--experiment-name", type=str, default=None)
parser.add_argument("-n", "--name", type=str, default=None, help="model name")

parser.add_argument('--path', default='./assets/dog.jpg', help='path to images or video')
parser.add_argument('--camid', type=int, default=0, help='webcam demo camera id')
parser.add_argument(
'--save_result', action='store_true',
help='whether to save the inference result of image/video'
"--path", default="./assets/dog.jpg", help="path to images or video"
)
parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
parser.add_argument(
"--save_result",
action="store_true",
help="whether to save the inference result of image/video",
)

# exp file
Expand All @@ -41,7 +47,12 @@ def make_parser():
help="pls input your expriment description file",
)
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
parser.add_argument("--device", default="cpu", type=str, help="device to run our model, can either be cpu or gpu")
parser.add_argument(
"--device",
default="cpu",
type=str,
help="device to run our model, can either be cpu or gpu",
)
parser.add_argument("--conf", default=None, type=float, help="test conf")
parser.add_argument("--nms", default=None, type=float, help="test nms threshold")
parser.add_argument("--tsize", default=None, type=int, help="test img size")
Expand Down Expand Up @@ -81,7 +92,15 @@ def get_image_list(path):


class Predictor(object):
def __init__(self, model, exp, cls_names=COCO_CLASSES, trt_file=None, decoder=None, device="cpu"):
def __init__(
self,
model,
exp,
cls_names=COCO_CLASSES,
trt_file=None,
decoder=None,
device="cpu",
):
self.model = model
self.cls_names = cls_names
self.decoder = decoder
Expand All @@ -92,6 +111,7 @@ def __init__(self, model, exp, cls_names=COCO_CLASSES, trt_file=None, decoder=No
self.device = device
if trt_file is not None:
from torch2trt import TRTModule

model_trt = TRTModule()
model_trt.load_state_dict(torch.load(trt_file))

Expand All @@ -102,20 +122,20 @@ def __init__(self, model, exp, cls_names=COCO_CLASSES, trt_file=None, decoder=No
self.std = (0.229, 0.224, 0.225)

def inference(self, img):
img_info = {'id': 0}
img_info = {"id": 0}
if isinstance(img, str):
img_info['file_name'] = os.path.basename(img)
img_info["file_name"] = os.path.basename(img)
img = cv2.imread(img)
else:
img_info['file_name'] = None
img_info["file_name"] = None

height, width = img.shape[:2]
img_info['height'] = height
img_info['width'] = width
img_info['raw_img'] = img
img_info["height"] = height
img_info["width"] = width
img_info["raw_img"] = img

img, ratio = preproc(img, self.test_size, self.rgb_means, self.std)
img_info['ratio'] = ratio
img_info["ratio"] = ratio
img = torch.from_numpy(img).unsqueeze(0)
if self.device == "gpu":
img = img.cuda()
Expand All @@ -126,14 +146,14 @@ def inference(self, img):
if self.decoder is not None:
outputs = self.decoder(outputs, dtype=outputs.type())
outputs = postprocess(
outputs, self.num_classes, self.confthre, self.nmsthre
)
logger.info('Infer time: {:.4f}s'.format(time.time()-t0))
outputs, self.num_classes, self.confthre, self.nmsthre
)
logger.info("Infer time: {:.4f}s".format(time.time() - t0))
return outputs, img_info

def visual(self, output, img_info, cls_conf=0.35):
ratio = img_info['ratio']
img = img_info['raw_img']
ratio = img_info["ratio"]
img = img_info["raw_img"]
if output is None:
return img
output = output.cpu()
Expand Down Expand Up @@ -168,24 +188,26 @@ def image_demo(predictor, vis_folder, path, current_time, save_result):
logger.info("Saving detection result in {}".format(save_file_name))
cv2.imwrite(save_file_name, result_image)
ch = cv2.waitKey(0)
if ch == 27 or ch == ord('q') or ch == ord('Q'):
if ch == 27 or ch == ord("q") or ch == ord("Q"):
break


def imageflow_demo(predictor, vis_folder, current_time, args):
cap = cv2.VideoCapture(args.path if args.demo == 'video' else args.camid)
cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid)
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float
fps = cap.get(cv2.CAP_PROP_FPS)
save_folder = os.path.join(vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time))
save_folder = os.path.join(
vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
)
os.makedirs(save_folder, exist_ok=True)
if args.demo == "video":
save_path = os.path.join(save_folder, args.path.split('/')[-1])
save_path = os.path.join(save_folder, args.path.split("/")[-1])
else:
save_path = os.path.join(save_folder, 'camera.mp4')
logger.info(f'video save_path is {save_path}')
save_path = os.path.join(save_folder, "camera.mp4")
logger.info(f"video save_path is {save_path}")
vid_writer = cv2.VideoWriter(
save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (int(width), int(height))
save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
)
while True:
ret_val, frame = cap.read()
Expand All @@ -195,7 +217,7 @@ def imageflow_demo(predictor, vis_folder, current_time, args):
if args.save_result:
vid_writer.write(result_frame)
ch = cv2.waitKey(1)
if ch == 27 or ch == ord('q') or ch == ord('Q'):
if ch == 27 or ch == ord("q") or ch == ord("Q"):
break
else:
break
Expand All @@ -209,11 +231,11 @@ def main(exp, args):
os.makedirs(file_name, exist_ok=True)

if args.save_result:
vis_folder = os.path.join(file_name, 'vis_res')
vis_folder = os.path.join(file_name, "vis_res")
os.makedirs(vis_folder, exist_ok=True)

if args.trt:
args.device="gpu"
args.device = "gpu"

logger.info("Args: {}".format(args))

Expand Down Expand Up @@ -247,12 +269,11 @@ def main(exp, args):
model = fuse_model(model)

if args.trt:
assert (not args.fuse),\
"TensorRT model is not support model fusing!"
assert not args.fuse, "TensorRT model is not support model fusing!"
trt_file = os.path.join(file_name, "model_trt.pth")
assert os.path.exists(trt_file), (
"TensorRT model is not found!\n Run python3 tools/trt.py first!"
)
assert os.path.exists(
trt_file
), "TensorRT model is not found!\n Run python3 tools/trt.py first!"
model.head.decode_in_inference = False
decoder = model.head.decode_outputs
logger.info("Using TensorRT to inference")
Expand All @@ -262,9 +283,9 @@ def main(exp, args):

predictor = Predictor(model, exp, COCO_CLASSES, trt_file, decoder, args.device)
current_time = time.localtime()
if args.demo == 'image':
if args.demo == "image":
image_demo(predictor, vis_folder, args.path, current_time, args.save_result)
elif args.demo == 'video' or args.demo == 'webcam':
elif args.demo == "video" or args.demo == "webcam":
imageflow_demo(predictor, vis_folder, current_time, args)


Expand Down
49 changes: 31 additions & 18 deletions tools/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.

import argparse
import os
import random
import warnings
from loguru import logger

import torch
Expand All @@ -16,6 +12,11 @@
from yolox.exp import get_exp
from yolox.utils import configure_nccl, fuse_model, get_local_rank, get_model_info, setup_logger

import argparse
import os
import random
import warnings


def make_parser():
parser = argparse.ArgumentParser("YOLOX Eval")
Expand All @@ -27,7 +28,10 @@ def make_parser():
"--dist-backend", default="nccl", type=str, help="distributed backend"
)
parser.add_argument(
"--dist-url", default=None, type=str, help="url used to set up distributed training"
"--dist-url",
default=None,
type=str,
help="url used to set up distributed training",
)
parser.add_argument("-b", "--batch-size", type=int, default=64, help="batch size")
parser.add_argument(
Expand Down Expand Up @@ -83,7 +87,11 @@ def make_parser():
help="Evaluating on test-dev set.",
)
parser.add_argument(
"--speed", dest="speed", default=False, action="store_true", help="speed test only."
"--speed",
dest="speed",
default=False,
action="store_true",
help="speed test only.",
)
parser.add_argument(
"opts",
Expand All @@ -95,7 +103,7 @@ def make_parser():


@logger.catch
def main(exp, num_gpu, args):
def main(exp, args, num_gpu):
if not args.experiment_name:
args.experiment_name = exp.exp_name

Expand All @@ -113,8 +121,8 @@ def main(exp, num_gpu, args):
configure_nccl()
cudnn.benchmark = True

# rank = args.local_rank
rank = get_local_rank()
rank = args.local_rank
# rank = get_local_rank()

if rank == 0:
if os.path.exists("./" + args.experiment_name + "ip_add.txt"):
Expand All @@ -125,9 +133,7 @@ def main(exp, num_gpu, args):
if rank == 0:
os.makedirs(file_name, exist_ok=True)

setup_logger(
file_name, distributed_rank=rank, filename="val_log.txt", mode="a"
)
setup_logger(file_name, distributed_rank=rank, filename="val_log.txt", mode="a")
logger.info("Args: {}".format(args))

if args.conf is not None:
Expand Down Expand Up @@ -167,10 +173,13 @@ def main(exp, num_gpu, args):
model = fuse_model(model)

if args.trt:
assert (not args.fuse and not is_distributed and args.batch_size == 1),\
"TensorRT model is not support model fusing and distributed inferencing!"
assert (
not args.fuse and not is_distributed and args.batch_size == 1
), "TensorRT model is not support model fusing and distributed inferencing!"
trt_file = os.path.join(file_name, "model_trt.pth")
assert os.path.exists(trt_file), "TensorRT model is not found!\n Run tools/trt.py first!"
assert os.path.exists(
trt_file
), "TensorRT model is not found!\n Run tools/trt.py first!"
model.head.decode_in_inference = False
decoder = model.head.decode_outputs
else:
Expand All @@ -192,8 +201,12 @@ def main(exp, num_gpu, args):
num_gpu = torch.cuda.device_count() if args.devices is None else args.devices
assert num_gpu <= torch.cuda.device_count()

dist_url = "auto" if args.dist_url is None else args.dist_url
launch(
main, num_gpu, args.num_machine, backend=args.dist_backend,
dist_url=dist_url, args=(exp, num_gpu, args)
main,
num_gpu,
args.num_machine,
args.machine_rank,
backend=args.dist_backend,
dist_url=args.dist_url,
args=(exp, args, num_gpu),
)
Loading