diff --git a/CHANGELOG.md b/CHANGELOG.md index 80e40457ffa0f..8bb0d31169b87 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Disabled checkpointing, earlystopping and logger with `fast_dev_run` ([#5277](https://github.com/PyTorchLightning/pytorch-lightning/pull/5277)) +- Fixed casted key to string in `_flatten_dict` ([#5354](https://github.com/PyTorchLightning/pytorch-lightning/pull/5354)) + ## [1.1.2] - 2020-12-23 diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index a27998366b671..ac7ab3e023bdb 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -207,7 +207,7 @@ def _sanitize_callable(val): return {key: _sanitize_callable(val) for key, val in params.items()} @staticmethod - def _flatten_dict(params: Dict[str, Any], delimiter: str = '/') -> Dict[str, Any]: + def _flatten_dict(params: Dict[Any, Any], delimiter: str = '/') -> Dict[str, Any]: """ Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``. @@ -223,12 +223,15 @@ def _flatten_dict(params: Dict[str, Any], delimiter: str = '/') -> Dict[str, Any {'a/b': 'c'} >>> LightningLoggerBase._flatten_dict({'a': {'b': 123}}) {'a/b': 123} + >>> LightningLoggerBase._flatten_dict({5: {'a': 123}}) + {'5/a': 123} """ def _dict_generator(input_dict, prefixes=None): prefixes = prefixes[:] if prefixes else [] if isinstance(input_dict, MutableMapping): for key, value in input_dict.items(): + key = str(key) if isinstance(value, (MutableMapping, Namespace)): value = vars(value) if isinstance(value, Namespace) else value for d in _dict_generator(value, prefixes + [key]): diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index 15a024003ebf0..fa5c711357ba3 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -22,7 +22,7 @@ from omegaconf import OmegaConf from tensorboard.backend.event_processing.event_accumulator import EventAccumulator -from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.loggers import TensorBoardLogger from tests.base import BoringModel, EvalModelTemplate @@ -102,7 +102,7 @@ def test_tensorboard_named_version(tmpdir): expected_version = "2020-02-05-162402" logger = TensorBoardLogger(save_dir=tmpdir, name=name, version=expected_version) - logger.log_hyperparams({"a": 1, "b": 2}) # Force data to be written + logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5}) # Force data to be written assert logger.version == expected_version assert os.listdir(tmpdir / name) == [expected_version] @@ -113,7 +113,7 @@ def test_tensorboard_named_version(tmpdir): def test_tensorboard_no_name(tmpdir, name): """Verify that None or empty name works""" logger = TensorBoardLogger(save_dir=tmpdir, name=name) - logger.log_hyperparams({"a": 1, "b": 2}) # Force data to be written + logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5}) # Force data to be written assert logger.root_dir == tmpdir assert os.listdir(tmpdir / "version_0")