Skip to content

Commit

Permalink
fix error when logging to progress bar with reserved name (#5620)
Browse files Browse the repository at this point in the history
* warn about duplicate metrics

* update changelog

* suggestions from rohit

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* multiple values in message

* Apply suggestions from code review

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored and Borda committed Feb 4, 2021
1 parent 062b9ba commit 1bc063e
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 15 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed distributed setting and `ddp_cpu` only with `num_processes>1` ([#5297](https://github.com/PyTorchLightning/pytorch-lightning/pull/5297))


- Fixed an error when logging a progress bar metric with a reserved name ([#5620](https://github.com/PyTorchLightning/pytorch-lightning/pull/5620))


- Fixed the saved filename in `ModelCheckpoint` when it already exists ([#4861](https://github.com/PyTorchLightning/pytorch-lightning/pull/4861))


Expand Down
39 changes: 29 additions & 10 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,31 @@
import inspect
import os
from abc import ABC
from argparse import ArgumentParser, Namespace
from argparse import ArgumentParser
from argparse import Namespace
from typing import cast, List, Optional, Type, TypeVar, Union

from pytorch_lightning.accelerators.legacy.accelerator import Accelerator
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import ProgressBarBase
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE, DeviceType, DistributedType
from pytorch_lightning.utilities.argparse import (
add_argparse_args,
from_argparse_args,
parse_argparser,
parse_env_variables,
)
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities import DeviceType
from pytorch_lightning.utilities import DistributedType
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.argparse import add_argparse_args
from pytorch_lightning.utilities.argparse import from_argparse_args
from pytorch_lightning.utilities.argparse import parse_argparser
from pytorch_lightning.utilities.argparse import parse_env_variables
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.model_helpers import is_overridden

Expand Down Expand Up @@ -193,7 +199,20 @@ def progress_bar_dict(self) -> dict:
""" Read-only for progress bar metrics. """
ref_model = self.get_model()
ref_model = cast(LightningModule, ref_model)
return dict(**ref_model.get_progress_bar_dict(), **self.logger_connector.progress_bar_metrics)

standard_metrics = ref_model.get_progress_bar_dict()
logged_metrics = self.progress_bar_metrics
duplicates = list(standard_metrics.keys() & logged_metrics.keys())
if duplicates:
rank_zero_warn(
f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and"
f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. "
f" If this is undesired, change the name or override `get_progress_bar_dict()`"
f" in `LightingModule`.", UserWarning
)
all_metrics = dict(**standard_metrics)
all_metrics.update(**logged_metrics)
return all_metrics

@property
def disable_validation(self) -> bool:
Expand Down
8 changes: 6 additions & 2 deletions tests/accelerators/legacy/test_multi_nodes_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from tests.base.boring_model import BoringModel # noqa: E402


@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest")
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
)
def test_logging_sync_dist_true_ddp(tmpdir):
"""
Tests to ensure that the sync_dist flag works with CPU (should just return the original value)
Expand Down Expand Up @@ -62,7 +64,9 @@ def validation_step(self, batch, batch_idx):
assert trainer.logged_metrics['bar'] == fake_result


@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest")
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
)
def test__validation_step__log(tmpdir):
"""
Tests that validation_step can log
Expand Down
21 changes: 20 additions & 1 deletion tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator
from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.boring_model import BoringModel, RandomDataset
from tests.base.boring_model import BoringModel
from tests.base.boring_model import RandomDataset


def decorator_with_arguments(fx_name: str = '', hook_fx_name: str = None) -> Callable:
Expand Down Expand Up @@ -454,3 +455,21 @@ def is_float(value: Any) -> bool:
assert excepted_function(metrics["x"])
assert excepted_function(metrics["y"])
assert excepted_function(metrics["z"])


def test_logging_to_progress_bar_with_reserved_key(tmpdir):
""" Test that logging a metric with a reserved name to the progress bar raises a warning. """
class TestModel(BoringModel):

def training_step(self, *args, **kwargs):
output = super().training_step(*args, **kwargs)
self.log("loss", output["loss"], prog_bar=True)
return output

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=2,
)
with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"):
trainer.fit(model)
4 changes: 2 additions & 2 deletions tests/utilities/test_xla_device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@

import pytest

import pytorch_lightning.utilities.xla_device as xla_utils
import pytorch_lightning.utilities.xla_device_utils as xla_utils
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities import _XLA_AVAILABLE
from tests.base.develop_utils import pl_multi_process_test


@pytest.mark.skipif(XLA_AVAILABLE, reason="test requires torch_xla to be absent")
@pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent")
def test_tpu_device_absence():
"""Check tpu_device_exists returns None when torch_xla is not available"""
assert xla_utils.XLADeviceUtils.tpu_device_exists() is None
Expand Down

0 comments on commit 1bc063e

Please sign in to comment.