From c14cd67bc83f25f941e1d0ae636d8da2a89049f4 Mon Sep 17 00:00:00 2001 From: Younghun Roh Date: Mon, 3 Aug 2020 10:36:00 +0900 Subject: [PATCH] Update for #2781 --- pytorch_lightning/metrics/functional/classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index d0f4b1b62333d..f723e3fd614d0 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -171,11 +171,11 @@ def stat_scores_multiple_classes( >>> sups tensor([1., 0., 1., 1.]) """ - num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) - if pred.ndim == target.ndim + 1: pred = to_categorical(pred, argmax_dim=argmax_dim) + num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) + pred = pred.view((-1, )).long() target = target.view((-1, )).long()