Skip to content

Commit

Permalink
Faster Accuracy metric (#2775)
Browse files Browse the repository at this point in the history
* Faster classfication stats

* Faster accuracy metric

* minor change on cls metric

* Add out-of-bound class clamping

* Add more tests and minor fixes

* Resolve code style warning

* Update for #2781

* hotfix

* Update pytorch_lightning/metrics/functional/classification.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update about conversation

* Add docstring on stat_scores_multiple_classes

Co-authored-by: Younghun Roh <yhunroh@mindslab.ai>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
3 people committed Aug 6, 2020
1 parent dd78be5 commit ac4a215
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 24 deletions.
81 changes: 63 additions & 18 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ def stat_scores_multiple_classes(
target: torch.Tensor,
num_classes: Optional[int] = None,
argmax_dim: int = 1,
reduction: str = 'none',
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Calls the stat_scores function iteratively for all classes, thus
calculating the number of true postive, false postive, true negative
Calculates the number of true postive, false postive, true negative
and false negative for each class
Args:
Expand All @@ -150,6 +150,12 @@ def stat_scores_multiple_classes(
num_classes: number of classes if known
argmax_dim: if pred is a tensor of probabilities, this indicates the
axis the argmax transformation will be applied over
reduction: method for reducing result values (default: none)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
Return:
True Positive, False Positive, True Negative, False Negative, Support
Expand All @@ -173,16 +179,58 @@ def stat_scores_multiple_classes(
if pred.ndim == target.ndim + 1:
pred = to_categorical(pred, argmax_dim=argmax_dim)

num_classes = get_num_classes(pred=pred, target=target,
num_classes=num_classes)
num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes)

tps = torch.zeros((num_classes,), device=pred.device)
fps = torch.zeros((num_classes,), device=pred.device)
tns = torch.zeros((num_classes,), device=pred.device)
fns = torch.zeros((num_classes,), device=pred.device)
sups = torch.zeros((num_classes,), device=pred.device)
for c in range(num_classes):
tps[c], fps[c], tns[c], fns[c], sups[c] = stat_scores(pred=pred, target=target, class_index=c)
if pred.dtype != torch.bool:
pred.clamp_max_(max=num_classes)
if target.dtype != torch.bool:
target.clamp_max_(max=num_classes)

possible_reductions = ('none', 'sum', 'elementwise_mean')
if reduction not in possible_reductions:
raise ValueError("reduction type %s not supported" % reduction)

if reduction == 'none':
pred = pred.view((-1, )).long()
target = target.view((-1, )).long()

tps = torch.zeros((num_classes + 1,), device=pred.device)
fps = torch.zeros((num_classes + 1,), device=pred.device)
tns = torch.zeros((num_classes + 1,), device=pred.device)
fns = torch.zeros((num_classes + 1,), device=pred.device)
sups = torch.zeros((num_classes + 1,), device=pred.device)

match_true = (pred == target).float()
match_false = 1 - match_true

tps.scatter_add_(0, pred, match_true)
fps.scatter_add_(0, pred, match_false)
fns.scatter_add_(0, target, match_false)
tns = pred.size(0) - (tps + fps + fns)
sups.scatter_add_(0, target, torch.ones_like(match_true))

tps = tps[:num_classes]
fps = fps[:num_classes]
tns = tns[:num_classes]
fns = fns[:num_classes]
sups = sups[:num_classes]

elif reduction == 'sum' or reduction == 'elementwise_mean':
count_match_true = (pred == target).sum().float()
oob_tp, oob_fp, oob_tn, oob_fn, oob_sup = stat_scores(pred, target, num_classes, argmax_dim)

tps = count_match_true - oob_tp
fps = pred.nelement() - count_match_true - oob_fp
fns = pred.nelement() - count_match_true - oob_fn
tns = pred.nelement() * (num_classes + 1) - (tps + fps + fns + oob_tn)
sups = pred.nelement() - oob_sup.float()

if reduction == 'elementwise_mean':
tps /= num_classes
fps /= num_classes
fns /= num_classes
tns /= num_classes
sups /= num_classes

return tps, fps, tns, fns, sups

Expand Down Expand Up @@ -218,16 +266,13 @@ def accuracy(
tensor(0.7500)
"""
tps, fps, tns, fns, sups = stat_scores_multiple_classes(
pred=pred, target=target, num_classes=num_classes)

if not (target > 0).any() and num_classes is None:
raise RuntimeError("cannot infer num_classes when target is all zero")

if reduction in ('elementwise_mean', 'sum'):
return reduce(sum(tps) / sum(sups), reduction=reduction)
if reduction == 'none':
return reduce(tps / sups, reduction=reduction)
tps, fps, tns, fns, sups = stat_scores_multiple_classes(
pred=pred, target=target, num_classes=num_classes, reduction=reduction)

return tps / sups


def confusion_matrix(
Expand Down
16 changes: 10 additions & 6 deletions tests/metrics/functional/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,19 @@ def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expect
assert sup.item() == expected_support


@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp',
@pytest.mark.parametrize(['pred', 'target', 'reduction', 'expected_tp', 'expected_fp',
'expected_tn', 'expected_fn', 'expected_support'], [
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]),
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 'none',
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]),
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'none',
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]),
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]),
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2])
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'sum',
torch.tensor(2), torch.tensor(2), torch.tensor(14), torch.tensor(2), torch.tensor(4)),
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'elementwise_mean',
torch.tensor(0.4), torch.tensor(0.4), torch.tensor(2.8), torch.tensor(0.4), torch.tensor(0.8))
])
def test_stat_scores_multiclass(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target)
def test_stat_scores_multiclass(pred, target, reduction, expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target, reduction=reduction)

assert torch.allclose(torch.tensor(expected_tp).to(tp), tp)
assert torch.allclose(torch.tensor(expected_fp).to(fp), fp)
Expand Down

0 comments on commit ac4a215

Please sign in to comment.