Skip to content

Commit

Permalink
make flake8 happy
Browse files Browse the repository at this point in the history
  • Loading branch information
liusongtao committed Jul 28, 2021
1 parent da45417 commit 87f31cc
Show file tree
Hide file tree
Showing 38 changed files with 946 additions and 417 deletions.
101 changes: 61 additions & 40 deletions tools/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.

import argparse
import os
import time
from loguru import logger

import cv2
Expand All @@ -16,20 +13,29 @@
from yolox.exp import get_exp
from yolox.utils import fuse_model, get_model_info, postprocess, vis

IMAGE_EXT = ['.jpg', '.jpeg', '.webp', '.bmp', '.png']
import argparse
import os
import time

IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]


def make_parser():
parser = argparse.ArgumentParser("YOLOX Demo!")
parser.add_argument('demo', default='image', help='demo type, eg. image, video and webcam')
parser.add_argument(
"demo", default="image", help="demo type, eg. image, video and webcam"
)
parser.add_argument("-expn", "--experiment-name", type=str, default=None)
parser.add_argument("-n", "--name", type=str, default=None, help="model name")

parser.add_argument('--path', default='./assets/dog.jpg', help='path to images or video')
parser.add_argument('--camid', type=int, default=0, help='webcam demo camera id')
parser.add_argument(
'--save_result', action='store_true',
help='whether to save the inference result of image/video'
"--path", default="./assets/dog.jpg", help="path to images or video"
)
parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
parser.add_argument(
"--save_result",
action="store_true",
help="whether to save the inference result of image/video",
)

# exp file
Expand All @@ -41,7 +47,12 @@ def make_parser():
help="pls input your expriment description file",
)
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
parser.add_argument("--device", default="cpu", type=str, help="device to run our model, can either be cpu or gpu")
parser.add_argument(
"--device",
default="cpu",
type=str,
help="device to run our model, can either be cpu or gpu",
)
parser.add_argument("--conf", default=None, type=float, help="test conf")
parser.add_argument("--nms", default=None, type=float, help="test nms threshold")
parser.add_argument("--tsize", default=None, type=int, help="test img size")
Expand Down Expand Up @@ -81,7 +92,15 @@ def get_image_list(path):


class Predictor(object):
def __init__(self, model, exp, cls_names=COCO_CLASSES, trt_file=None, decoder=None, device="cpu"):
def __init__(
self,
model,
exp,
cls_names=COCO_CLASSES,
trt_file=None,
decoder=None,
device="cpu",
):
self.model = model
self.cls_names = cls_names
self.decoder = decoder
Expand All @@ -92,6 +111,7 @@ def __init__(self, model, exp, cls_names=COCO_CLASSES, trt_file=None, decoder=No
self.device = device
if trt_file is not None:
from torch2trt import TRTModule

model_trt = TRTModule()
model_trt.load_state_dict(torch.load(trt_file))

Expand All @@ -102,20 +122,20 @@ def __init__(self, model, exp, cls_names=COCO_CLASSES, trt_file=None, decoder=No
self.std = (0.229, 0.224, 0.225)

def inference(self, img):
img_info = {'id': 0}
img_info = {"id": 0}
if isinstance(img, str):
img_info['file_name'] = os.path.basename(img)
img_info["file_name"] = os.path.basename(img)
img = cv2.imread(img)
else:
img_info['file_name'] = None
img_info["file_name"] = None

height, width = img.shape[:2]
img_info['height'] = height
img_info['width'] = width
img_info['raw_img'] = img
img_info["height"] = height
img_info["width"] = width
img_info["raw_img"] = img

img, ratio = preproc(img, self.test_size, self.rgb_means, self.std)
img_info['ratio'] = ratio
img_info["ratio"] = ratio
img = torch.from_numpy(img).unsqueeze(0)
if self.device == "gpu":
img = img.cuda()
Expand All @@ -126,14 +146,14 @@ def inference(self, img):
if self.decoder is not None:
outputs = self.decoder(outputs, dtype=outputs.type())
outputs = postprocess(
outputs, self.num_classes, self.confthre, self.nmsthre
)
logger.info('Infer time: {:.4f}s'.format(time.time()-t0))
outputs, self.num_classes, self.confthre, self.nmsthre
)
logger.info("Infer time: {:.4f}s".format(time.time() - t0))
return outputs, img_info

def visual(self, output, img_info, cls_conf=0.35):
ratio = img_info['ratio']
img = img_info['raw_img']
ratio = img_info["ratio"]
img = img_info["raw_img"]
if output is None:
return img
output = output.cpu()
Expand Down Expand Up @@ -168,24 +188,26 @@ def image_demo(predictor, vis_folder, path, current_time, save_result):
logger.info("Saving detection result in {}".format(save_file_name))
cv2.imwrite(save_file_name, result_image)
ch = cv2.waitKey(0)
if ch == 27 or ch == ord('q') or ch == ord('Q'):
if ch == 27 or ch == ord("q") or ch == ord("Q"):
break


def imageflow_demo(predictor, vis_folder, current_time, args):
cap = cv2.VideoCapture(args.path if args.demo == 'video' else args.camid)
cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid)
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float
fps = cap.get(cv2.CAP_PROP_FPS)
save_folder = os.path.join(vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time))
save_folder = os.path.join(
vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
)
os.makedirs(save_folder, exist_ok=True)
if args.demo == "video":
save_path = os.path.join(save_folder, args.path.split('/')[-1])
save_path = os.path.join(save_folder, args.path.split("/")[-1])
else:
save_path = os.path.join(save_folder, 'camera.mp4')
logger.info(f'video save_path is {save_path}')
save_path = os.path.join(save_folder, "camera.mp4")
logger.info(f"video save_path is {save_path}")
vid_writer = cv2.VideoWriter(
save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (int(width), int(height))
save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
)
while True:
ret_val, frame = cap.read()
Expand All @@ -195,7 +217,7 @@ def imageflow_demo(predictor, vis_folder, current_time, args):
if args.save_result:
vid_writer.write(result_frame)
ch = cv2.waitKey(1)
if ch == 27 or ch == ord('q') or ch == ord('Q'):
if ch == 27 or ch == ord("q") or ch == ord("Q"):
break
else:
break
Expand All @@ -209,11 +231,11 @@ def main(exp, args):
os.makedirs(file_name, exist_ok=True)

