Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
LSH9832 committed Dec 13, 2022
1 parent 9a7ca30 commit 2b98bd6
Show file tree
Hide file tree
Showing 8 changed files with 854 additions and 0 deletions.
120 changes: 120 additions & 0 deletions detect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from datetime import datetime as date
from glob import glob
import os
from loguru import logger
import argparse
import cv2

from edgeyolo.detect import Detector, TRTDetector, draw


if __name__ == '__main__':

parser = argparse.ArgumentParser("EdgeYOLO Detect parser")
parser.add_argument("-w", "--weights", type=str, default="edgeyolo_coco.pth", help="weight file")
parser.add_argument("-c", "--conf-thres", type=float, default=0.25, help="confidence threshold")
parser.add_argument("-n", "--nms-thres", type=float, default=0.55, help="nms threshold")
parser.add_argument("--fp16", action="store_true", help="fp16")
parser.add_argument("--no-fuse", action="store_true", help="do not fuse model")
parser.add_argument("--input-size", type=int, nargs="+", default=[640, 640], help="input size: [height, width]")
parser.add_argument("-s", "--source", type=str, default="E:/videos/test.avi", help="video source or image dir")
parser.add_argument("--trt", action="store_true", help="is trt model")
parser.add_argument("--legacy", action="store_true", help="if img /= 255 while training, add this command.")
parser.add_argument("--use-decoder", action="store_true", help="support original yolox model v0.2.0")
parser.add_argument("--batch-size", type=int, default=1, help="batch size")
parser.add_argument("--no-label", action="store_true", help="do not draw label")
parser.add_argument("--save-dir", type=str, default="./imgs/coco", help="image result save dir")

args = parser.parse_args()
exist_save_dir = os.path.isdir(args.save_dir)

# detector setup
detector = TRTDetector if args.trt else Detector
detect = detector(
weight_file=args.weights,
conf_thres=args.conf_thres,
nms_thres=args.nms_thres,
input_size=args.input_size,
fuse=not args.no_fuse,
fp16=args.fp16,
use_decoder=args.use_decoder
)
if args.trt:
args.batch_size = detect.batch_size

# source loader setup
if os.path.isdir(args.source):

class DirCapture:

def __init__(self, dir_name):
self.imgs = []
for img_type in ["jpg", "png", "jpeg", "bmp", "webp"]:
self.imgs += sorted(glob(os.path.join(dir_name, f"*.{img_type}")))

def isOpened(self):
return bool(len(self.imgs))

def read(self):
print(self.imgs[0])
now_img = cv2.imread(self.imgs[0])
self.imgs = self.imgs[1:]
return now_img is not None, now_img

source = DirCapture(args.source)
delay = 0
else:
source = cv2.VideoCapture(int(args.source) if args.source.isdigit() else args.source)
delay = 1

all_dt = []
dts_len = 300 // args.batch_size
success = True

# start inference
while source.isOpened() and success:

frames = []
for _ in range(args.batch_size):
success, frame = source.read()
if not success:
if not len(frames):
cv2.destroyAllWindows()
break
else:
while len(frames) < args.batch_size:
frames.append(frames[-1])
else:
frames.append(frame)

if not len(frames):
break

results = detect(frames, args.legacy)
dt = detect.dt
all_dt.append(dt)
if len(all_dt) > dts_len:
all_dt = all_dt[-dts_len:]
print(f"\r{dt * 1000 / args.batch_size:.1f}ms "
f"average:{sum(all_dt) / len(all_dt) / args.batch_size * 1000:.1f}ms", end=" ")

key = -1
imgs = draw(frames, results, detect.class_names, 2, draw_label=not args.no_label)
# print([im.shape for im in frames])
for img in imgs:
# print(img.shape)
cv2.imshow("EdgeYOLO result", img)
key = cv2.waitKey(delay)
if key in [ord("q"), 27]:
break
elif key == ord(" "):
delay = 1 - delay
elif key == ord("s"):
if not exist_save_dir:
os.makedirs(args.save_dir, exist_ok=True)
file_name = f"{str(date.now()).split('.')[0].replace(':', '').replace('-', '').replace(' ', '')}.jpg"
cv2.imwrite(os.path.join(args.save_dir, file_name), img)
logger.info(f"image saved to {file_name}.")
if key in [ord("q"), 27]:
cv2.destroyAllWindows()
break
74 changes: 74 additions & 0 deletions onnx2trt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import yaml
import argparse
import os.path as osp
import os
from loguru import logger
import torch


