Skip to content

Commit

Permalink
Add QFocalLoss() (ultralytics#1482)
Browse files Browse the repository at this point in the history
* Update loss.py

implement the quality focal loss which is a more general case of focal loss
more detail in https://arxiv.org/abs/2006.04388 

In the obj loss (or the case cls loss with label smooth), the targets is no long barely be 0 or 1 (can be 0.7), in this case, the normal focal loss is not work accurately
quality focal loss in behave the same as focal loss when the target is equal to 0 or 1, and work accurately when targets in (0, 1)

example:

targets:
tensor([[0.6225, 0.0000, 0.0000],
        [0.9000, 0.0000, 0.0000],
        [1.0000, 0.0000, 0.0000]])
___________________________
pred_prob:
tensor([[0.6225, 0.2689, 0.1192],
        [0.7773, 0.5000, 0.2227],
        [0.8176, 0.8808, 0.1978]])
____________________________
focal_loss
tensor([[0.0937, 0.0328, 0.0039],
        [0.0166, 0.1838, 0.0199],
        [0.0039, 1.3186, 0.0145]])
______________
qfocal_loss
tensor([[7.5373e-08, 3.2768e-02, 3.9179e-03],
        [4.8601e-03, 1.8380e-01, 1.9857e-02],
        [3.9233e-03, 1.3186e+00, 1.4545e-02]])
 
we can see that targets[0][0] = 0.6255 is almost the same as pred_prob[0][0] = 0.6225, 
the targets[1][0] = 0.9 is greater then pred_prob[1][0] = 0.7773 by 0.1227
however, the focal loss[0][0] = 0.0937 larger then focal loss[1][0] = 0.0166 (which against the purpose of focal loss)

for the quality focal loss , it implement the case of targets not equal to 0 or 1

* Update loss.py
  • Loading branch information
yxNONG committed Nov 25, 2020
1 parent 78a8465 commit 489edb8
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,32 @@ def forward(self, pred, true):
return loss.sum()
else: # 'none'
return loss


class QFocalLoss(nn.Module):
# Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
super(QFocalLoss, self).__init__()
self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
self.gamma = gamma
self.alpha = alpha
self.reduction = loss_fcn.reduction
self.loss_fcn.reduction = 'none' # required to apply FL to each element

def forward(self, pred, true):
loss = self.loss_fcn(pred, true)

pred_prob = torch.sigmoid(pred) # prob from logits
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
modulating_factor = torch.abs(true - pred_prob) ** self.gamma
loss *= alpha_factor * modulating_factor

if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else: # 'none'
return loss


def compute_loss(p, targets, model): # predictions, targets, model
Expand Down

0 comments on commit 489edb8

Please sign in to comment.