Skip to content

Commit

Permalink
[Fix] fix loss computation in MSPNHead (#2993)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Louis authored Mar 26, 2024
1 parent de67839 commit e0eb5d4
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
4 changes: 3 additions & 1 deletion mmpose/models/heads/heatmap_heads/mspn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,9 @@ def loss(self,

keypoint_weights = torch.cat([
d.gt_instance_labels.keypoint_weights for d in batch_data_samples
]) # shape: [B*N, L, K]
],
dim=1)
keypoint_weights = keypoint_weights.transpose(0, 1) # [B*N, L, K]

# calculate losses over multiple stages and multiple units
losses = dict()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def _get_data_samples(self,
with_heatmap=True,
with_reg_label=False,
num_levels=num_levels)['data_samples']

return batch_data_samples

def test_init(self):
Expand Down Expand Up @@ -153,6 +154,10 @@ def test_loss(self):
(unit_channels, 32, 24), (unit_channels, 64, 48)])
batch_data_samples = self._get_data_samples(
batch_size=2, heatmap_size=(48, 64), num_levels=4)
for ds in batch_data_samples:
ds.gt_instance_labels = InstanceData(
keypoint_weights=ds.gt_instance_labels.keypoint_weights.
transpose(0, 1))
losses = head.loss(feats, batch_data_samples)

self.assertIsInstance(losses['loss_kpt'], torch.Tensor)
Expand Down Expand Up @@ -189,6 +194,10 @@ def test_loss(self):
(unit_channels, 32, 24), (unit_channels, 64, 48)])
batch_data_samples = self._get_data_samples(
batch_size=2, heatmap_size=(48, 64), num_levels=16)
for ds in batch_data_samples:
ds.gt_instance_labels = InstanceData(
keypoint_weights=ds.gt_instance_labels.keypoint_weights.
transpose(0, 1))
losses = head.loss(feats, batch_data_samples)

self.assertIsInstance(losses['loss_kpt'], torch.Tensor)
Expand Down

0 comments on commit e0eb5d4

Please sign in to comment.