Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix support for half precision in Perplexity metric #2235

Merged
merged 11 commits into from
Nov 25, 2023
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed numerical stability issue in `UniversalImageQualityIndex` metric ([#2222](https://github.com/Lightning-AI/torchmetrics/pull/2222))


- Fixed support for half precision in Perplexity metric ([#2235](https://github.com/Lightning-AI/torchmetrics/pull/2235))

## [1.2.0] - 2023-09-22

### Added
Expand Down
11 changes: 3 additions & 8 deletions src/torchmetrics/functional/text/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@

import torch
from torch import Tensor
from torch.nn import functional as F # noqa: N812

_TORCH_FLOAT_OR_DOUBLE = (torch.float32, torch.float64)


def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> None:
Expand Down Expand Up @@ -59,10 +56,8 @@ def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> None:
"Input tensors `preds` and `target` are expected to have equaling first two dimensions,"
f" [batch_size, seq_len], but got {preds.shape[:2]} and {target.shape}."
)
if preds.dtype not in _TORCH_FLOAT_OR_DOUBLE:
raise TypeError(
f"Input tensor `preds` is expected to be of a type one of {_TORCH_FLOAT_OR_DOUBLE} but got {preds.dtype}."
)
if not preds.is_floating_point():
raise TypeError(f"Input tensor `preds` is expected to be of floating point type but got {preds.dtype}.")
if target.dtype != torch.int64:
raise TypeError(f"Input tensor `target` is expected to be of a type {torch.int64} but got {target.dtype}.")

Expand All @@ -87,7 +82,7 @@ def _perplexity_update(preds: Tensor, target: Tensor, ignore_index: Optional[int
"""
_check_shape_and_type_consistency(preds, target)

probs = F.softmax(preds.reshape(-1, preds.shape[-1]), dim=1)
probs = torch.nn.functional.softmax(preds.reshape(-1, preds.shape[-1]), dim=1)
target = target.reshape(-1)

if ignore_index is not None:
Expand Down
19 changes: 18 additions & 1 deletion tests/unittests/text/test_perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_perplexity_fn(self, preds, target, ignore_index):
metric_args={"ignore_index": ignore_index},
)

def test_accuracy_differentiability(self, preds, target, ignore_index):
def test_perplexity_differentiability(self, preds, target, ignore_index):
"""Test the differentiability of the metric, according to its `is_differentiable` attribute."""
self.run_differentiability_test(
preds=preds,
Expand All @@ -80,3 +80,20 @@ def test_accuracy_differentiability(self, preds, target, ignore_index):
metric_functional=perplexity,
metric_args={"ignore_index": ignore_index},
)

@pytest.mark.parametrize("dtype", [torch.half, torch.double])
def test_perplexity_dtypes_cpu(self, preds, target, ignore_index, dtype):
"""Test dtype support of the metric on CPU."""
if dtype == torch.half:
pytest.skip("`softmax_lastdim_kernel_impl` is not support for half precision on CPU.")
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
self.run_precision_test_cpu(
preds, target, Perplexity, perplexity, metric_args={"ignore_index": ignore_index}, dtype=dtype
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
@pytest.mark.parametrize("dtype", [torch.half, torch.double])
def test_perplexity_dtypes_gpu(self, preds, target, ignore_index, dtype):
"""Test dtype support of the metric on GPU."""
self.run_precision_test_gpu(
preds, target, Perplexity, perplexity, metric_args={"ignore_index": ignore_index}, dtype=dtype
)
Loading