Skip to content

Commit

Permalink
New argument compute_on_cpu (#867)
Browse files Browse the repository at this point in the history
* implementation

* bit of testing

* map

* docs

* improve testing

* input validation

* mypy

* map testing

* changelog

* update done

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Apr 11, 2022
1 parent d587be9 commit 15fb10b
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 24 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `reset_real_features` argument image quality assesment metrics ([#722](https://github.com/PyTorchLightning/metrics/pull/722))


- Added new keyword argument `compute_on_cpu` to all metrics ([#867](https://github.com/PyTorchLightning/metrics/pull/867))


### Changed

- Made `num_classes` in `jaccard_index` a required argument ([#853](https://github.com/PyTorchLightning/metrics/pull/853), [#914](https://github.com/PyTorchLightning/metrics/pull/914))
Expand Down
16 changes: 13 additions & 3 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,19 @@ A functional metric is differentiable if its corresponding modular metric is dif

.. _Metric kwargs:

*****************************
Advanced distributed settings
*****************************
************************
Advanced metric settings
************************

The following is a list of additional arguments that can be given to any metric class (in the ``**kwargs`` argument)
that will alter how metric states are stored and synced.

If you are running metrics on GPU and are encountering that you are running out of GPU VRAM then the following
argument can help:

- ``compute_on_cpu`` will automatically move the metric states to cpu after calling ``update``, making sure that
GPU memory is not filling up. The consequence will be that the ``compute`` method will be called on CPU instead
of GPU. Only applies to metric states that are lists.

If you are running in a distributed environment, ``TorchMetrics`` will automatically take care of the distributed
synchronization for you. However, the following three keyword arguments can be given to any metric class for
Expand Down
3 changes: 3 additions & 0 deletions tests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def test_error_on_wrong_input():
with pytest.raises(ValueError, match="Expected keyword argument `dist_sync_fn` to be an callable function.*"):
DummyMetric(dist_sync_fn=[2, 3])

with pytest.raises(ValueError, match="Expected keyword argument `compute_on_cpu` to be an `bool` bu.*"):
DummyMetric(compute_on_cpu=None)


def test_inherit():
"""Test that metric that inherits can be instanciated."""
Expand Down
5 changes: 3 additions & 2 deletions tests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def _compare_fn(preds, target) -> dict:


@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
@pytest.mark.parametrize("compute_on_cpu", [True, False])
class TestMAP(MetricTester):
"""Test the MAP metric for object detection predictions.
Expand All @@ -175,7 +176,7 @@ class TestMAP(MetricTester):
atol = 1e-1

@pytest.mark.parametrize("ddp", [False, True])
def test_map(self, ddp):
def test_map(self, compute_on_cpu, ddp):
"""Test modular implementation for correctness."""
self.run_class_metric_test(
ddp=ddp,
Expand All @@ -185,7 +186,7 @@ def test_map(self, ddp):
sk_metric=_compare_fn,
dist_sync_on_step=False,
check_batch=False,
metric_args={"class_metrics": True},
metric_args={"class_metrics": True, "compute_on_cpu": compute_on_cpu},
)


Expand Down
5 changes: 3 additions & 2 deletions tests/image/test_inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,12 @@ def __len__(self):

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test is too slow without gpu")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
def test_compare_is(tmpdir):
@pytest.mark.parametrize("compute_on_cpu", [True, False])
def test_compare_is(tmpdir, compute_on_cpu):
"""check that the hole pipeline give the same result as torch-fidelity."""
from torch_fidelity import calculate_metrics

metric = InceptionScore(splits=1).cuda()
metric = InceptionScore(splits=1, compute_on_cpu=compute_on_cpu).cuda()

# Generate some synthetic data
img1 = torch.randint(0, 255, (100, 3, 299, 299), dtype=torch.uint8)
Expand Down
4 changes: 3 additions & 1 deletion tests/regression/test_pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,17 @@ def _sk_pearsonr(preds, target):
class TestPearsonCorrcoef(MetricTester):
atol = 1e-2

@pytest.mark.parametrize("compute_on_cpu", [True, False])
@pytest.mark.parametrize("ddp", [True, False])
def test_pearson_corrcoef(self, preds, target, ddp):
def test_pearson_corrcoef(self, preds, target, compute_on_cpu, ddp):
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=PearsonCorrCoef,
sk_metric=_sk_pearsonr,
dist_sync_on_step=False,
metric_args={"compute_on_cpu": compute_on_cpu},
)

def test_pearson_corrcoef_functional(self, preds, target):
Expand Down
20 changes: 6 additions & 14 deletions torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,10 +482,10 @@ def _evaluate_image(
nb_iou_thrs = len(self.iou_thresholds)
nb_gt = len(gt)
nb_det = len(det)
gt_matches = torch.zeros((nb_iou_thrs, nb_gt), dtype=torch.bool)
det_matches = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool)
gt_matches = torch.zeros((nb_iou_thrs, nb_gt), dtype=torch.bool, device=gt.device)
det_matches = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=gt.device)
gt_ignore = ignore_area_sorted
det_ignore = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool)
det_ignore = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=gt.device)

if torch.numel(ious) > 0:
for idx_iou, t in enumerate(self.iou_thresholds):
Expand Down Expand Up @@ -723,13 +723,13 @@ def __calculate_recall_precision_scores(
recall[idx, idx_cls, idx_bbox_area, idx_max_det_thrs] = rc[-1] if nd else 0

# Remove zigzags for AUC
diff_zero = torch.zeros((1,))
diff = torch.ones((1,))
diff_zero = torch.zeros((1,), device=pr.device)
diff = torch.ones((1,), device=pr.device)
while not torch.all(diff == 0):
diff = torch.clamp(torch.cat((pr[1:] - pr[:-1], diff_zero), 0), min=0)
pr += diff

inds = torch.searchsorted(rc, rec_thresholds, right=False)
inds = torch.searchsorted(rc, rec_thresholds.to(rc.device), right=False)
num_inds = inds.argmax() if inds.max() >= nd else nb_rec_thrs
inds = inds[:num_inds]
prec[:num_inds] = pr[inds]
Expand Down Expand Up @@ -766,14 +766,6 @@ def compute(self) -> dict:
- map_per_class: ``torch.Tensor`` (-1 if class metrics are disabled)
- mar_100_per_class: ``torch.Tensor`` (-1 if class metrics are disabled)
"""

# move everything to CPU, as we are faster here
self.detection_boxes = [box.cpu() for box in self.detection_boxes]
self.detection_labels = [label.cpu() for label in self.detection_labels]
self.detection_scores = [score.cpu() for score in self.detection_scores]
self.groundtruth_boxes = [box.cpu() for box in self.groundtruth_boxes]
self.groundtruth_labels = [label.cpu() for label in self.groundtruth_labels]

classes = self._get_classes()
precisions, recalls = self._calculate(classes)
map_val, mar_val = self._summarize_results(precisions, recalls)
Expand Down
24 changes: 22 additions & 2 deletions torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ class Metric(Module, ABC):
kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info.
- compute_on_cpu: If metric state should be stored on CPU during computations. Only works
for list states.
- dist_sync_on_step: If metric state should synchronize on ``forward()``
- process_group: The process group on which the synchronization is called
- dist_sync_fn: function that performs the allgather option on the metric state
Expand Down Expand Up @@ -120,6 +122,11 @@ def __init__(
warnings.warn(
"Argument `compute_on_step` is deprecated in v0.8 and will be removed in v0.9", DeprecationWarning
)
self.compute_on_cpu = kwargs.pop("compute_on_cpu", False)
if not isinstance(self.compute_on_cpu, bool):
raise ValueError(
f"Expected keyword argument `compute_on_cpu` to be an `bool` but got {self.compute_on_cpu}"
)

self.dist_sync_on_step = kwargs.pop("dist_sync_on_step", False)
if not isinstance(self.dist_sync_on_step, bool):
Expand Down Expand Up @@ -243,6 +250,9 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
self._to_sync = self.dist_sync_on_step # type: ignore
# skip restore cache operation from compute as cache is stored below.
self._should_unsync = False
# skip computing on cpu for the batch
_temp_compute_on_cpu = self.compute_on_cpu
self.compute_on_cpu = False

# save context before switch
cache = {attr: getattr(self, attr) for attr in self._defaults}
Expand All @@ -262,6 +272,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
self._to_sync = True
self._computed = None
self._enable_grad = False
self.compute_on_cpu = _temp_compute_on_cpu

return self._forward_cache

Expand Down Expand Up @@ -294,14 +305,23 @@ def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group:

def _wrap_update(self, update: Callable) -> Callable:
@functools.wraps(update)
def wrapped_func(*args: Any, **kwargs: Any) -> Optional[Any]:
def wrapped_func(*args: Any, **kwargs: Any) -> None:
self._computed = None
self._update_called = True
with torch.set_grad_enabled(self._enable_grad):
return update(*args, **kwargs)
update(*args, **kwargs)
if self.compute_on_cpu:
self._move_list_states_to_cpu()

return wrapped_func

def _move_list_states_to_cpu(self) -> None:
"""Move list states to cpu to save GPU memory."""
for key in self._defaults.keys():
current_val = getattr(self, key)
if isinstance(current_val, Sequence):
setattr(self, key, [cur_v.to("cpu") for cur_v in current_val])

def sync(
self,
dist_sync_fn: Optional[Callable] = None,
Expand Down

0 comments on commit 15fb10b

Please sign in to comment.