Skip to content

Commit

Permalink
Weighted Loss ultralytics#9
Browse files Browse the repository at this point in the history
  • Loading branch information
manole-alexandru committed Mar 26, 2023
1 parent 2e74b05 commit 25b23f6
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def forward(self, pred, true):
else: # 'none'
return loss

def weighted_bce(y_pred, y_true, BETA=2):
weights = (y_true * (BETA - 1)) + 1
bce = nn.BCELoss(reduction='none')(y_pred, y_true)
wbce = torch.mean(bce * weights)
return wbce

class ComputeLoss:
sort_obj_iou = False
Expand Down Expand Up @@ -172,7 +177,8 @@ def __call__(self, preds, targets, seg_masks): # predictions, targets
# Mask Loss
# print('\n----------- PRED VALID: ', torch.all(pred_mask >= 0), '-----------------\n')
# print('\n----------- SEG MASK VALID: ', torch.all(seg_masks >= 0), '-----------------\n')
seg_loss = nn.functional.binary_cross_entropy_with_logits(pred_mask, seg_masks, reduction='none').mean()
# seg_loss = nn.functional.binary_cross_entropy_with_logits(pred_mask, seg_masks, reduction='none').mean()
seg_loss = weighted_bce(pred_mask, seg_masks)
# print('SEG_LOSS', seg_loss)
if torch.isnan(seg_loss):
print(pred_mask)
Expand Down

0 comments on commit 25b23f6

Please sign in to comment.