Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix error when logging to progress bar with reserved name #5620

Merged
merged 13 commits into from
Jan 25, 2021
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed an error when logging a progress bar metric with a reserved name ([#5620](https://github.com/PyTorchLightning/pytorch-lightning/pull/5620))




Expand Down
18 changes: 16 additions & 2 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,23 @@ def progress_bar_callback(self):
@property
def progress_bar_dict(self) -> dict:
""" Read-only for progress bar metrics. """
ref_model = self.model if not self.data_parallel else self.model.module
ref_model = self.get_model()
ref_model = cast(LightningModule, ref_model)
return dict(**ref_model.get_progress_bar_dict(), **self.logger_connector.progress_bar_metrics)

standard_metrics = ref_model.get_progress_bar_dict()
logged_metrics = self.progress_bar_metrics
duplicates = list(standard_metrics.keys() & logged_metrics.keys())
if duplicates:
rank_zero_warn(
f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and"
f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. "
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
f" If this is undesired, change the name or override `get_progress_bar_dict()`"
f" in `LightingModule`.",
UserWarning
)
all_metrics = dict(**standard_metrics)
all_metrics.update(**logged_metrics)
return all_metrics

@property
def disable_validation(self) -> bool:
Expand Down
18 changes: 18 additions & 0 deletions tests/trainer/logging/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,3 +425,21 @@ def test_dataloader(self):
)
trainer.fit(model)
trainer.test(model, ckpt_path=None)


def test_logging_to_progress_bar_with_reserved_key(tmpdir):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton @Borda Where should I put this test? I'm not familiar with how the logging tests are structured.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests/trainer/logging_process ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this folder does not exist on master, only dev branch

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I see.... :]

Copy link
Contributor

@tchaton tchaton Jan 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe callback/test_progress ?

""" Test that logging a metric with a reserved name to the progress bar raises a warning. """
class TestModel(BoringModel):

def training_step(self, *args, **kwargs):
output = super().training_step(*args, **kwargs)
self.log("loss", output["loss"], prog_bar=True)
return output

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=2,
)
with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"):
trainer.fit(model)