def get_args():
parser = argparse.ArgumentParser("EdgeYOLO onnx2tensorrt parser")
parser.add_argument("-o", "--onnx", type=str, default="yolov7.onnx", help="ONNX file")
parser.add_argument("-y", "--yaml", type=str, default="yolov7.yaml", help="export params file")
parser.add_argument("-w", "--workspace", type=int, default=8, help="export memory workspace(GB)")
parser.add_argument("--fp16", action="store_true", help="fp16")
parser.add_argument("--int8", action="store_true", help="int8")
parser.add_argument("--best", action="store_true", help="best")
parser.add_argument("-d", "--dist-path", type=str, default="export_output/tensorrt")
parser.add_argument("--batch-size", type=int, default=0, help="batch-size")
return parser.parse_args()


def main():
args = get_args()

assert osp.isfile(args.onnx), f"No such file named {args.onnx}."
assert osp.isfile(args.yaml), f"No such file named {args.yaml}."

os.makedirs(args.dist_path, exist_ok=True)

name = args.onnx.replace("\\", "/").split("/")[-1][:-len(args.onnx.split(".")[-1])]

engine_file = osp.join(args.dist_path, name + "engine").replace("\\", "/")
pt_file = osp.join(args.dist_path, name + "pt").replace("\\", "/")
cls_file = osp.join(args.dist_path, name + "txt").replace("\\", "/")
params = yaml.load(open(args.yaml).read(), yaml.Loader)
command = f"trtexec --onnx={args.onnx}" \
f"{' --fp16' if args.fp16 else ' --int8' if args.int8 else ' --best' if args.best else ''} " \
f"--saveEngine={engine_file} --workspace={args.workspace*1024} " \
f"--batch={args.batch_size if not args.batch_size > 0 else params['batch_size'] if 'batch_size' in params else 1}"

logger.info("start converting onnx to tensorRT engine file.")
os.system(command)

if not osp.isfile(engine_file):
logger.error("tensorRT engine file convertion failed.")
return

logger.info(f"tensorRT engine saved to {engine_file}")

try:
data = {
"model": {
"engine": bytearray(open(engine_file, "rb").read()),
"input_names": params["input_name"],
"output_names": params["output_name"]
},
"names": params["names"],
"img_size": params["img_size"],
"batch_size": params["batch_size"]
}
class_str = ""
for name in params["names"]:
class_str += name + "\n"
with open(cls_file, "w") as cls_f:
cls_f.write(class_str[:-1])
logger.info(f"class names txt pt saved to {cls_file}")
torch.save(data, pt_file)
logger.info(f"tensorRT pt saved to {pt_file}")
except Exception as e:
logger.error(f"convert2pt error: {e}")


if __name__ == '__main__':
main()
140 changes: 140 additions & 0 deletions params/model/edgeyolo-visdrone.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# parameters
nc: 10 # number of classes
depth_multiple: 1.0 # models depth multiple
width_multiple: 1.0 # layer channel multiple

# anchor-box-free
anchors:
- [8, 8] # P3/8
- [16, 16] # P4/16
- [32, 32] # P5/32

# edgeyolo backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [32, 3, 1]], # 0

[-1, 1, Conv, [64, 3, 2]], # 1-P1/2
[-1, 1, Conv, [64, 3, 1]],

[-1, 1, Conv, [128, 3, 2]], # 3-P2/4
[-1, 1, Conv, [64, 1, 1]],
[-2, 1, Conv, [64, 1, 1]],
[-1, 1, Conv, [64, 3, 1]],
[-1, 1, Conv, [64, 3, 1]],
[-1, 1, Conv, [64, 3, 1]],
[-1, 1, Conv, [64, 3, 1]],
[[-1, -3, -5, -6], 1, Concat, [1]],
[-1, 1, Conv, [256, 1, 1]], # 11

[-1, 1, MP, []],
[-1, 1, Conv, [128, 1, 1]],
[-3, 1, Conv, [128, 1, 1]],
[-1, 1, Conv, [128, 3, 2]],
[[-1, -3], 1, Concat, [1]], # 16-P3/8
[-1, 1, Conv, [128, 1, 1]],
[-2, 1, Conv, [128, 1, 1]],
[-1, 1, Conv, [128, 3, 1]],
[-1, 1, Conv, [128, 3, 1]],
[-1, 1, Conv, [128, 3, 1]],
[-1, 1, Conv, [128, 3, 1]],
[[-1, -3, -5, -6], 1, Concat, [1]],
[-1, 1, Conv, [512, 1, 1]], # 24

