diff --git a/CHANGELOG.md b/CHANGELOG.md index 780a8790b9fdd..79b253061ddb2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Trigger warning when non-metric logged value with multi processes hasn't been reduced ([#6417](https://github.com/PyTorchLightning/pytorch-lightning/pull/6417)) + - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 3961586f4946a..bd5d7fb3b0dc9 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -633,6 +633,12 @@ def rename_keys(self, map_dict: dict): meta[dest] = meta[source] del meta[source] + def get_non_metrics_keys(self): + """ + This function is used to filter metric keys for which the value isn't a Metric + """ + return [k for k, v in self.items() if not isinstance(v, Metric)] + def choose_last(x): if isinstance(x, (torch.Tensor, list)): diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 4a57b14efd89b..e2ce66c86ecff 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from collections import defaultdict -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from weakref import proxy import torch @@ -21,6 +22,19 @@ from pytorch_lightning.core.step_result import Result from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import DistributedType, LightningEnum +from pytorch_lightning.utilities.warnings import WarningCache + +log = logging.getLogger(__name__) + + +class MetricWarningCache(WarningCache): + + def __init__(self): + super().__init__() + self.warned_metrics = [] + + +warning_cache = MetricWarningCache() class ResultStoreType(LightningEnum): @@ -52,8 +66,10 @@ class HookResultStore: Those data structures enables us to reduce properly Result object when batch loop is finished. """ - def __init__(self, fx_name: str) -> None: + def __init__(self, fx_name: str, all_gather_fn: Callable, should_warn: bool) -> None: self._fx_name = fx_name + self._all_gather_fn = all_gather_fn + self._should_warn = should_warn self._internals = {} self._internals_reduced = {} self._internal_type = None @@ -109,6 +125,20 @@ def run_epoch_func(self, results, opt_metric, func_name, *args, **kwargs) -> Non func = getattr(opt_metric, func_name) metrics_to_log = func(*args, add_dataloader_idx=self.has_several_dataloaders, **kwargs) + if self._should_warn: + for non_metric_key in opt_metric.get_non_metrics_keys(): + if non_metric_key in metrics_to_log and non_metric_key not in warning_cache.warned_metrics: + metric = self._all_gather_fn(metrics_to_log[non_metric_key]) + if any(metric[0] != m for m in metric[1:]): + warning_cache.warn( + f"The value associated to the key {non_metric_key}: {metric.cpu().tolist()} " + "doesn't appear to be the same accross all processes. " + "HINT: One could either do: `self.log(..., sync_dist=True, sync_fn=torch.mean)`" + " to force mean reduction across processes which can be inaccurate or implement" + " a `torchmetrics.Metric`" + ) + warning_cache.warned_metrics.append(non_metric_key) + results.append(metrics_to_log) def get_epoch_from_func_name(self, func_name, *args, **kwargs) -> List[Dict]: @@ -227,6 +257,12 @@ class EpochResultStore: def __init__(self, trainer: 'pl.Trainer') -> None: self.trainer = proxy(trainer) + + # Add warning only for distributed (expect rpc as main worker is running the code). + _should_warn = trainer.accelerator_connector.is_distributed + _should_warn &= not trainer.training_type_plugin.rpc_enabled + self._should_warn = _should_warn + self.reset() def __getitem__(self, key: str) -> Any: @@ -278,7 +314,8 @@ def cache_result(self) -> None: info = self.info fx_name = info["fx_name"] - self._internals.setdefault(fx_name, HookResultStore(fx_name)) + all_gather_fn = self.trainer.lightning_module.all_gather + self._internals.setdefault(fx_name, HookResultStore(fx_name, all_gather_fn, self._should_warn)) # attach capture batch_size Result.attach_batch_size(self._batch_size, hook_result) diff --git a/tests/trainer/logging_/test_train_loop_logging_1_0.py b/tests/trainer/logging_/test_train_loop_logging_1_0.py index f8672eb4ec51e..393aaacb72328 100644 --- a/tests/trainer/logging_/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging_/test_train_loop_logging_1_0.py @@ -743,6 +743,7 @@ class TestLoggingSyncDistModel(BoringModel): def training_step(self, batch, batch_idx): acc = self.step(batch[0]) self.log('foo', 1, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='SUM') + self.log('cho', acc, on_step=False, on_epoch=True) return acc def validation_step(self, batch, batch_idx): @@ -763,8 +764,12 @@ def validation_step(self, batch, batch_idx): gpus=2, profiler="pytorch" ) - trainer.fit(model) + if os.getenv("LOCAL_RANK") == '0': + with pytest.warns(UserWarning, match="The value associated to the key cho:"): + trainer.fit(model) + else: + trainer.fit(model) assert trainer.logged_metrics['foo'] == 2 assert trainer.logged_metrics['bar'] == 2