Skip to content

Commit

Permalink
Fix label checking in classification (#2427)
Browse files Browse the repository at this point in the history
* Implementation
* tests
* changelog

(cherry picked from commit 5980744)
  • Loading branch information
SkafteNicki authored and Borda committed Mar 18, 2024
1 parent ef3e473 commit d3e891e
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 28 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed bug when `top_k>1` and `average="macro"` for classification metrics ([#2423](https://github.com/Lightning-AI/torchmetrics/pull/2423))


- Fixed case where label prediction tensors in classification metrics were not validated correctly ([#2427](https://github.com/Lightning-AI/torchmetrics/pull/2427))


- Fixed how auc scores are calculated in `PrecisionRecallCurve.plot` methods ([#2437](https://github.com/Lightning-AI/torchmetrics/pull/2437))

## [1.3.1] - 2024-02-12
Expand Down
20 changes: 6 additions & 14 deletions src/torchmetrics/functional/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,21 +285,13 @@ def _multiclass_confusion_matrix_tensor_validation(
" and `preds` should be (N, C, ...)."
)

num_unique_values = len(torch.unique(target))
check = num_unique_values > num_classes if ignore_index is None else num_unique_values > num_classes + 1
if check:
raise RuntimeError(
"Detected more unique values in `target` than `num_classes`. Expected only "
f"{num_classes if ignore_index is None else num_classes + 1} but found "
f"{num_unique_values} in `target`."
)

if not preds.is_floating_point():
num_unique_values = len(torch.unique(preds))
if num_unique_values > num_classes:
check_value = num_classes if ignore_index is None else num_classes + 1
for t, name in ((target, "target"),) + ((preds, "preds"),) if not preds.is_floating_point() else (): # noqa: RUF005
num_unique_values = len(torch.unique(t))
if num_unique_values > check_value:
raise RuntimeError(
"Detected more unique values in `preds` than `num_classes`. Expected only "
f"{num_classes} but found {num_unique_values} in `preds`."
f"Detected more unique values in `{name}` than expected. Expected only {check_value} but found"
f" {num_unique_values} in `target`."
)


Expand Down
20 changes: 6 additions & 14 deletions src/torchmetrics/functional/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,21 +304,13 @@ def _multiclass_stat_scores_tensor_validation(
" and `preds` should be (N, C, ...)."
)

num_unique_values = len(torch.unique(target))
check = num_unique_values > num_classes if ignore_index is None else num_unique_values > num_classes + 1
if check:
raise RuntimeError(
"Detected more unique values in `target` than `num_classes`. Expected only"
f" {num_classes if ignore_index is None else num_classes + 1} but found"
f" {num_unique_values} in `target`."
)

if not preds.is_floating_point():
unique_values = torch.unique(preds)
if len(unique_values) > num_classes:
check_value = num_classes if ignore_index is None else num_classes + 1
for t, name in ((target, "target"),) + ((preds, "preds"),) if not preds.is_floating_point() else (): # noqa: RUF005
num_unique_values = len(torch.unique(t))
if num_unique_values > check_value:
raise RuntimeError(
"Detected more unique values in `preds` than `num_classes`. Expected only"
f" {num_classes} but found {len(unique_values)} in `preds`."
f"Detected more unique values in `{name}` than expected. Expected only {check_value} but found"
f" {num_unique_values} in `target`."
)


Expand Down
35 changes: 35 additions & 0 deletions tests/unittests/classification/test_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,41 @@ def test_multiclass_confusion_matrix_dtype_gpu(self, inputs, dtype):
)


@pytest.mark.parametrize(
("preds", "target", "ignore_index", "error_message"),
[
(
torch.randint(NUM_CLASSES + 1, (100,)),
torch.randint(NUM_CLASSES, (100,)),
None,
f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES}.*",
),
(
torch.randint(NUM_CLASSES, (100,)),
torch.randint(NUM_CLASSES + 1, (100,)),
None,
f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES}.*",
),
(
torch.randint(NUM_CLASSES + 2, (100,)),
torch.randint(NUM_CLASSES, (100,)),
1,
f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES + 1}.*",
),
(
torch.randint(NUM_CLASSES, (100,)),
torch.randint(NUM_CLASSES + 2, (100,)),
1,
f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES + 1}.*",
),
],
)
def test_raises_error_on_too_many_classes(preds, target, ignore_index, error_message):
"""Test that an error is raised if the number of classes in preds or target is larger than expected."""
with pytest.raises(RuntimeError, match=error_message):
multiclass_confusion_matrix(preds, target, num_classes=NUM_CLASSES, ignore_index=ignore_index)


def test_multiclass_overflow():
"""Test that multiclass computations does not overflow even on byte inputs."""
preds = torch.randint(20, (100,)).byte()
Expand Down
35 changes: 35 additions & 0 deletions tests/unittests/classification/test_stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,41 @@ def test_multiclass_stat_scores_dtype_gpu(self, inputs, dtype):
)


@pytest.mark.parametrize(
("preds", "target", "ignore_index", "error_message"),
[
(
torch.randint(NUM_CLASSES + 1, (100,)),
torch.randint(NUM_CLASSES, (100,)),
None,
f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES}.*",
),
(
torch.randint(NUM_CLASSES, (100,)),
torch.randint(NUM_CLASSES + 1, (100,)),
None,
f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES}.*",
),
(
torch.randint(NUM_CLASSES + 2, (100,)),
torch.randint(NUM_CLASSES, (100,)),
1,
f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES + 1}.*",
),
(
torch.randint(NUM_CLASSES, (100,)),
torch.randint(NUM_CLASSES + 2, (100,)),
1,
f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES + 1}.*",
),
],
)
def test_raises_error_on_too_many_classes(preds, target, ignore_index, error_message):
"""Test that an error is raised if the number of classes in preds or target is larger than expected."""
with pytest.raises(RuntimeError, match=error_message):
multiclass_stat_scores(preds, target, num_classes=NUM_CLASSES, ignore_index=ignore_index)


_mc_k_target = torch.tensor([0, 1, 2])
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])

Expand Down

0 comments on commit d3e891e

Please sign in to comment.