Skip to content

Commit

Permalink
FIX-5311: Cast to string _flatten_dict (#5354)
Browse files Browse the repository at this point in the history
* fix

* params

* add test

* add another types

* chlog

Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 5, 2021
1 parent ec0fb7a commit 6536ea4
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Disabled checkpointing, earlystopping and logger with `fast_dev_run` ([#5277](https://github.com/PyTorchLightning/pytorch-lightning/pull/5277))

- Fixed casted key to string in `_flatten_dict` ([#5354](https://github.com/PyTorchLightning/pytorch-lightning/pull/5354))


## [1.1.2] - 2020-12-23

Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _sanitize_callable(val):
return {key: _sanitize_callable(val) for key, val in params.items()}

@staticmethod
def _flatten_dict(params: Dict[str, Any], delimiter: str = '/') -> Dict[str, Any]:
def _flatten_dict(params: Dict[Any, Any], delimiter: str = '/') -> Dict[str, Any]:
"""
Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``.
Expand All @@ -223,12 +223,15 @@ def _flatten_dict(params: Dict[str, Any], delimiter: str = '/') -> Dict[str, Any
{'a/b': 'c'}
>>> LightningLoggerBase._flatten_dict({'a': {'b': 123}})
{'a/b': 123}
>>> LightningLoggerBase._flatten_dict({5: {'a': 123}})
{'5/a': 123}
"""

def _dict_generator(input_dict, prefixes=None):
prefixes = prefixes[:] if prefixes else []
if isinstance(input_dict, MutableMapping):
for key, value in input_dict.items():
key = str(key)
if isinstance(value, (MutableMapping, Namespace)):
value = vars(value) if isinstance(value, Namespace) else value
for d in _dict_generator(value, prefixes + [key]):
Expand Down
6 changes: 3 additions & 3 deletions tests/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from omegaconf import OmegaConf
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from tests.base import BoringModel, EvalModelTemplate

Expand Down Expand Up @@ -102,7 +102,7 @@ def test_tensorboard_named_version(tmpdir):
expected_version = "2020-02-05-162402"

logger = TensorBoardLogger(save_dir=tmpdir, name=name, version=expected_version)
logger.log_hyperparams({"a": 1, "b": 2}) # Force data to be written
logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5}) # Force data to be written

assert logger.version == expected_version
assert os.listdir(tmpdir / name) == [expected_version]
Expand All @@ -113,7 +113,7 @@ def test_tensorboard_named_version(tmpdir):
def test_tensorboard_no_name(tmpdir, name):
"""Verify that None or empty name works"""
logger = TensorBoardLogger(save_dir=tmpdir, name=name)
logger.log_hyperparams({"a": 1, "b": 2}) # Force data to be written
logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5}) # Force data to be written
assert logger.root_dir == tmpdir
assert os.listdir(tmpdir / "version_0")

Expand Down

0 comments on commit 6536ea4

Please sign in to comment.