Skip to content

Commit

Permalink
Avoid flattening hyperparameters in WandbLogger (#17574)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

(cherry picked from commit 6c8c02d)
  • Loading branch information
awaelchli authored and Borda committed Jun 1, 2023
1 parent dbb983e commit 4ebffd7
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed signature inspection of decorated hooks ([#17507](https://github.com/Lightning-AI/lightning/pull/17507))


- The `WandbLogger` no longer flattens dictionaries in the hyperparameters logged to the dashboard ([#17574](https://github.com/Lightning-AI/lightning/pull/17574))



## [2.0.2] - 2023-04-24

Expand Down
3 changes: 1 addition & 2 deletions src/lightning/pytorch/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor

from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _sanitize_callable_params
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
Expand Down Expand Up @@ -418,7 +418,6 @@ def watch(self, model: nn.Module, log: str = "gradients", log_freq: int = 100, l
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = _convert_params(params)
params = _flatten_dict(params)
params = _sanitize_callable_params(params)
self.experiment.config.update(params, allow_val_change=True)

Expand Down
7 changes: 3 additions & 4 deletions tests/tests_pytorch/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,9 @@ def test_wandb_logger_init(wandb, monkeypatch):
wandb.init().log.assert_called_with({"acc": 1.0, "trainer/global_step": 6})

# log hyper parameters
logger.log_hyperparams({"test": None, "nested": {"a": 1}, "b": [2, 3, 4]})
wandb.init().config.update.assert_called_once_with(
{"test": None, "nested/a": 1, "b": [2, 3, 4]}, allow_val_change=True
)
hparams = {"test": None, "nested": {"a": 1}, "b": [2, 3, 4]}
logger.log_hyperparams(hparams)
wandb.init().config.update.assert_called_once_with(hparams, allow_val_change=True)

# watch a model
logger.watch("model", "log", 10, False)
Expand Down

0 comments on commit 4ebffd7

Please sign in to comment.