Skip to content

Commit

Permalink
Update for Lightning-AI#2781
Browse files Browse the repository at this point in the history
  • Loading branch information
Younghun Roh committed Aug 3, 2020
1 parent 6df2018 commit c14cd67
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,11 @@ def stat_scores_multiple_classes(
>>> sups
tensor([1., 0., 1., 1.])
"""
num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes)

if pred.ndim == target.ndim + 1:
pred = to_categorical(pred, argmax_dim=argmax_dim)

num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes)

pred = pred.view((-1, )).long()
target = target.view((-1, )).long()

Expand Down

0 comments on commit c14cd67

Please sign in to comment.