Skip to content

Commit

Permalink
fix error when logging to progress bar with reserved name (#5620)
Browse files Browse the repository at this point in the history
* warn about duplicate metrics

* update changelog

* suggestions from rohit

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* multiple values in message

* Apply suggestions from code review

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored and Borda committed Feb 4, 2021
1 parent 7dc869f commit 4509129
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 26 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed distributed setting and `ddp_cpu` only with `num_processes>1` ([#5297](https://github.com/PyTorchLightning/pytorch-lightning/pull/5297))


- Fixed an error when logging a progress bar metric with a reserved name ([#5620](https://github.com/PyTorchLightning/pytorch-lightning/pull/5620))


- Fixed the saved filename in `ModelCheckpoint` when it already exists ([#4861](https://github.com/PyTorchLightning/pytorch-lightning/pull/4861))


Expand Down
39 changes: 29 additions & 10 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,31 @@
import inspect
import os
from abc import ABC
from argparse import ArgumentParser, Namespace
from argparse import ArgumentParser
from argparse import Namespace
from typing import cast, List, Optional, Type, TypeVar, Union

from pytorch_lightning.accelerators.legacy.accelerator import Accelerator
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import ProgressBarBase
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE, DeviceType, DistributedType
from pytorch_lightning.utilities.argparse import (
add_argparse_args,
from_argparse_args,
parse_argparser,
parse_env_variables,
)
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities import DeviceType
from pytorch_lightning.utilities import DistributedType
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.argparse import add_argparse_args
from pytorch_lightning.utilities.argparse import from_argparse_args
from pytorch_lightning.utilities.argparse import parse_argparser
from pytorch_lightning.utilities.argparse import parse_env_variables
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.model_helpers import is_overridden

Expand Down Expand Up @@ -193,7 +199,20 @@ def progress_bar_dict(self) -> dict:
""" Read-only for progress bar metrics. """
ref_model = self.get_model()
ref_model = cast(LightningModule, ref_model)
return dict(**ref_model.get_progress_bar_dict(), **self.logger_connector.progress_bar_metrics)

standard_metrics = ref_model.get_progress_bar_dict()
logged_metrics = self.progress_bar_metrics
duplicates = list(standard_metrics.keys() & logged_metrics.keys())
if duplicates:
rank_zero_warn(
f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and"
f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. "
f" If this is undesired, change the name or override `get_progress_bar_dict()`"
f" in `LightingModule`.", UserWarning
)
all_metrics = dict(**standard_metrics)
all_metrics.update(**logged_metrics)
return all_metrics

@property
def disable_validation(self) -> bool:
Expand Down
32 changes: 21 additions & 11 deletions pytorch_lightning/utilities/xla_device_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
from warnings import warn

warn(
"`xla_device_utils` package has been renamed to `xla_device` since v1.2 and will be removed in v1.4",
DeprecationWarning
)
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import importlib
import queue as q
import traceback
from multiprocessing import Process
from multiprocessing import Queue

import torch

XLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None
#: define waiting time got checking TPU available in sec
Expand Down Expand Up @@ -47,10 +62,8 @@ class XLADeviceUtils:
def _fetch_xla_device_type(device: torch.device) -> str:
"""
Returns XLA device type
Args:
device: (:class:`~torch.device`): Accepts a torch.device type with a XLA device format i.e xla:0
Return:
Returns a str of the device hardware type. i.e TPU
"""
Expand All @@ -61,7 +74,6 @@ def _fetch_xla_device_type(device: torch.device) -> str:
def _is_device_tpu() -> bool:
"""
Check if device is TPU
Return:
A boolean value indicating if the xla device is a TPU device or not
"""
Expand All @@ -74,7 +86,6 @@ def _is_device_tpu() -> bool:
def xla_available() -> bool:
"""
Check if XLA library is installed
Return:
A boolean value indicating if a XLA is installed
"""
Expand All @@ -84,7 +95,6 @@ def xla_available() -> bool:
def tpu_device_exists() -> bool:
"""
Runs XLA device check within a separate process
Return:
A boolean value indicating if a TPU device exists on the system
"""
Expand Down
8 changes: 6 additions & 2 deletions tests/accelerators/legacy/test_multi_nodes_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from tests.base.boring_model import BoringModel # noqa: E402


@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest")
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
)
def test_logging_sync_dist_true_ddp(tmpdir):
"""
Tests to ensure that the sync_dist flag works with CPU (should just return the original value)
Expand Down Expand Up @@ -62,7 +64,9 @@ def validation_step(self, batch, batch_idx):
assert trainer.logged_metrics['bar'] == fake_result


@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest")
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
)
def test__validation_step__log(tmpdir):
"""
Tests that validation_step can log
Expand Down
21 changes: 20 additions & 1 deletion tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator
from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.boring_model import BoringModel, RandomDataset
from tests.base.boring_model import BoringModel
from tests.base.boring_model import RandomDataset


def decorator_with_arguments(fx_name: str = '', hook_fx_name: str = None) -> Callable:
Expand Down Expand Up @@ -454,3 +455,21 @@ def is_float(value: Any) -> bool:
assert excepted_function(metrics["x"])
assert excepted_function(metrics["y"])
assert excepted_function(metrics["z"])


def test_logging_to_progress_bar_with_reserved_key(tmpdir):
""" Test that logging a metric with a reserved name to the progress bar raises a warning. """
class TestModel(BoringModel):

def training_step(self, *args, **kwargs):
output = super().training_step(*args, **kwargs)
self.log("loss", output["loss"], prog_bar=True)
return output

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=2,
)
with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"):
trainer.fit(model)
4 changes: 2 additions & 2 deletions tests/utilities/test_xla_device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@

import pytest

import pytorch_lightning.utilities.xla_device as xla_utils
import pytorch_lightning.utilities.xla_device_utils as xla_utils
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities import _XLA_AVAILABLE
from tests.base.develop_utils import pl_multi_process_test


@pytest.mark.skipif(XLA_AVAILABLE, reason="test requires torch_xla to be absent")
@pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent")
def test_tpu_device_absence():
"""Check tpu_device_exists returns None when torch_xla is not available"""
assert xla_utils.XLADeviceUtils.tpu_device_exists() is None
Expand Down

0 comments on commit 4509129

Please sign in to comment.