if args.save_result:
vis_folder = os.path.join(file_name, 'vis_res')
vis_folder = os.path.join(file_name, "vis_res")
os.makedirs(vis_folder, exist_ok=True)

if args.trt:
args.device="gpu"
args.device = "gpu"

logger.info("Args: {}".format(args))

Expand Down Expand Up @@ -247,12 +269,11 @@ def main(exp, args):
model = fuse_model(model)

if args.trt:
assert (not args.fuse),\
"TensorRT model is not support model fusing!"
assert not args.fuse, "TensorRT model is not support model fusing!"
trt_file = os.path.join(file_name, "model_trt.pth")
assert os.path.exists(trt_file), (
"TensorRT model is not found!\n Run python3 tools/trt.py first!"
)
assert os.path.exists(
trt_file
), "TensorRT model is not found!\n Run python3 tools/trt.py first!"
model.head.decode_in_inference = False
decoder = model.head.decode_outputs
logger.info("Using TensorRT to inference")
Expand All @@ -262,9 +283,9 @@ def main(exp, args):

predictor = Predictor(model, exp, COCO_CLASSES, trt_file, decoder, args.device)
current_time = time.localtime()
if args.demo == 'image':
if args.demo == "image":
image_demo(predictor, vis_folder, args.path, current_time, args.save_result)
elif args.demo == 'video' or args.demo == 'webcam':
elif args.demo == "video" or args.demo == "webcam":
imageflow_demo(predictor, vis_folder, current_time, args)


Expand Down
44 changes: 29 additions & 15 deletions tools/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.

import argparse
import os
import random
import warnings
from loguru import logger

import torch
Expand All @@ -16,6 +12,11 @@
from yolox.exp import get_exp
from yolox.utils import configure_nccl, fuse_model, get_local_rank, get_model_info, setup_logger

import argparse
import os
import random
import warnings


