From 5775ca823604dcc560fbf9d2d33d651174771e5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 23 Jan 2021 01:45:30 +0100 Subject: [PATCH 1/7] warn about duplicate metrics --- pytorch_lightning/trainer/properties.py | 16 +++++++++++++++- tests/trainer/logging/test_logger_connector.py | 18 ++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 3fa2af79e5530..c0dd34af898a0 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -185,7 +185,21 @@ def progress_bar_dict(self) -> dict: """ Read-only for progress bar metrics. """ ref_model = self.model if not self.data_parallel else self.model.module ref_model = cast(LightningModule, ref_model) - return dict(**ref_model.get_progress_bar_dict(), **self.logger_connector.progress_bar_metrics) + + standard_metrics = ref_model.get_progress_bar_dict() + logged_metrics = self.logger_connector.progress_bar_metrics + duplicates = list(standard_metrics.keys() & logged_metrics.keys()) + if duplicates: + rank_zero_warn( + f"The progress bar already tracks a metric with the name '{duplicates[0]}' and" + f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. " + f" If this is undesired, change the name or override `get_progress_bar_dict()`." + f" in LightingModule.", + UserWarning + ) + all_metrics = dict(**standard_metrics) + all_metrics.update(**logged_metrics) + return all_metrics @property def disable_validation(self) -> bool: diff --git a/tests/trainer/logging/test_logger_connector.py b/tests/trainer/logging/test_logger_connector.py index 56e5765c7f4b8..fcf6f7ff013ac 100644 --- a/tests/trainer/logging/test_logger_connector.py +++ b/tests/trainer/logging/test_logger_connector.py @@ -425,3 +425,21 @@ def test_dataloader(self): ) trainer.fit(model) trainer.test(model, ckpt_path=None) + + +def test_logging_to_progress_bar_with_reserved_key(tmpdir): + """ Test that logging a metric with a reserved name to the progress bar raises a warning. """ + class TestModel(BoringModel): + + def training_step(self, *args, **kwargs): + output = super().training_step(*args, **kwargs) + self.log("loss", output["loss"], prog_bar=True) + return output + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=2, + ) + with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the name 'loss'"): + trainer.fit(model) From f1d124b130d7d41d8714e7aed89901a5be158313 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 23 Jan 2021 02:01:53 +0100 Subject: [PATCH 2/7] update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5fd70e3583c01..389078239901b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed an error when logging a progress bar metric with a reserved name ([#5620](https://github.com/PyTorchLightning/pytorch-lightning/pull/5620)) + From a900f8191d73d642211893cf28ec3dac88adfbbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 23 Jan 2021 22:44:34 +0100 Subject: [PATCH 3/7] suggestions from rohit Co-authored-by: Rohit Gupta --- pytorch_lightning/trainer/properties.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index c0dd34af898a0..37c5bad77faed 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -183,18 +183,18 @@ def progress_bar_callback(self): @property def progress_bar_dict(self) -> dict: """ Read-only for progress bar metrics. """ - ref_model = self.model if not self.data_parallel else self.model.module + ref_model = self.get_model() ref_model = cast(LightningModule, ref_model) standard_metrics = ref_model.get_progress_bar_dict() - logged_metrics = self.logger_connector.progress_bar_metrics + logged_metrics = self.progress_bar_metrics duplicates = list(standard_metrics.keys() & logged_metrics.keys()) if duplicates: rank_zero_warn( f"The progress bar already tracks a metric with the name '{duplicates[0]}' and" f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. " - f" If this is undesired, change the name or override `get_progress_bar_dict()`." - f" in LightingModule.", + f" If this is undesired, change the name or override `get_progress_bar_dict()`" + f" in `LightingModule`.", UserWarning ) all_metrics = dict(**standard_metrics) From 5f26f11ec757bfa22c872dc8b2b1326280b9b118 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 24 Jan 2021 03:06:07 +0100 Subject: [PATCH 4/7] multiple values in message --- pytorch_lightning/trainer/properties.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 37c5bad77faed..64b5c5ae2ad0d 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -191,7 +191,7 @@ def progress_bar_dict(self) -> dict: duplicates = list(standard_metrics.keys() & logged_metrics.keys()) if duplicates: rank_zero_warn( - f"The progress bar already tracks a metric with the name '{duplicates[0]}' and" + f"The progress bar already tracks a metric with the name '{', '.join(duplicates)}' and" f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. " f" If this is undesired, change the name or override `get_progress_bar_dict()`" f" in `LightingModule`.", From 71d4d8bd56af66aa7fdc9266e96c5f52adefa0bc Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Sun, 24 Jan 2021 16:24:09 +0530 Subject: [PATCH 5/7] Apply suggestions from code review --- pytorch_lightning/trainer/properties.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 64b5c5ae2ad0d..f16b12f69e221 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -191,7 +191,7 @@ def progress_bar_dict(self) -> dict: duplicates = list(standard_metrics.keys() & logged_metrics.keys()) if duplicates: rank_zero_warn( - f"The progress bar already tracks a metric with the name '{', '.join(duplicates)}' and" + f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and" f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. " f" If this is undesired, change the name or override `get_progress_bar_dict()`" f" in `LightingModule`.", From 8a3e1f87accad20dd3ae9ecc73711de9b5f2226e Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Sun, 24 Jan 2021 17:40:34 +0530 Subject: [PATCH 6/7] test --- tests/trainer/logging/test_logger_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/logging/test_logger_connector.py b/tests/trainer/logging/test_logger_connector.py index fcf6f7ff013ac..10808454e7f4a 100644 --- a/tests/trainer/logging/test_logger_connector.py +++ b/tests/trainer/logging/test_logger_connector.py @@ -441,5 +441,5 @@ def training_step(self, *args, **kwargs): default_root_dir=tmpdir, max_steps=2, ) - with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the name 'loss'"): + with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the name(s) 'loss'"): trainer.fit(model) From 0c442f0a2571c9775cde04ac253211ebeadb39dc Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Mon, 25 Jan 2021 12:15:33 +0530 Subject: [PATCH 7/7] test --- tests/trainer/logging/test_logger_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/logging/test_logger_connector.py b/tests/trainer/logging/test_logger_connector.py index 10808454e7f4a..a06d61a82de65 100644 --- a/tests/trainer/logging/test_logger_connector.py +++ b/tests/trainer/logging/test_logger_connector.py @@ -441,5 +441,5 @@ def training_step(self, *args, **kwargs): default_root_dir=tmpdir, max_steps=2, ) - with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the name(s) 'loss'"): + with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"): trainer.fit(model)