Skip to content

Commit

Permalink
Fix false num_classes warning in metrics (#2781)
Browse files Browse the repository at this point in the history
* Fix num_classes warning

Put to_categorical before get_num_classes in metrics/functional/classification.py

* Update classification.py

Remove whitespaces in blank line.
  • Loading branch information
pwwang committed Aug 2, 2020
1 parent 8baec1a commit c600ca6
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,12 @@ def stat_scores_multiple_classes(
>>> sups
tensor([1., 0., 1., 1.])
"""
num_classes = get_num_classes(pred=pred, target=target,
num_classes=num_classes)

if pred.ndim == target.ndim + 1:
pred = to_categorical(pred, argmax_dim=argmax_dim)

num_classes = get_num_classes(pred=pred, target=target,
num_classes=num_classes)

tps = torch.zeros((num_classes,), device=pred.device)
fps = torch.zeros((num_classes,), device=pred.device)
tns = torch.zeros((num_classes,), device=pred.device)
Expand Down

0 comments on commit c600ca6

Please sign in to comment.