Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster Accuracy metric #2775

Merged
merged 14 commits into from
Aug 6, 2020
73 changes: 55 additions & 18 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ def stat_scores_multiple_classes(
target: torch.Tensor,
num_classes: Optional[int] = None,
argmax_dim: int = 1,
reduction: str = 'none'
Diuven marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Calls the stat_scores function iteratively for all classes, thus
calculating the number of true postive, false postive, true negative
Calculates the number of true postive, false postive, true negative
and false negative for each class

Args:
Expand Down Expand Up @@ -173,16 +173,56 @@ def stat_scores_multiple_classes(
if pred.ndim == target.ndim + 1:
pred = to_categorical(pred, argmax_dim=argmax_dim)

num_classes = get_num_classes(pred=pred, target=target,
num_classes=num_classes)
num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes)

tps = torch.zeros((num_classes,), device=pred.device)
fps = torch.zeros((num_classes,), device=pred.device)
tns = torch.zeros((num_classes,), device=pred.device)
fns = torch.zeros((num_classes,), device=pred.device)
sups = torch.zeros((num_classes,), device=pred.device)
for c in range(num_classes):
tps[c], fps[c], tns[c], fns[c], sups[c] = stat_scores(pred=pred, target=target, class_index=c)
pred = pred.view((-1, )).long()
Diuven marked this conversation as resolved.
Show resolved Hide resolved
target = target.view((-1, )).long()

pred.clamp_max_(max=num_classes)
target.clamp_max_(max=num_classes)

possible_reductions = ('none', 'sum', 'elementwise_mean')
if reduction not in possible_reductions:
raise ValueError("reduction type %s not supported" % reduction)

if reduction == 'none':
tps = torch.zeros((num_classes + 1,), device=pred.device)
fps = torch.zeros((num_classes + 1,), device=pred.device)
tns = torch.zeros((num_classes + 1,), device=pred.device)
fns = torch.zeros((num_classes + 1,), device=pred.device)
sups = torch.zeros((num_classes + 1,), device=pred.device)

match_true = (pred == target).float()
match_false = 1 - match_true

tps.scatter_add_(0, pred, match_true)
fps.scatter_add_(0, pred, match_false)
fns.scatter_add_(0, target, match_false)
tns = pred.size(0) - (tps + fps + fns)
sups.scatter_add_(0, target, torch.ones_like(match_true))

tps = tps[:num_classes]
fps = fps[:num_classes]
tns = tns[:num_classes]
fns = fns[:num_classes]
sups = sups[:num_classes]

elif reduction == 'sum' or reduction == 'elementwise_mean':
count_match_true = (pred == target).sum().float()
oob_tp, oob_fp, oob_tn, oob_fn, oob_sup = stat_scores(pred, target, num_classes, argmax_dim)

tps = count_match_true - oob_tp
fps = pred.nelement() - count_match_true - oob_fp
fns = pred.nelement() - count_match_true - oob_fn
tns = pred.nelement() * (num_classes + 1) - (tps + fps + fns + oob_tn)
sups = pred.nelement() - oob_sup.float()

if reduction == 'elementwise_mean':
tps /= num_classes
fps /= num_classes
fns /= num_classes
tns /= num_classes
sups /= num_classes

return tps, fps, tns, fns, sups

Expand Down Expand Up @@ -218,16 +258,13 @@ def accuracy(
tensor(0.7500)

"""
tps, fps, tns, fns, sups = stat_scores_multiple_classes(
pred=pred, target=target, num_classes=num_classes)

if not (target > 0).any() and num_classes is None:
raise RuntimeError("cannot infer num_classes when target is all zero")

if reduction in ('elementwise_mean', 'sum'):
return reduce(sum(tps) / sum(sups), reduction=reduction)
if reduction == 'none':
return reduce(tps / sups, reduction=reduction)
tps, fps, tns, fns, sups = stat_scores_multiple_classes(
pred=pred, target=target, num_classes=num_classes, reduction=reduction)

return tps / sups


def confusion_matrix(
Expand Down
16 changes: 10 additions & 6 deletions tests/metrics/functional/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,19 @@ def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expect
assert sup.item() == expected_support


@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp',
@pytest.mark.parametrize(['pred', 'target', 'reduction', 'expected_tp', 'expected_fp',
'expected_tn', 'expected_fn', 'expected_support'], [
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]),
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 'none',
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]),
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'none',
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]),
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]),
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2])
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'sum',
torch.tensor(2), torch.tensor(2), torch.tensor(14), torch.tensor(2), torch.tensor(4)),
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'elementwise_mean',
torch.tensor(0.4), torch.tensor(0.4), torch.tensor(2.8), torch.tensor(0.4), torch.tensor(0.8))
])
def test_stat_scores_multiclass(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target)
def test_stat_scores_multiclass(pred, target, reduction, expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target, reduction=reduction)

assert torch.allclose(torch.tensor(expected_tp).to(tp), tp)
assert torch.allclose(torch.tensor(expected_fp).to(fp), fp)
Expand Down