From d9d7e91a3b68fb7bbb966c73745a932ea95a2e6b Mon Sep 17 00:00:00 2001 From: Younghun Roh <9127047+Diuven@users.noreply.github.com> Date: Sat, 8 Aug 2020 19:01:38 +0900 Subject: [PATCH] hotfix on classification metrics (#2878) * Faster classfication stats * Faster accuracy metric * minor change on cls metric * Add out-of-bound class clamping * Add more tests and minor fixes * Resolve code style warning * Update for #2781 * hotfix * Update pytorch_lightning/metrics/functional/classification.py Co-authored-by: Jirka Borovec * Update about conversation * Add docstring on stat_scores_multiple_classes * Fixing #2862 Co-authored-by: Younghun Roh Co-authored-by: Jirka Borovec --- pytorch_lightning/metrics/functional/classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 83b27a7ab2d53..0a77dd6b67682 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -181,9 +181,9 @@ def stat_scores_multiple_classes( num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) if pred.dtype != torch.bool: - pred.clamp_max_(max=num_classes) + pred = pred.clamp_max(max=num_classes) if target.dtype != torch.bool: - target.clamp_max_(max=num_classes) + target = target.clamp_max(max=num_classes) possible_reductions = ('none', 'sum', 'elementwise_mean') if reduction not in possible_reductions: