Skip to content

Commit

Permalink
Avoid false positive warning about using sync_dist when using torch…
Browse files Browse the repository at this point in the history
…metrics (#14143)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
rohitgr7 and carmocca authored Aug 12, 2022
1 parent 2d9e00f commit 6789a06
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
4 changes: 3 additions & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed the device placement when `LightningModule.cuda()` gets called without specifying a device index and the current cuda device was not 0 ([#14128](https://github.com/Lightning-AI/lightning/pull/14128))


- Avoided false positive warning about using `sync_dist` when using torchmetrics ([#14143](https://github.com/Lightning-AI/lightning/pull/14143))


- Avoid `metadata.entry_points` deprecation warning on Python 3.10 ([#14052](https://github.com/Lightning-AI/lightning/pull/14052))


Expand All @@ -79,7 +82,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed saving hyperparameters in a composition where the parent class is not a `LightningModule` or `LightningDataModule` ([#14151](https://github.com/Lightning-AI/lightning/pull/14151))



## [1.7.1] - 2022-08-09

### Fixed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]:
elif not on_step and result_metric.meta.on_epoch:
if result_metric._computed is None:
should = result_metric.meta.sync.should
if not result_metric.meta.sync.should and distributed_available():
if not should and distributed_available() and result_metric.is_tensor:
# ensure sync happens for FT since during a failure, the metrics are synced and saved to the
# checkpoint, so during restart, metrics on rank 0 are from the accumulated ones from the previous
# run, and on other ranks, they are 0. So we need to make sure they are synced in further training
Expand Down
22 changes: 16 additions & 6 deletions tests/tests_pytorch/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torchmetrics
from torch.nn import ModuleDict, ModuleList
from torchmetrics import Metric, MetricCollection

import pytorch_lightning as pl
import tests_pytorch.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
Expand Down Expand Up @@ -666,19 +668,27 @@ def on_train_start(self):


@pytest.mark.parametrize("distributed_env", [True, False])
def test_logger_sync_dist(distributed_env):
# self.log('bar', 7, ..., sync_dist=False)
@pytest.mark.parametrize("log_val", [torch.tensor(0.5), torchmetrics.Accuracy()])
def test_logger_sync_dist(distributed_env, log_val):
pl.trainer.connectors.logger_connector.result.warning_cache.clear()

# self.log('bar', 0.5, ..., sync_dist=False)
meta = _Metadata("foo", "bar")
meta.sync = _Sync(_should=False)
result_metric = _ResultMetric(metadata=meta, is_tensor=True)
result_metric.update(torch.tensor(7.0), 10)
is_tensor = isinstance(log_val, torch.Tensor)

if not is_tensor:
log_val.update(torch.tensor([0, 1]), torch.tensor([0, 0], dtype=torch.long))

result_metric = _ResultMetric(metadata=meta, is_tensor=is_tensor)
result_metric.update(log_val, 10)

warning_ctx = pytest.warns if distributed_env else no_warning_call
warning_ctx = pytest.warns if distributed_env and is_tensor else no_warning_call

with mock.patch(
"pytorch_lightning.trainer.connectors.logger_connector.result.distributed_available",
return_value=distributed_env,
):
with warning_ctx(PossibleUserWarning, match=r"recommended to use `self.log\('bar', ..., sync_dist=True\)`"):
value = _ResultCollection._get_cache(result_metric, on_step=False)
assert value == 7.0
assert value == 0.5

0 comments on commit 6789a06

Please sign in to comment.