Skip to content

Commit

Permalink
Fix ROC metric for CUDA tensors (#2304)
Browse files Browse the repository at this point in the history
* Fix ROC metric for CUDA tensors

Previously roc metric (and auroc) errors when passed in CUDA tensors,
due to torch.tensor construction without specifying device.
This fixes the error by using F.pad instead.

* Update test_classification.py

* Update test_classification.py

* chlog

* Update test_classification.py

* Update test_classification.py

* Update tests/metrics/functional/test_classification.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update test_classification.py

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Jirka <jirka@pytorchlightning.ai>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
4 people authored Jun 23, 2020
1 parent 92f122e commit 29179db
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 12 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed an issue with forward hooks not being removed after model summary ([#2298](https://github.com/PyTorchLightning/pytorch-lightning/pull/2298))

- Fixed `average_precision` metric ([#2319](https://github.com/PyTorchLightning/pytorch-lightning/pull/2319))
- Fixed ROC metric for CUDA tensors ([#2304](https://github.com/PyTorchLightning/pytorch-lightning/pull/2304))

- Fixed `average_precision` metric ([#2319](https://github.com/PyTorchLightning/pytorch-lightning/pull/2319))

## [0.8.1] - 2020-06-19

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional, Tuple, Callable

import torch
from torch.nn import functional as F

from pytorch_lightning.metrics.functional.reduction import reduce
from pytorch_lightning.utilities import rank_zero_warn
Expand Down Expand Up @@ -500,8 +501,7 @@ def _binary_clf_curve(
# the indices associated with the distinct values. We also
# concatenate a value for the end of the curve.
distinct_value_indices = torch.where(pred[1:] - pred[:-1])[0]
threshold_idxs = torch.cat([distinct_value_indices,
torch.tensor([target.size(0) - 1])])
threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1)

target = (target == pos_label).to(torch.long)
tps = torch.cumsum(target * weight, dim=0)[threshold_idxs]
Expand Down
23 changes: 14 additions & 9 deletions tests/metrics/functional/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,30 @@
])
def test_against_sklearn(sklearn_metric, torch_metric):
"""Compare PL metrics to sklearn version."""
pred = torch.randint(10, (500,))
target = torch.randint(10, (500,))
device = 'cuda' if torch.cuda.is_available() else 'cpu'

pred = torch.randint(10, (500,), device=device)
target = torch.randint(10, (500,), device=device)

assert torch.allclose(
torch.tensor(sklearn_metric(target, pred), dtype=torch.float),
torch.tensor(sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy()), dtype=torch.float, device=device),
torch_metric(pred, target))

pred = torch.randint(10, (200,))
target = torch.randint(5, (200,))
pred = torch.randint(10, (200,), device=device)
target = torch.randint(5, (200,), device=device)

assert torch.allclose(
torch.tensor(sklearn_metric(target, pred), dtype=torch.float),
torch.tensor(sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy()), dtype=torch.float, device=device),
torch_metric(pred, target))

pred = torch.randint(5, (200,))
target = torch.randint(10, (200,))
pred = torch.randint(5, (200,), device=device)
target = torch.randint(10, (200,), device=device)

assert torch.allclose(
torch.tensor(sklearn_metric(target, pred), dtype=torch.float),
torch.tensor(sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy()), dtype=torch.float, device=device),
torch_metric(pred, target))


Expand Down

0 comments on commit 29179db

Please sign in to comment.