Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for Non-primitive types in TensorboardLogger #1130

Merged
merged 11 commits into from
Mar 14, 2020
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added type hints to `pytorch_lightning.core` ([#946](https://github.com/PyTorchLightning/pytorch-lightning/pull/946))
- Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104))
- Added support for non-primitive types in hparams for TensorboardLogger ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130))


### Changed

Expand Down
27 changes: 26 additions & 1 deletion pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import torch
monney marked this conversation as resolved.
Show resolved Hide resolved
from abc import ABC, abstractmethod
from argparse import Namespace
from functools import wraps
Expand Down Expand Up @@ -42,7 +43,8 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
"""
pass

def _convert_params(self, params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:
@staticmethod
def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:
# in case converting from namespace
if isinstance(params, Namespace):
params = vars(params)
Expand All @@ -52,6 +54,29 @@ def _convert_params(self, params: Union[Dict[str, Any], Namespace]) -> Dict[str,

return params

@staticmethod
def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
"""Returns params with non-primitvies converted to strings for logging

>>> params = {"float": 0.3,
... "int": 1,
... "string": "abc",
... "bool": True,
... "list": [1, 2, 3],
... "namespace": Namespace(foo=3),
... "layer": torch.nn.BatchNorm1d}
>>> import pprint
>>> pprint.pprint(LightningLoggerBase._sanitize_params(params)) # doctest: +NORMALIZE_WHITESPACE
{'bool': True,
'float': 0.3,
'int': 1,
'layer': "<class 'torch.nn.modules.batchnorm.BatchNorm1d'>",
'list': '[1, 2, 3]',
'namespace': 'Namespace(foo=3)',
'string': 'abc'}
"""
return {k: v if type(v) in [bool, int, float, str, torch.Tensor] else str(v) for k, v in params.items()}

@abstractmethod
def log_hyperparams(self, params: argparse.Namespace):
"""Record hyperparameters.
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def experiment(self) -> SummaryWriter:
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
sanitized_params = self._sanitize_params(params)

if parse_version(torch.__version__) < parse_version("1.3.0"):
warn(
Expand All @@ -110,13 +111,14 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
)
else:
from torch.utils.tensorboard.summary import hparams
exp, ssi, sei = hparams(params, {})
exp, ssi, sei = hparams(sanitized_params, {})
writer = self.experiment._get_file_writer()
writer.add_summary(exp)
writer.add_summary(ssi)
writer.add_summary(sei)

# some alternative should be added
self.tags.update(params)
self.tags.update(sanitized_params)

@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
Expand Down
6 changes: 5 additions & 1 deletion tests/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
TensorBoardLogger
)
from tests.models import LightningTestModel
from argparse import Namespace


def test_tensorboard_logger(tmpdir):
Expand Down Expand Up @@ -108,6 +109,9 @@ def test_tensorboard_log_hyperparams(tmpdir):
"float": 0.3,
"int": 1,
"string": "abc",
"bool": True
"bool": True,
"list": [1, 2, 3],
"namespace": Namespace(foo=3),
"layer": torch.nn.BatchNorm1d
}
logger.log_hyperparams(hparams)