From 45091293ffddc9dc401efadadc7df0f71c881a74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 25 Jan 2021 13:57:06 +0100 Subject: [PATCH] fix error when logging to progress bar with reserved name (#5620) * warn about duplicate metrics * update changelog * suggestions from rohit Co-authored-by: Rohit Gupta * multiple values in message * Apply suggestions from code review Co-authored-by: Rohit Gupta Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- CHANGELOG.md | 3 ++ pytorch_lightning/trainer/properties.py | 39 ++++++++++++++----- .../utilities/xla_device_utils.py | 32 +++++++++------ .../legacy/test_multi_nodes_gpu.py | 8 +++- .../trainer/logging_/test_logger_connector.py | 21 +++++++++- tests/utilities/test_xla_device_utils.py | 4 +- 6 files changed, 81 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c4255af8e3d79..6a8cf96c8e1984 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -164,6 +164,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed distributed setting and `ddp_cpu` only with `num_processes>1` ([#5297](https://github.com/PyTorchLightning/pytorch-lightning/pull/5297)) +- Fixed an error when logging a progress bar metric with a reserved name ([#5620](https://github.com/PyTorchLightning/pytorch-lightning/pull/5620)) + + - Fixed the saved filename in `ModelCheckpoint` when it already exists ([#4861](https://github.com/PyTorchLightning/pytorch-lightning/pull/4861)) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 760a621db69140..aca37802ca90e3 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -14,11 +14,15 @@ import inspect import os from abc import ABC -from argparse import ArgumentParser, Namespace +from argparse import ArgumentParser +from argparse import Namespace from typing import cast, List, Optional, Type, TypeVar, Union from pytorch_lightning.accelerators.legacy.accelerator import Accelerator -from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import ProgressBarBase from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers.base import LightningLoggerBase from pytorch_lightning.loggers.tensorboard import TensorBoardLogger @@ -26,13 +30,15 @@ from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE, DeviceType, DistributedType -from pytorch_lightning.utilities.argparse import ( - add_argparse_args, - from_argparse_args, - parse_argparser, - parse_env_variables, -) +from pytorch_lightning.utilities import _HOROVOD_AVAILABLE +from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities import DeviceType +from pytorch_lightning.utilities import DistributedType +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.argparse import add_argparse_args +from pytorch_lightning.utilities.argparse import from_argparse_args +from pytorch_lightning.utilities.argparse import parse_argparser +from pytorch_lightning.utilities.argparse import parse_env_variables from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.model_helpers import is_overridden @@ -193,7 +199,20 @@ def progress_bar_dict(self) -> dict: """ Read-only for progress bar metrics. """ ref_model = self.get_model() ref_model = cast(LightningModule, ref_model) - return dict(**ref_model.get_progress_bar_dict(), **self.logger_connector.progress_bar_metrics) + + standard_metrics = ref_model.get_progress_bar_dict() + logged_metrics = self.progress_bar_metrics + duplicates = list(standard_metrics.keys() & logged_metrics.keys()) + if duplicates: + rank_zero_warn( + f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and" + f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. " + f" If this is undesired, change the name or override `get_progress_bar_dict()`" + f" in `LightingModule`.", UserWarning + ) + all_metrics = dict(**standard_metrics) + all_metrics.update(**logged_metrics) + return all_metrics @property def disable_validation(self) -> bool: diff --git a/pytorch_lightning/utilities/xla_device_utils.py b/pytorch_lightning/utilities/xla_device_utils.py index 22d3455ce49a78..204a2433e757fc 100644 --- a/pytorch_lightning/utilities/xla_device_utils.py +++ b/pytorch_lightning/utilities/xla_device_utils.py @@ -1,9 +1,24 @@ -from warnings import warn - -warn( - "`xla_device_utils` package has been renamed to `xla_device` since v1.2 and will be removed in v1.4", - DeprecationWarning -) +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import importlib +import queue as q +import traceback +from multiprocessing import Process +from multiprocessing import Queue + +import torch XLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None #: define waiting time got checking TPU available in sec @@ -47,10 +62,8 @@ class XLADeviceUtils: def _fetch_xla_device_type(device: torch.device) -> str: """ Returns XLA device type - Args: device: (:class:`~torch.device`): Accepts a torch.device type with a XLA device format i.e xla:0 - Return: Returns a str of the device hardware type. i.e TPU """ @@ -61,7 +74,6 @@ def _fetch_xla_device_type(device: torch.device) -> str: def _is_device_tpu() -> bool: """ Check if device is TPU - Return: A boolean value indicating if the xla device is a TPU device or not """ @@ -74,7 +86,6 @@ def _is_device_tpu() -> bool: def xla_available() -> bool: """ Check if XLA library is installed - Return: A boolean value indicating if a XLA is installed """ @@ -84,7 +95,6 @@ def xla_available() -> bool: def tpu_device_exists() -> bool: """ Runs XLA device check within a separate process - Return: A boolean value indicating if a TPU device exists on the system """ diff --git a/tests/accelerators/legacy/test_multi_nodes_gpu.py b/tests/accelerators/legacy/test_multi_nodes_gpu.py index f17ac42fcbd041..d9387df2b99032 100644 --- a/tests/accelerators/legacy/test_multi_nodes_gpu.py +++ b/tests/accelerators/legacy/test_multi_nodes_gpu.py @@ -26,7 +26,9 @@ from tests.base.boring_model import BoringModel # noqa: E402 -@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" +) def test_logging_sync_dist_true_ddp(tmpdir): """ Tests to ensure that the sync_dist flag works with CPU (should just return the original value) @@ -62,7 +64,9 @@ def validation_step(self, batch, batch_idx): assert trainer.logged_metrics['bar'] == fake_result -@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" +) def test__validation_step__log(tmpdir): """ Tests that validation_step can log diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 106a299e436f10..ffdaea8c5203bf 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -28,7 +28,8 @@ from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base.boring_model import BoringModel, RandomDataset +from tests.base.boring_model import BoringModel +from tests.base.boring_model import RandomDataset def decorator_with_arguments(fx_name: str = '', hook_fx_name: str = None) -> Callable: @@ -454,3 +455,21 @@ def is_float(value: Any) -> bool: assert excepted_function(metrics["x"]) assert excepted_function(metrics["y"]) assert excepted_function(metrics["z"]) + + +def test_logging_to_progress_bar_with_reserved_key(tmpdir): + """ Test that logging a metric with a reserved name to the progress bar raises a warning. """ + class TestModel(BoringModel): + + def training_step(self, *args, **kwargs): + output = super().training_step(*args, **kwargs) + self.log("loss", output["loss"], prog_bar=True) + return output + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=2, + ) + with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"): + trainer.fit(model) diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py index dcafa8509266a8..471792da9ccabe 100644 --- a/tests/utilities/test_xla_device_utils.py +++ b/tests/utilities/test_xla_device_utils.py @@ -16,13 +16,13 @@ import pytest -import pytorch_lightning.utilities.xla_device as xla_utils +import pytorch_lightning.utilities.xla_device_utils as xla_utils from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities import _XLA_AVAILABLE from tests.base.develop_utils import pl_multi_process_test -@pytest.mark.skipif(XLA_AVAILABLE, reason="test requires torch_xla to be absent") +@pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent") def test_tpu_device_absence(): """Check tpu_device_exists returns None when torch_xla is not available""" assert xla_utils.XLADeviceUtils.tpu_device_exists() is None