''' Author: Oaktree.AI Date: 2021-06-01 09:49:22 LastEditors: Oaktree.AI LastEditTime: 2021-06-02 00:32:22 Description: ''' """trt_yolo_cv2.py This script could be used to make object detection video with TensorRT optimized YOLO engine. "cv" means "create video" made by BigJoon (ref. jkjung-avt) """ import os import time import numpy as np import argparse import cv2 from threading import Thread import pycuda.autoinit # This is needed for initializing CUDA driver from utils.yolo_classes import get_cls_dict from utils.camera import add_camera_args, Camera from utils.visualization import BBoxVisualization from utils.yolo_with_plugins import TrtYOLO def parse_args(): """Parse input arguments.""" desc = ('Run the TensorRT optimized object detecion model on an input ' 'video and save BBoxed overlaid output as another video.') parser = argparse.ArgumentParser(description=desc) parser = argparse.ArgumentParser(description=desc) parser = add_camera_args(parser) # parser.add_argument( # '-v', '--video', type=str, required=True, # help='input video file name') parser.add_argument( '-s', '--source', type=str, required=True, help='input camare file name') # parser.add_argument( # '-o', '--output', type=str, required=True, # help='output video file name') parser.add_argument( '-c', '--category_num', type=int, default=5, help='number of object categories [80]') parser.add_argument( '-m', '--model', type=str, required=True, help=('[yolov3-tiny|yolov3|yolov3-spp|yolov4-tiny|yolov4|' 'yolov4-csp|yolov4x-mish]-[{dimension}], where ' '{dimension} could be either a single number (e.g. ' '288, 416, 608) or 2 numbers, WxH (e.g. 416x256)')) parser.add_argument( '-l', '--letter_box', action='store_true', help='inference with letterboxed image [False]') args = parser.parse_args() return args def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True): # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232 try: shape = img.shape[:2] # current shape [height, width] except Exception as e: pass if isinstance(new_shape, int): new_shape = (new_shape, new_shape) # Scale ratio (new / old) r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) if not scaleup: # only scale down, do not scale up (for better test mAP) r = min(r, 1.0) # Compute padding ratio = r, r # width, height ratios new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - \ new_unpad[1] # wh padding if auto: # minimum rectangle dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding elif scaleFill: # stretch dw, dh = 0.0, 0.0 new_unpad = (new_shape[1], new_shape[0]) ratio = new_shape[1] / shape[1], new_shape[0] / \ shape[0] # width, height ratios dw /= 2 # divide padding into 2 sides dh /= 2 if shape[::-1] != new_unpad: # resize img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border return img, ratio, (dw, dh) class LoadStreams: # multiple IP or RTSP cameras def __init__(self, sources='streams.txt', img_size=640): self.mode = 'images' self.img_size = img_size self.cam_id = None if os.path.isfile(sources): with open(sources, 'r') as f: sources = [x.strip() for x in f.read().splitlines() if len(x.strip())] else: sources = [sources] n = len(sources) self.imgs = [None] * n self.sources = sources for i, s in enumerate(sources): # Start the thread to read frames from the video stream print('%g/%g: %s... ' % (i + 1, n, s), end='') cap = cv2.VideoCapture(eval(s) if s.isnumeric() else s) assert cap.isOpened(), 'Failed to open %s' % s w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = cap.get(cv2.CAP_PROP_FPS) % 100 _, self.imgs[i] = cap.read() # guarantee first frame thread = Thread(target=self.update, args=([i, cap]), daemon=True) print(' success (%gx%g at %.2f FPS).' % (w, h, fps)) thread.start() print('') # newline # check for common shapes s = np.stack([letterbox(x, new_shape=self.img_size)[ 0].shape for x in self.imgs], 0) # inference shapes # rect inference if all shapes equal self.rect = np.unique(s, axis=0).shape[0] == 1 if not self.rect: print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.') def update(self, index, cap): self.cam_id = index # Read next stream frame in a daemon thread n = 0 while cap.isOpened(): n += 1 # _, self.imgs[index] = cap.read() cap.grab() if n == 1: # read every 4th frame _, self.imgs[index] = cap.retrieve() n = 0 # time.sleep(0.01) # wait time def __iter__(self): self.count = -1 return self def __next__(self): self.count += 1 img0 = self.imgs.copy() if cv2.waitKey(1) == ord('q'): # q to quit cv2.destroyAllWindows() raise StopIteration # Letterbox img = [letterbox(x, new_shape=self.img_size, auto=self.rect)[0] for x in img0] # Stack img = np.stack(img, 0) # Convert # BGR to RGB, to bsx3x416x416 img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) img = np.ascontiguousarray(img) return self.sources, self.imgs, img0, None, self.cam_id, img def __len__(self): return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years def loop_and_detect(trt_yolo, conf_th, vis, dataset): """Continuously capture images from camera and do object detection. # Arguments cap: the camera instance (video source). trt_yolo: the TRT YOLO object detector instance. conf_th: confidence/score threshold for object detection. vis: for visualization. writer: the VideoWriter object for the output video. """ for _, raw_image, _, _, cam_id, _ in dataset: for i, frame in enumerate(raw_image): boxes, confs, clss = trt_yolo.detect(frame, conf_th) frame = vis.draw_bboxes(frame, boxes, confs, clss) img_show = cv2.resize(frame, (348, 348)) cv2.imshow(str(i), img_show) # print('.', end='', flush=True) def main(): args = parse_args() if args.category_num <= 0: raise SystemExit('ERROR: bad category_num (%d)!' % args.category_num) if not os.path.isfile('yolo/%s.trt' % args.model): raise SystemExit('ERROR: file (yolo/%s.trt) not found!' % args.model) dataset = LoadStreams(args.source) cls_dict = get_cls_dict(args.category_num) vis = BBoxVisualization(cls_dict) trt_yolo = TrtYOLO(args.model, args.category_num, args.letter_box) loop_and_detect(trt_yolo, conf_th=0.5, vis=vis, dataset=dataset) if __name__ == '__main__': main()