From 543361207f3a1bc751383ef6e1609d48ce01e740 Mon Sep 17 00:00:00 2001 From: pwwang <1188067+pwwang@users.noreply.github.com> Date: Fri, 31 Jul 2020 10:33:43 -0700 Subject: [PATCH 1/2] Fix num_classes warning Put to_categorical before get_num_classes in metrics/functional/classification.py --- pytorch_lightning/metrics/functional/classification.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index e551046a38a43..1a9b0729d7919 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -171,11 +171,11 @@ def stat_scores_multiple_classes( >>> sups tensor([1., 0., 1., 1.]) """ - num_classes = get_num_classes(pred=pred, target=target, - num_classes=num_classes) - if pred.ndim == target.ndim + 1: pred = to_categorical(pred, argmax_dim=argmax_dim) + + num_classes = get_num_classes(pred=pred, target=target, + num_classes=num_classes) tps = torch.zeros((num_classes,), device=pred.device) fps = torch.zeros((num_classes,), device=pred.device) From e0ee0a26f818a38e5c7028436a968b37de00df7f Mon Sep 17 00:00:00 2001 From: pwwang <1188067+pwwang@users.noreply.github.com> Date: Fri, 31 Jul 2020 10:39:39 -0700 Subject: [PATCH 2/2] Update classification.py Remove whitespaces in blank line. --- pytorch_lightning/metrics/functional/classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 1a9b0729d7919..096569a81f493 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -173,7 +173,7 @@ def stat_scores_multiple_classes( """ if pred.ndim == target.ndim + 1: pred = to_categorical(pred, argmax_dim=argmax_dim) - + num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes)