diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index d0f4b1b62333d..f723e3fd614d0 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -171,11 +171,11 @@ def stat_scores_multiple_classes( >>> sups tensor([1., 0., 1., 1.]) """ - num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) - if pred.ndim == target.ndim + 1: pred = to_categorical(pred, argmax_dim=argmax_dim) + num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) + pred = pred.view((-1, )).long() target = target.view((-1, )).long()