-
Notifications
You must be signed in to change notification settings - Fork 3
/
mobilevit_track.py
193 lines (161 loc) · 8.34 KB
/
mobilevit_track.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import math
from lib.models.mobilevit_track.mobilevit_track import build_mobilevit_track
from lib.test.tracker.basetracker import BaseTracker
import torch
from lib.test.tracker.vis_utils import gen_visualization
from lib.test.utils.hann import hann2d
from lib.train.data.processing_utils import sample_target
# for debug
import cv2
import os
import numpy as np
from lib.test.tracker.data_utils_mobilevit import Preprocessor
from lib.utils.box_ops import clip_box
from lib.utils.ce_utils import generate_mask_cond
class MobileViTTrack(BaseTracker):
def __init__(self, params, dataset_name):
super(MobileViTTrack, self).__init__(params)
network = build_mobilevit_track(params.cfg, training=False)
network.load_state_dict(torch.load(self.params.checkpoint, map_location='cpu')['net'], strict=True)
self.cfg = params.cfg
if self.cfg.TEST.DEVICE == 'cpu':
self.device = 'cpu'
else:
self.device = 'cuda'
self.network = network.to(self.device) # network.cuda()
self.network.eval()
self.preprocessor = Preprocessor()
self.state = None
self.feat_sz = self.cfg.TEST.SEARCH_SIZE // self.cfg.MODEL.BACKBONE.STRIDE
# motion constraint
# self.output_window = hann2d(torch.tensor([self.feat_sz, self.feat_sz]).long(), centered=True).cuda()
self.output_window = hann2d(torch.tensor([self.feat_sz, self.feat_sz]).long(), centered=True).to(self.device)
# for debug
self.debug = params.debug
self.use_visdom = params.debug
self.frame_id = 0
if self.debug:
if not self.use_visdom:
self.save_dir = "debug"
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
else:
# self.add_hook()
self._init_visdom(None, 1)
# for save boxes from all queries
self.save_all_boxes = params.save_all_boxes
self.z_dict1 = {}
# save the model state dictionary only (to verify the actual model size)
self.save_state_dict = True
if self.save_state_dict:
model_name = self.params.checkpoint
torch.save(network.state_dict(), model_name.split('.pth.tar')[0] + '_state_dict.pt')
def initialize(self, image, info: dict):
# forward the template once
z_patch_arr, resize_factor, z_amask_arr = sample_target(image, info['init_bbox'], self.params.template_factor,
output_sz=self.params.template_size)
self.z_patch_arr = z_patch_arr
template = self.preprocessor.process(z_patch_arr, z_amask_arr)
with torch.no_grad():
# conv_1 (i.e., the first conv3x3 layer) output for
z = self.network.backbone.conv_1.forward(template.tensors.to(self.device))
# layer_1 (i.e., MobileNetV2 block) output
z = self.network.backbone.layer_1.forward(z)
# layer_2 (i.e., MobileNetV2 with down-sampling + 2 x MobileNetV2) output
z = self.network.backbone.layer_2.forward(z)
self.z_dict1 = z
self.box_mask_z = None
# save states
self.state = info['init_bbox']
self.frame_id = 0
if self.save_all_boxes:
'''save all predicted boxes'''
all_boxes_save = info['init_bbox'] * self.cfg.MODEL.NUM_OBJECT_QUERIES
return {"all_boxes": all_boxes_save}
def track(self, image, info: dict = None):
H, W, _ = image.shape
self.frame_id += 1
x_patch_arr, resize_factor, x_amask_arr = sample_target(image, self.state, self.params.search_factor,
output_sz=self.params.search_size) # (x1, y1, w, h)
search = self.preprocessor.process(x_patch_arr, x_amask_arr)
with torch.no_grad():
x_dict = search
# merge the template and the search
# run the transformer
out_dict = self.network.forward(
template=self.z_dict1.to(self.device), search=x_dict.tensors.to(self.device))
# add hann windows
pred_score_map = out_dict['score_map']
response = self.output_window * pred_score_map
pred_boxes = self.network.box_head.cal_bbox(response, out_dict['size_map'], out_dict['offset_map'])
pred_boxes = pred_boxes.view(-1, 4)
# Baseline: Take the mean of all pred boxes as the final result
pred_box = (pred_boxes.mean(
dim=0) * self.params.search_size / resize_factor).tolist() # (cx, cy, w, h) [0,1]
# get the final box result
self.state = clip_box(self.map_box_back(pred_box, resize_factor), H, W, margin=10)
# for debug
if self.debug:
if not self.use_visdom:
x1, y1, w, h = self.state
image_BGR = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
cv2.rectangle(image_BGR, (int(x1),int(y1)), (int(x1+w),int(y1+h)), color=(0,0,255), thickness=2)
save_path = os.path.join(self.save_dir, "%04d.jpg" % self.frame_id)
cv2.imwrite(save_path, image_BGR)
else:
self.visdom.register((image, self.state), 'Tracking', 1, 'Tracking')
self.visdom.register(torch.from_numpy(x_patch_arr).permute(2, 0, 1), 'image', 1, 'search_region')
self.visdom.register(torch.from_numpy(self.z_patch_arr).permute(2, 0, 1), 'image', 1, 'template')
self.visdom.register(pred_score_map.view(self.feat_sz, self.feat_sz), 'heatmap', 1, 'score_map')
self.visdom.register((pred_score_map * self.output_window).view(self.feat_sz, self.feat_sz), 'heatmap', 1, 'score_map_hann')
if 'removed_indexes_s' in out_dict and out_dict['removed_indexes_s']:
removed_indexes_s = out_dict['removed_indexes_s']
removed_indexes_s = [removed_indexes_s_i.cpu().numpy() for removed_indexes_s_i in removed_indexes_s]
masked_search = gen_visualization(x_patch_arr, removed_indexes_s)
self.visdom.register(torch.from_numpy(masked_search).permute(2, 0, 1), 'image', 1, 'masked_search')
while self.pause_mode:
if self.step:
self.step = False
break
if self.save_all_boxes:
'''save all predictions'''
all_boxes = self.map_box_back_batch(pred_boxes * self.params.search_size / resize_factor, resize_factor)
all_boxes_save = all_boxes.view(-1).tolist() # (4N, )
return {"target_bbox": self.state,
"all_boxes": all_boxes_save}
else:
return {"target_bbox": self.state}
def map_box_back(self, pred_box: list, resize_factor: float):
cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
cx, cy, w, h = pred_box
half_side = 0.5 * self.params.search_size / resize_factor
cx_real = cx + (cx_prev - half_side)
cy_real = cy + (cy_prev - half_side)
return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h]
def map_box_back_batch(self, pred_box: torch.Tensor, resize_factor: float):
cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
cx, cy, w, h = pred_box.unbind(-1) # (N,4) --> (N,)
half_side = 0.5 * self.params.search_size / resize_factor
cx_real = cx + (cx_prev - half_side)
cy_real = cy + (cy_prev - half_side)
return torch.stack([cx_real - 0.5 * w, cy_real - 0.5 * h, w, h], dim=-1)
def change(self, r):
return np.maximum(r, 1. / r)
def sz(self, w, h):
pad = (w + h) * 0.5
sz2 = (w + pad) * (h + pad)
return np.sqrt(sz2)
def sz_wh(self, wh):
pad = (wh[0] + wh[1]) * 0.5
sz2 = (wh[0] + pad) * (wh[1] + pad)
return np.sqrt(sz2)
def add_hook(self):
conv_features, enc_attn_weights, dec_attn_weights = [], [], []
for i in range(12):
self.network.backbone.blocks[i].attn.register_forward_hook(
# lambda self, input, output: enc_attn_weights.append(output[1])
lambda self, input, output: enc_attn_weights.append(output[1])
)
self.enc_attn_weights = enc_attn_weights
def get_tracker_class():
return MobileViTTrack