Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

关于CenterNet后接segmentation的请教 #13

Open
NeuZhangQiang opened this issue Nov 26, 2020 · 1 comment
Open

关于CenterNet后接segmentation的请教 #13

NeuZhangQiang opened this issue Nov 26, 2020 · 1 comment

Comments

@NeuZhangQiang
Copy link

NeuZhangQiang commented Nov 26, 2020

非常感谢您分享的代码。

由于我暂时还没有配置成功代码,所以只能干看代码,没法调试。

有一个问题想请教一下,就是原始的CenterNet的输出有三个分支,分别是 heatmap (W*H*C),offset (W*H*2)和size (W*H*2),然后你这里加了一个seg_feat,这个分支是怎么加的,能介绍一下吗?能否告知是在代码的哪一处?这里的seg_feat它的size是什么样子的?怎么为每个中心点分配一个mask?难道与offset和size一样,预测一个 W*H*W*H的seg_feat?

此外,代码中关于dice loss的计算,我也不是很明白:

    def forward(self, seg_feat, conv_weight, mask,ind, target):
        mask_loss=0.
        batch_size = seg_feat.size(0)
        weight = _tranpose_and_gather_feat(conv_weight, ind)
        h,w = seg_feat.size(-2),seg_feat.size(-1)
        x,y = ind%w,ind/w
        x_range = torch.arange(w).float().to(device=seg_feat.device)
        y_range = torch.arange(h).float().to(device=seg_feat.device)
        y_grid, x_grid = torch.meshgrid([y_range, x_range])
        for i in range(batch_size):
            num_obj = target[i].size(0)
            conv1w,conv1b,conv2w,conv2b,conv3w,conv3b= \
                torch.split(weight[i,:num_obj],[(self.feat_channel+2)*self.feat_channel,self.feat_channel,
                                          self.feat_channel**2,self.feat_channel,
                                          self.feat_channel,1],dim=-1)
            y_rel_coord = (y_grid[None,None] - y[i,:num_obj].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).float())/128.
            x_rel_coord = (x_grid[None,None] - x[i,:num_obj].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).float())/128.
            feat = seg_feat[i][None].repeat([num_obj,1,1,1])
            feat = torch.cat([feat,x_rel_coord, y_rel_coord],dim=1).view(1,-1,h,w)

            conv1w=conv1w.contiguous().view(-1,self.feat_channel+2,1,1)
            conv1b=conv1b.contiguous().flatten()
            feat = F.conv2d(feat,conv1w,conv1b,groups=num_obj).relu()

            conv2w=conv2w.contiguous().view(-1,self.feat_channel,1,1)
            conv2b=conv2b.contiguous().flatten()
            feat = F.conv2d(feat,conv2w,conv2b,groups=num_obj).relu()

            conv3w=conv3w.contiguous().view(-1,self.feat_channel,1,1)
            conv3b=conv3b.contiguous().flatten()
            feat = F.conv2d(feat,conv3w,conv3b,groups=num_obj).sigmoid().squeeze()

            true_mask = mask[i,:num_obj,None,None].float()
            mask_loss+=dice_loss(feat*true_mask,target[i]*true_mask)

        return mask_loss/batch_size

里面还进行了卷积计算?能否说明一下思路?

非常期待您的回复。

@BloodLemonS
Copy link

您好,请问这个代码您看懂了吗?您提到的这个问题我现在也比较疑惑,请问方便解答一下吗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants