diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d758c6d3cb9d..0a50740142667 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,6 +66,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a missing call to the `on_before_zero_grad` model hook ([#1493](https://github.com/PyTorchLightning/pytorch-lightning/pull/1493)). +- Allow use of sweeps with WandbLogger ([#1512](https://github.com/PyTorchLightning/pytorch-lightning/pull/1512)) + - Fixed a bug that caused the `callbacks` Trainer argument to reference a global variable ([#1534](https://github.com/PyTorchLightning/pytorch-lightning/pull/1534)). - Fixed a bug that set all boolean CLI arguments from Trainer.add_argparse_args always to True ([#1570](https://github.com/PyTorchLightning/pytorch-lightning/issues/1570)) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 8d4bd0aa4a355..3e8443056c722 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -114,7 +114,7 @@ def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100): @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = self._convert_params(params) - self.experiment.config.update(params) + self.experiment.config.update(params, allow_val_change=True) @rank_zero_only def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 87240ac3edb70..d2ef6318c8c14 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -23,7 +23,7 @@ def test_wandb_logger(wandb): wandb.init().log.assert_called_once_with({'global_step': 3, 'acc': 1.0}) logger.log_hyperparams({'test': None}) - wandb.init().config.update.assert_called_once_with({'test': None}) + wandb.init().config.update.assert_called_once_with({'test': None}, allow_val_change=True) logger.watch('model', 'log', 10) wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10)