From 0c48ba3d3370923ac7fbf228ffbdff80c731c5f2 Mon Sep 17 00:00:00 2001 From: monney Date: Sat, 14 Mar 2020 13:02:05 -0400 Subject: [PATCH] Add Support for Non-primitive types in TensorboardLogger (#1130) * Added support for non-primitive types to tensorboard logger * added EOF newline * PEP8 * Updated CHANGELOG for PR #1130. Moved _sanitize_params to base logger. Cleaned up _sanitize_params * Updated CHANGELOG for PR #1130. Moved _sanitize_params to base logger. Cleaned up _sanitize_params * changed convert_params to static method * PEP8 * Cleanup Doctest for _sanitize_params Co-Authored-By: Jirka Borovec * Removed OrderedDict import * Updated import order to conventions Co-authored-by: Manbir Gulati Co-authored-by: Jirka Borovec --- CHANGELOG.md | 2 ++ pytorch_lightning/loggers/base.py | 28 +++++++++++++++++++++++- pytorch_lightning/loggers/tensorboard.py | 6 +++-- tests/loggers/test_tensorboard.py | 6 ++++- 4 files changed, 38 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e17728f9424da..1d6cbc3d087812 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added type hints to `pytorch_lightning.core` ([#946](https://github.com/PyTorchLightning/pytorch-lightning/pull/946)) - Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104)) +- Added support for non-primitive types in hparams for TensorboardLogger ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130)) + ### Changed diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 5295bee02fe8ff..8c3daa29cb96e5 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -4,6 +4,8 @@ from functools import wraps from typing import Union, Optional, Dict, Iterable, Any, Callable, List +import torch + def rank_zero_only(fn: Callable): """Decorate a logger method to run it only on the process with rank 0. @@ -42,7 +44,8 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): """ pass - def _convert_params(self, params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]: + @staticmethod + def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]: # in case converting from namespace if isinstance(params, Namespace): params = vars(params) @@ -52,6 +55,29 @@ def _convert_params(self, params: Union[Dict[str, Any], Namespace]) -> Dict[str, return params + @staticmethod + def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: + """Returns params with non-primitvies converted to strings for logging + + >>> params = {"float": 0.3, + ... "int": 1, + ... "string": "abc", + ... "bool": True, + ... "list": [1, 2, 3], + ... "namespace": Namespace(foo=3), + ... "layer": torch.nn.BatchNorm1d} + >>> import pprint + >>> pprint.pprint(LightningLoggerBase._sanitize_params(params)) # doctest: +NORMALIZE_WHITESPACE + {'bool': True, + 'float': 0.3, + 'int': 1, + 'layer': "", + 'list': '[1, 2, 3]', + 'namespace': 'Namespace(foo=3)', + 'string': 'abc'} + """ + return {k: v if type(v) in [bool, int, float, str, torch.Tensor] else str(v) for k, v in params.items()} + @abstractmethod def log_hyperparams(self, params: argparse.Namespace): """Record hyperparameters. diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 9be1d82b7669a0..662ecdf4af4959 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -101,6 +101,7 @@ def experiment(self) -> SummaryWriter: @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = self._convert_params(params) + sanitized_params = self._sanitize_params(params) if parse_version(torch.__version__) < parse_version("1.3.0"): warn( @@ -110,13 +111,14 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: ) else: from torch.utils.tensorboard.summary import hparams - exp, ssi, sei = hparams(params, {}) + exp, ssi, sei = hparams(sanitized_params, {}) writer = self.experiment._get_file_writer() writer.add_summary(exp) writer.add_summary(ssi) writer.add_summary(sei) + # some alternative should be added - self.tags.update(params) + self.tags.update(sanitized_params) @rank_zero_only def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index b3f3d19242c8c1..e815384011f982 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -1,4 +1,5 @@ import pickle +from argparse import Namespace import pytest import torch @@ -108,6 +109,9 @@ def test_tensorboard_log_hyperparams(tmpdir): "float": 0.3, "int": 1, "string": "abc", - "bool": True + "bool": True, + "list": [1, 2, 3], + "namespace": Namespace(foo=3), + "layer": torch.nn.BatchNorm1d } logger.log_hyperparams(hparams)