-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
183 lines (151 loc) · 5.88 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# This code is modified from detectron2 by facebook research
# Link to github repo: https://github.com/facebookresearch/detectron2
import argparse
import glob
import multiprocessing as mp
import os
import time
import cv2
import tqdm
from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.utils.logger import setup_logger
from lib.predict.predictor import VisualizationDemo
from lib.predict.predictor_MLP import VisualizationDemoMLP
from lib.predict.predictor_MLP_nonlocalized import VisualizationDemoMLPNonLocalized
from utils.add_custom_config import *
import json
import copy
# constants
WINDOW_NAME = "Human-to-robots handovers"
def setup_cfg(args):
# load default config from file and command-line arguments
cfg_object = get_cfg()
cfg_keypoint = get_cfg()
# add additional config for head pose estimation
add_custom_config(cfg_keypoint)
cfg_object.merge_from_file(args.cfg_object)
cfg_keypoint.merge_from_file(args.cfg_keypoint)
# set threshold for object detection
cfg_object.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
cfg_object.MODEL.ROI_HEADS.NUM_CLASSES = 1
# set threshold for keypoint detection
cfg_keypoint.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
cfg_object.MODEL.WEIGHTS = args.obj_weights
cfg_keypoint.MODEL.WEIGHTS = args.keypoint_weights
cfg_object.freeze()
cfg_keypoint.freeze()
return cfg_object, cfg_keypoint
def get_parser():
parser = argparse.ArgumentParser(description="Detectron2 demo for builtin models")
parser.add_argument(
"--cfg-keypoint",
default="./configs/keypoint_rcnn_R_101_FPN_3x.yaml",
metavar="FILE",
help="path to keypoint config file",
)
parser.add_argument(
"--cfg-object",
default="./configs/object_faster_rcnn_R_101_FPN_3x.yaml",
metavar="FILE",
help="path to object detection config file",
)
parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.")
parser.add_argument("--video-input", help="Path to video file.")
parser.add_argument(
"--output",
help="A file or directory to save output visualizations. "
"If not given, will show output in an OpenCV window.",
)
parser.add_argument(
"--out-json",
default="./default.json",
metavar="FILE",
help="A file or directory to save output json. "
)
parser.add_argument(
"--confidence-threshold",
type=float,
default=0.85,
help="Minimum score for instance predictions to be shown",
)
parser.add_argument(
"--obj-weights",
type=str,
default="./pretrained-weights/Apple_Faster_RCNN_R_101_FPN_3x.pth",
help="Path to the object detection weights",
)
parser.add_argument(
"--keypoint-weights",
type=str,
default="detectron2://COCO-Keypoints/keypoint_rcnn_R_101_FPN_3x/138363331/model_final_997cc7.pkl",
help="Path to the keypoint detection weights",
)
parser.add_argument("--train", action="store_true", help="Run training.")
parser.add_argument("--non-localized", action="store_true", help="Run alternate modality.")
return parser
if __name__ == "__main__":
mp.set_start_method("spawn", force=True)
args = get_parser().parse_args()
setup_logger(name="fvcore")
logger = setup_logger()
logger.info("Arguments: " + str(args))
cfg_object, cfg_keypoint = setup_cfg(args)
# database in json
database_json = {}
database_json['annotation'] = {}
database_arr = []
if args.train:
demo = VisualizationDemo(cfg_object, cfg_keypoint)
elif args.non_localized:
demo = VisualizationDemoMLPNonLocalized(cfg_object, cfg_keypoint)
else:
demo = VisualizationDemoMLP(cfg_object, cfg_keypoint)
frame = 0
if args.webcam:
cam = cv2.VideoCapture(0)
for vis in tqdm.tqdm(demo.run_on_video(cam)):
cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
cv2.imshow(WINDOW_NAME, vis)
if cv2.waitKey(1) == 27:
break # esc to quit
cv2.destroyAllWindows()
elif args.video_input:
video = cv2.VideoCapture(args.video_input)
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
frames_per_second = video.get(cv2.CAP_PROP_FPS)
num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
basename = os.path.basename(args.video_input)
if args.output:
if os.path.isdir(args.output):
output_fname = os.path.join(args.output, basename)
output_fname = os.path.splitext(output_fname)[0] + ".mkv"
else:
output_fname = args.output
assert not os.path.isfile(output_fname), output_fname
output_file = cv2.VideoWriter(
filename=output_fname,
fourcc=cv2.VideoWriter_fourcc(*"mp4v"),
fps=float(frames_per_second),
frameSize=(width, height),
isColor=True,
)
assert os.path.isfile(args.video_input)
for vis_frame, data_json in tqdm.tqdm(demo.run_on_video(video), total=num_frames):
if args.output:
output_file.write(vis_frame)
else:
cv2.namedWindow(basename, cv2.WINDOW_NORMAL)
cv2.imshow(basename, vis_frame)
if cv2.waitKey(1) == 27:
break # esc to quit
database_arr.append(copy.deepcopy(data_json))
database_json["annotation"] = database_arr
with open(args.out_json, 'w') as json_file:
json.dump(database_json, json_file)
video.release()
if args.output:
output_file.release()
else:
cv2.destroyAllWindows()