Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for Non-primitive types in TensorboardLogger #1130

Merged
merged 11 commits into from
Mar 14, 2020
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added type hints to `pytorch_lightning.core` ([#946](https://github.com/PyTorchLightning/pytorch-lightning/pull/946))
- Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104))
- Added support for non-primitive types in hparams for TensorboardLogger ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130))


### Changed

Expand Down
28 changes: 27 additions & 1 deletion pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from functools import wraps
from typing import Union, Optional, Dict, Iterable, Any, Callable, List

import torch


def rank_zero_only(fn: Callable):
"""Decorate a logger method to run it only on the process with rank 0.
Expand Down Expand Up @@ -42,7 +44,8 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
"""
pass

def _convert_params(self, params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:
@staticmethod
def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:
# in case converting from namespace
if isinstance(params, Namespace):
params = vars(params)
Expand All @@ -52,6 +55,29 @@ def _convert_params(self, params: Union[Dict[str, Any], Namespace]) -> Dict[str,

return params

@staticmethod
def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
"""Returns params with non-primitvies converted to strings for logging

>>> params = {"float": 0.3,
... "int": 1,
... "string": "abc",
... "bool": True,
... "list": [1, 2, 3],
... "namespace": Namespace(foo=3),
... "layer": torch.nn.BatchNorm1d}
>>> import pprint
>>> pprint.pprint(LightningLoggerBase._sanitize_params(params)) # doctest: +NORMALIZE_WHITESPACE
{'bool': True,
'float': 0.3,
'int': 1,
'layer': "<class 'torch.nn.modules.batchnorm.BatchNorm1d'>",
'list': '[1, 2, 3]',
'namespace': 'Namespace(foo=3)',
'string': 'abc'}
"""
return {k: v if type(v) in [bool, int, float, str, torch.Tensor] else str(v) for k, v in params.items()}

@abstractmethod
def log_hyperparams(self, params: argparse.Namespace):
"""Record hyperparameters.
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def experiment(self) -> SummaryWriter:
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
sanitized_params = self._sanitize_params(params)

if parse_version(torch.__version__) < parse_version("1.3.0"):
warn(
Expand All @@ -110,13 +111,14 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
)
else:
from torch.utils.tensorboard.summary import hparams
exp, ssi, sei = hparams(params, {})
exp, ssi, sei = hparams(sanitized_params, {})
writer = self.experiment._get_file_writer()
writer.add_summary(exp)
writer.add_summary(ssi)
writer.add_summary(sei)

# some alternative should be added
self.tags.update(params)
self.tags.update(sanitized_params)

@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
Expand Down
6 changes: 5 additions & 1 deletion tests/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pickle
from argparse import Namespace

import pytest
import torch
Expand Down Expand Up @@ -108,6 +109,9 @@ def test_tensorboard_log_hyperparams(tmpdir):
"float": 0.3,
"int": 1,
"string": "abc",
"bool": True
"bool": True,
"list": [1, 2, 3],
"namespace": Namespace(foo=3),
"layer": torch.nn.BatchNorm1d
}
logger.log_hyperparams(hparams)