-
Notifications
You must be signed in to change notification settings - Fork 0
/
trmot.py
75 lines (67 loc) · 3.23 KB
/
trmot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from collections import defaultdict
import torch
from .base import Tracker
from .tr_mot.multitracker import JDETracker, STrack
class TRMOTTracker(Tracker):
def __init__(self, type_names, frame_rate,
max_age=2, min_iou=0.2,
feature_thres=0.7, feature_buffer_size=1):
super(TRMOTTracker, self).__init__(type_names, frame_rate)
STrack.reset_id()
self.trackers = {}
for obj_type in self.type_names:
self.trackers[obj_type] = JDETracker(
int(max_age * frame_rate), feature_thres,
1 - min_iou, 1 - min_iou / 2)
self.feature_buffer_size = int(feature_buffer_size * frame_rate)
self.active_tracks = set()
def convert_to_tracks(self, detection):
grouped_tracks = defaultdict(list)
for obj_i in range(len(detection)):
obj_type = self.type_names(detection.object_types[obj_i].item())
bbox = detection.image_boxes[obj_i].numpy()
tlwh = STrack.tlbr_to_tlwh(bbox)
score = detection.detection_scores[obj_i].item()
feature = detection.image_features[obj_i].numpy().copy()
track = STrack(
tlwh, score, feature, obj_i, self.feature_buffer_size)
grouped_tracks[obj_type].append(track)
return grouped_tracks
def get_tracked_detection(self, detection):
track_ids = torch.full((len(detection),), -1, dtype=torch.int)
states = torch.zeros((len(detection),), dtype=torch.int)
track_boxes = torch.zeros((len(detection), 4))
image_speeds = torch.zeros((len(detection), 2))
for tracker in self.trackers.values():
for track in tracker.tracked_stracks:
self.active_tracks.add(track.track_id)
obj_i = track.obj_index
track_ids[obj_i] = track.track_id
states[obj_i] = track.state
track_boxes[obj_i] = torch.as_tensor(
track.tlbr, dtype=torch.float)
speed = torch.as_tensor([
track.mean[4], track.mean[5] + track.mean[7] / 2])
image_speeds[obj_i] = speed * self.frame_rate
assert (track_ids >= 0).all(), 'Not all objects are tracked'
detection.track_ids = track_ids
detection.track_states = states
detection.track_boxes = track_boxes
detection.image_speeds = image_speeds
ongoing_track_ids = set()
for tracker in self.trackers.values():
ongoing_track_ids.update([
t.track_id for t in tracker.tracked_stracks])
ongoing_track_ids.update([
t.track_id for t in tracker.lost_stracks])
finished_track_ids = self.active_tracks - ongoing_track_ids
self.active_tracks = self.active_tracks - finished_track_ids
finished_track_ids = torch.as_tensor(
[*finished_track_ids], dtype=torch.int)
return finished_track_ids
def __call__(self, detection):
grouped_tracks = self.convert_to_tracks(detection)
for obj_type, tracker in self.trackers.items():
tracker.update(grouped_tracks[obj_type])
finished_track_ids = self.get_tracked_detection(detection)
return finished_track_ids