-
Notifications
You must be signed in to change notification settings - Fork 8
/
losses.py
25 lines (21 loc) · 919 Bytes
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# ------------------------------------------------------------------------
# Modified from UniMoCo (https://github.com/dddzg/unimoco)
# Copyright (c) Tencent, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
"""Definition of the Supervised Contrastive Loss
"""
from torch import nn
import torch
class SupContrastive(nn.Module):
def __init__(self, reduction='mean'):
super(SupContrastive, self).__init__()
self.reduction = reduction
def forward(self, y_pred, y_true):
sum_neg = ((1 - y_true) * torch.exp(y_pred)).sum(1).unsqueeze(1)
sum_pos = (y_true * torch.exp(-y_pred))
num_pos = y_true.sum(1)
loss = torch.log(1 + sum_neg * sum_pos).sum(1) / num_pos
if self.reduction == 'mean':
return torch.mean(loss)
else:
return loss