Skip to content

Commit

Permalink
Resolve code style warning
Browse files Browse the repository at this point in the history
  • Loading branch information
Younghun Roh committed Jul 31, 2020
1 parent b157cfa commit 6df2018
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,7 @@ def stat_scores_multiple_classes(
fns = fns[:num_classes]
sups = sups[:num_classes]

return tps, fps, tns, fns, sups

elif reduction == 'sum' or reduction == 'elementwise_mean':
if reduction == 'sum' or reduction == 'elementwise_mean':
count_match_true = (pred == target).sum().float()
oob_tp, oob_fp, oob_tn, oob_fn, oob_sup = stat_scores(pred, target, num_classes, argmax_dim)

Expand All @@ -224,11 +222,11 @@ def stat_scores_multiple_classes(
tns /= num_classes
sups /= num_classes

return tps, fps, tns, fns, sups

else:
raise ValueError("reduction type %s not supported" % reduction)

return tps, fps, tns, fns, sups


def accuracy(
pred: torch.Tensor,
Expand Down

0 comments on commit 6df2018

Please sign in to comment.