[-1, 1, MP, []],
[-1, 1, Conv, [256, 1, 1]],
[-3, 1, Conv, [256, 1, 1]],
[-1, 1, Conv, [256, 3, 2]],
[[-1, -3], 1, Concat, [1]], # 29-P4/16
[-1, 1, Conv, [256, 1, 1]],
[-2, 1, Conv, [256, 1, 1]],
[-1, 1, Conv, [256, 3, 1]],
[-1, 1, Conv, [256, 3, 1]],
[-1, 1, Conv, [256, 3, 1]],
[-1, 1, Conv, [256, 3, 1]],
[[-1, -3, -5, -6], 1, Concat, [1]],
[-1, 1, Conv, [1024, 1, 1]], # 37

[-1, 1, MP, []],
[-1, 1, Conv, [512, 1, 1]],
[-3, 1, Conv, [512, 1, 1]],
[-1, 1, Conv, [512, 3, 2]],
[[-1, -3], 1, Concat, [1]], # 42-P5/32
[-1, 1, Conv, [256, 1, 1]],
[-2, 1, Conv, [256, 1, 1]],
[-1, 1, Conv, [256, 3, 1]],
[-1, 1, Conv, [256, 3, 1]],
[-1, 1, Conv, [256, 3, 1]],
[-1, 1, Conv, [256, 3, 1]],
[[-1, -3, -5, -6], 1, Concat, [1]],
[-1, 1, Conv, [1024, 1, 1]], # 50
]

# edgeyolo head
head:
[[-1, 1, SPPCSPC, [512]], # 51

[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[37, 1, Conv, [256, 1, 1]], # route backbone P4
[[-1, -2], 1, Concat, [1]],

[-1, 1, Conv, [256, 1, 1]],
[-2, 1, Conv, [256, 1, 1]],
[-1, 1, Conv, [128, 3, 1]],
[-1, 1, Conv, [128, 3, 1]],
[-1, 1, Conv, [128, 3, 1]],
[-1, 1, Conv, [128, 3, 1]],
[[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
[-1, 1, Conv, [256, 1, 1]], # 63

[-1, 1, Conv, [128, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[24, 1, Conv, [128, 1, 1]], # route backbone P3
[[-1, -2], 1, Concat, [1]],

[-1, 1, Conv, [128, 1, 1]],
[-2, 1, Conv, [128, 1, 1]],
[-1, 1, Conv, [64, 3, 1]],
[-1, 1, Conv, [64, 3, 1]],
[-1, 1, Conv, [64, 3, 1]],
[-1, 1, Conv, [64, 3, 1]],
[[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
[-1, 1, Conv, [128, 1, 1]], # 75

[-1, 1, MP, []],
[-1, 1, Conv, [128, 1, 1]],
[-3, 1, Conv, [128, 1, 1]],
[-1, 1, Conv, [128, 3, 2]],
[[-1, -3, 63], 1, Concat, [1]],

[-1, 1, Conv, [256, 1, 1]],
[-2, 1, Conv, [256, 1, 1]],
[-1, 1, Conv, [128, 3, 1]],
[-1, 1, Conv, [128, 3, 1]],
[-1, 1, Conv, [128, 3, 1]],
[-1, 1, Conv, [128, 3, 1]],
[[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
[-1, 1, Conv, [256, 1, 1]], # 88

[-1, 1, MP, []],
[-1, 1, Conv, [256, 1, 1]],
[-3, 1, Conv, [256, 1, 1]],
[-1, 1, Conv, [256, 3, 2]],
[[-1, -3, 51], 1, Concat, [1]],

[-1, 1, Conv, [512, 1, 1]],
[-2, 1, Conv, [512, 1, 1]],
[-1, 1, Conv, [256, 3, 1]],
[-1, 1, Conv, [256, 3, 1]],
[-1, 1, Conv, [256, 3, 1]],
[-1, 1, Conv, [256, 3, 1]],
[[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
[-1, 1, Conv, [512, 1, 1]], # 101

[75, 1, RepConv, [256, 3, 1]], # 102
[88, 1, RepConv, [512, 3, 1]], # 103
[101, 1, RepConv, [1024, 3, 1]], # 104

[[102,103,104], 1, YOLOXDetect, [nc, anchors, Conv]],
]
Loading

0 comments on commit 2b98bd6

Please sign in to comment.