diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index e551046a38a43..096569a81f493 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -171,12 +171,12 @@ def stat_scores_multiple_classes( >>> sups tensor([1., 0., 1., 1.]) """ - num_classes = get_num_classes(pred=pred, target=target, - num_classes=num_classes) - if pred.ndim == target.ndim + 1: pred = to_categorical(pred, argmax_dim=argmax_dim) + num_classes = get_num_classes(pred=pred, target=target, + num_classes=num_classes) + tps = torch.zeros((num_classes,), device=pred.device) fps = torch.zeros((num_classes,), device=pred.device) tns = torch.zeros((num_classes,), device=pred.device)