Skip to content

Commit

Permalink
Fix YoloX loss to handle negative batch correctly and added a test to…
Browse files Browse the repository at this point in the history
… test this. (#812)
  • Loading branch information
BloodAxe authored Mar 31, 2023
1 parent 0256c7d commit be15981
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/super_gradients/training/losses/yolox_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,15 +684,16 @@ def _compute_loss(self, predictions: List[torch.Tensor], targets: torch.Tensor)
bbox_preds, cls_preds, obj_preds, expanded_strides, x_shifts, y_shifts, targets
)

num_gts = max(flattened_gts.shape[0], 1)
num_gts = flattened_gts.shape[0]
num_gts_clamped = max(flattened_gts.shape[0], 1)
num_fg = max(matched_gt_ids.shape[0], 1)
total_num_anchors = max(transformed_outputs.shape[0] * transformed_outputs.shape[1], 1)

cls_targets = F.one_hot(matched_gt_classes.to(torch.int64), self.num_classes) * matched_ious.unsqueeze(dim=1)
obj_targets = transformed_outputs.new_zeros((transformed_outputs.shape[0], transformed_outputs.shape[1]))
obj_targets[matched_img_ids, matched_fg_ids] = 1
reg_targets = flattened_gts[matched_gt_ids][:, 1:]
if self.use_l1:
if self.use_l1 and num_gts > 0:
l1_targets = self.get_l1_target(
transformed_outputs.new_zeros((num_fg, 4)),
flattened_gts[matched_gt_ids][:, 1:],
Expand All @@ -707,7 +708,8 @@ def _compute_loss(self, predictions: List[torch.Tensor], targets: torch.Tensor)
loss_iou = self.iou_loss(bbox_preds[matched_img_ids, matched_fg_ids], reg_targets).sum() / num_fg
loss_obj = self.bcewithlog_loss(obj_preds.squeeze(-1), obj_targets).sum() / (total_num_anchors if self.obj_loss_fix else num_fg)
loss_cls = self.bcewithlog_loss(cls_preds[matched_img_ids, matched_fg_ids], cls_targets).sum() / num_fg
if self.use_l1:

if self.use_l1 and num_gts > 0:
loss_l1 = self.l1_loss(raw_outputs[matched_img_ids, matched_fg_ids], l1_targets).sum() / num_fg
else:
loss_l1 = 0.0
Expand All @@ -723,7 +725,7 @@ def _compute_loss(self, predictions: List[torch.Tensor], targets: torch.Tensor)
loss_obj.unsqueeze(0),
loss_cls.unsqueeze(0),
torch.tensor(loss_l1).unsqueeze(0).to(transformed_outputs.device),
torch.tensor(num_fg / max(num_gts, 1)).unsqueeze(0).to(transformed_outputs.device),
torch.tensor(num_fg / num_gts_clamped).unsqueeze(0).to(transformed_outputs.device),
loss.unsqueeze(0),
)
).detach(),
Expand Down
29 changes: 29 additions & 0 deletions tests/unit_tests/yolox_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import torch

from super_gradients.training.losses import YoloXDetectionLoss, YoloXFastDetectionLoss
from super_gradients.training.models.detection_models.yolox import YoloX_N, YoloX_T, YoloX_S, YoloX_M, YoloX_L, YoloX_X
from super_gradients.training.utils.detection_utils import DetectionCollateFN
from super_gradients.training.utils.utils import HpmStruct


Expand Down Expand Up @@ -39,6 +41,33 @@ def test_yolox_creation(self):
output_augment = yolo_model(dummy_input)
self.assertIsNotNone(output_augment)

def test_yolox_loss(self):
samples = [
(torch.zeros((3, 256, 256)), torch.zeros((100, 5))),
(torch.zeros((3, 256, 256)), torch.zeros((100, 5))),
(torch.zeros((3, 256, 256)), torch.zeros((100, 5))),
(torch.zeros((3, 256, 256)), torch.zeros((100, 5))),
(torch.zeros((3, 256, 256)), torch.zeros((100, 5))),
]
collate = DetectionCollateFN()
_, targets = collate(samples)

predictions = [
torch.randn((5, 1, 256 // 8, 256 // 8, 4 + 1 + 10)),
torch.randn((5, 1, 256 // 16, 256 // 16, 4 + 1 + 10)),
torch.randn((5, 1, 256 // 32, 256 // 32, 4 + 1 + 10)),
]

for loss in [
YoloXDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=True, iou_type="giou"),
YoloXDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=True, iou_type="iou"),
YoloXDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=False),
YoloXFastDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=True),
YoloXFastDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=False),
]:
result = loss(predictions, targets)
print(result)


if __name__ == "__main__":
unittest.main()

0 comments on commit be15981

Please sign in to comment.