diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 83b27a7ab2d53..0a77dd6b67682 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -181,9 +181,9 @@ def stat_scores_multiple_classes( num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) if pred.dtype != torch.bool: - pred.clamp_max_(max=num_classes) + pred = pred.clamp_max(max=num_classes) if target.dtype != torch.bool: - target.clamp_max_(max=num_classes) + target = target.clamp_max(max=num_classes) possible_reductions = ('none', 'sum', 'elementwise_mean') if reduction not in possible_reductions: