diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 50c531f75afb62..000d1d26680f1d 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -20,6 +20,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed signature inspection of decorated hooks ([#17507](https://github.com/Lightning-AI/lightning/pull/17507)) +- The `WandbLogger` no longer flattens dictionaries in the hyperparameters logged to the dashboard ([#17574](https://github.com/Lightning-AI/lightning/pull/17574)) + + ## [2.0.2] - 2023-04-24 diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index b303ddc5957e03..fd49501b13c164 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -24,7 +24,7 @@ from lightning_utilities.core.imports import RequirementCache from torch import Tensor -from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _sanitize_callable_params +from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params from lightning.fabric.utilities.types import _PATH from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment @@ -418,7 +418,6 @@ def watch(self, model: nn.Module, log: str = "gradients", log_freq: int = 100, l @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = _convert_params(params) - params = _flatten_dict(params) params = _sanitize_callable_params(params) self.experiment.config.update(params, allow_val_change=True) diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index bbdae2b0de0df8..0446d929a642ca 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -103,10 +103,9 @@ def test_wandb_logger_init(wandb, monkeypatch): wandb.init().log.assert_called_with({"acc": 1.0, "trainer/global_step": 6}) # log hyper parameters - logger.log_hyperparams({"test": None, "nested": {"a": 1}, "b": [2, 3, 4]}) - wandb.init().config.update.assert_called_once_with( - {"test": None, "nested/a": 1, "b": [2, 3, 4]}, allow_val_change=True - ) + hparams = {"test": None, "nested": {"a": 1}, "b": [2, 3, 4]} + logger.log_hyperparams(hparams) + wandb.init().config.update.assert_called_once_with(hparams, allow_val_change=True) # watch a model logger.watch("model", "log", 10, False)