Skip to content

Commit

Permalink
IoU: avoid recomputing class presence in target and pred
Browse files Browse the repository at this point in the history
Use already-computed support, true positives, and false positives to
determine if a class is not present in either target or pred.
  • Loading branch information
abrahambotros committed Sep 8, 2020
1 parent f1be1e7 commit 172a7c1
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,15 +1009,19 @@ def iou(
tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred, target, num_classes)

scores = torch.zeros(num_classes - min_class_idx, device=pred.device, dtype=torch.float32)
for class_idx in range(min_class_idx, num_classes):
# If this class is not present in either the pred or the target, then use the not_present_score for this class.
if not (target == class_idx).any() and not (pred == class_idx).any():
scores[class_idx - min_class_idx] = not_present_score
continue

for class_idx in range(min_class_idx, num_classes):
tp = tps[class_idx]
fp = fps[class_idx]
fn = fns[class_idx]
sup = sups[class_idx]

# If this class is not present in either the target (no support) or the pred (no true or false positives), then
# use the not_present_score for this class.
if sup + tp + fp == 0:
scores[class_idx - min_class_idx] = not_present_score
continue

denom = tp + fp + fn
score = tp.to(torch.float) / denom
scores[class_idx - min_class_idx] = score
Expand Down

0 comments on commit 172a7c1

Please sign in to comment.