Skip to content

Commit

Permalink
update form TaksAligned code on utils/ tal.py
Browse files Browse the repository at this point in the history
  • Loading branch information
positive666 committed May 22, 2023
1 parent 0463a7d commit aba2878
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions utils/tal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9, roll_out=False):
Return:
(Tensor): shape(b, n_boxes, h*w)
"""
"""该函数的作用是通过码头坐标和真实框的位置信息,在所有anchor中选择位于真实框内部或者与其IoU大于阈值的anchor点,并返回一个(b, n_boxes, h*w)的张量表示所选择的anchor点
下面对该函数的每句话代码进行注释和讲解:"""
"""该函数的作用是通过bbox坐标和真实框的位置信息,在所有anchor中选择位于真实框内部或者与其IoU大于阈值的anchor点,并返回一个(b, n_boxes, h*w)的张量表示所选择的anchor点"""

n_anchors = xy_centers.shape[0]
bs, n_boxes, _ = gt_bboxes.shape # 每个GT有多少个Anchor
if roll_out:
Expand Down Expand Up @@ -103,7 +103,7 @@ def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_g
return (torch.full_like(pd_scores[..., 0], self.bg_idx).to(device), torch.zeros_like(pd_bboxes).to(device),
torch.zeros_like(pd_scores).to(device), torch.zeros_like(pd_scores[..., 0]).to(device),
torch.zeros_like(pd_scores[..., 0]).to(device))
# 获取正样本掩码、匹配度、重叠度
# 获取正样本掩码、匹配度、重叠度
mask_pos, align_metric, overlaps = self.get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points,
mask_gt)
# get target IOU match:解决一个anchor和多个GT框匹配问题
Expand Down Expand Up @@ -141,8 +141,8 @@ def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes):
overlaps = torch.empty((self.bs, self.n_max_boxes, pd_scores.shape[1]), device=pd_scores.device)
ind_0 = torch.empty(self.n_max_boxes, dtype=torch.long)
for b in range(self.bs):
""" gt_labes info --->improve (使用 roll_out 策略时,只计算那些被标签所覆盖的边框与 GT 之间的 CIoU,减少了计算量
而对于那些不被 ground truth 标签所覆盖的边框,将被舍弃,避免了计算冗余和过多内存消耗) """
""" gt_labes info - (使用roll_out 策略时,只计算那些被标签所覆盖的边框与 GT 之间的 CIoU,减少了计算量
而对于那些不被 ground truth 标签所覆盖的边框,将被舍弃,避免计算冗余和过多内存消耗) """
# form gt_label
ind_0[:], ind_2 = b, gt_labels[b].squeeze(-1).long()
# get the scores of each grid for each gt cls
Expand Down

0 comments on commit aba2878

Please sign in to comment.