def make_parser():
parser = argparse.ArgumentParser("YOLOX Eval")
Expand All @@ -27,7 +28,10 @@ def make_parser():
"--dist-backend", default="nccl", type=str, help="distributed backend"
)
parser.add_argument(
"--dist-url", default=None, type=str, help="url used to set up distributed training"
"--dist-url",
default=None,
type=str,
help="url used to set up distributed training",
)
parser.add_argument("-b", "--batch-size", type=int, default=64, help="batch size")
parser.add_argument(
Expand Down Expand Up @@ -83,7 +87,11 @@ def make_parser():
help="Evaluating on test-dev set.",
)
parser.add_argument(
"--speed", dest="speed", default=False, action="store_true", help="speed test only."
"--speed",
dest="speed",
default=False,
action="store_true",
help="speed test only.",
)
parser.add_argument(
"opts",
Expand Down Expand Up @@ -114,7 +122,7 @@ def main(exp, args, num_gpu):
cudnn.benchmark = True

rank = args.local_rank
#rank = get_local_rank()
# rank = get_local_rank()

if rank == 0:
if os.path.exists("./" + args.experiment_name + "ip_add.txt"):
Expand All @@ -125,9 +133,7 @@ def main(exp, args, num_gpu):
if rank == 0:
os.makedirs(file_name, exist_ok=True)

setup_logger(
file_name, distributed_rank=rank, filename="val_log.txt", mode="a"
)
setup_logger(file_name, distributed_rank=rank, filename="val_log.txt", mode="a")
logger.info("Args: {}".format(args))

if args.conf is not None:
Expand Down Expand Up @@ -167,10 +173,13 @@ def main(exp, args, num_gpu):
model = fuse_model(model)

if args.trt:
assert (not args.fuse and not is_distributed and args.batch_size == 1),\
"TensorRT model is not support model fusing and distributed inferencing!"
assert (
not args.fuse and not is_distributed and args.batch_size == 1
), "TensorRT model is not support model fusing and distributed inferencing!"
trt_file = os.path.join(file_name, "model_trt.pth")
assert os.path.exists(trt_file), "TensorRT model is not found!\n Run tools/trt.py first!"
assert os.path.exists(
trt_file
), "TensorRT model is not found!\n Run tools/trt.py first!"
model.head.decode_in_inference = False
decoder = model.head.decode_outputs
else:
Expand All @@ -193,6 +202,11 @@ def main(exp, args, num_gpu):
assert num_gpu <= torch.cuda.device_count()

launch(
main, num_gpu, args.num_machine, args.machine_rank, backend=args.dist_backend,
dist_url=args.dist_url, args=(exp, args, num_gpu)
main,
num_gpu,
args.num_machine,
args.machine_rank,
backend=args.dist_backend,
dist_url=args.dist_url,
args=(exp, args, num_gpu),
)
52 changes: 32 additions & 20 deletions tools/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.

import argparse
import os
from loguru import logger

import torch
Expand All @@ -13,28 +11,41 @@
from yolox.models.network_blocks import SiLU
from yolox.utils import replace_module

import argparse
import os


def make_parser():
parser = argparse.ArgumentParser("YOLOX onnx deploy")
parser.add_argument("--output-name", type=str, default="yolox.onnx",
help="output name of models")
parser.add_argument("--input", default="images", type=str,
help="input node name of onnx model")
parser.add_argument("--output", default="output", type=str,
help="output node name of onnx model")
parser.add_argument("-o", "--opset", default=11, type=int,
help="onnx opset version")
parser.add_argument("--no-onnxsim", action="store_true",
help="use onnxsim or not")
parser.add_argument("-f", "--exp_file", default=None, type=str,
help="expriment description file",)
parser.add_argument(
"--output-name", type=str, default="yolox.onnx", help="output name of models"
)
parser.add_argument(
"--input", default="images", type=str, help="input node name of onnx model"
)
parser.add_argument(
"--output", default="output", type=str, help="output node name of onnx model"
)
parser.add_argument(
"-o", "--opset", default=11, type=int, help="onnx opset version"
)
parser.add_argument("--no-onnxsim", action="store_true", help="use onnxsim or not")
parser.add_argument(
"-f",
"--exp_file",
default=None,
type=str,
help="expriment description file",
)
parser.add_argument("-expn", "--experiment-name", type=str, default=None)
parser.add_argument("-n", "--name", type=str, default=None,
help="model name")
parser.add_argument("-c", "--ckpt", default=None, type=str,
help="ckpt path")
parser.add_argument("opts", help="Modify config options using the command-line",
default=None, nargs=argparse.REMAINDER,)
parser.add_argument("-n", "--name", type=str, default=None, help="model name")
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)

return parser

Expand Down Expand Up @@ -80,6 +91,7 @@ def main():

if not args.no_onnxsim:
import onnx

from onnxsim import simplify

# use onnxsimplify to reduce reduent model.
Expand Down
Loading

0 comments on commit 87f31cc

Please sign in to comment.