Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for Non-primitive types in TensorboardLogger #1130

Merged
merged 11 commits into from
Mar 14, 2020
18 changes: 16 additions & 2 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def experiment(self) -> SummaryWriter:
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
sanitized_params = self._sanitize_params(params)

if parse_version(torch.__version__) < parse_version("1.3.0"):
warn(
Expand All @@ -110,13 +111,14 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
)
else:
from torch.utils.tensorboard.summary import hparams
exp, ssi, sei = hparams(params, {})
exp, ssi, sei = hparams(sanitized_params, {})
writer = self.experiment._get_file_writer()
writer.add_summary(exp)
writer.add_summary(ssi)
writer.add_summary(sei)

# some alternative should be added
self.tags.update(params)
self.tags.update(sanitized_params)

@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
Expand Down Expand Up @@ -173,3 +175,15 @@ def _get_next_version(self):
return 0

return max(existing_versions) + 1

def _sanitize_params(self, params):
Borda marked this conversation as resolved.
Show resolved Hide resolved
native_types = [int, bool, float, str, torch.Tensor]
Borda marked this conversation as resolved.
Show resolved Hide resolved
out_dict = {}

for k, v in params.items():
if type(v) not in native_types:
out_dict[k] = repr(v)
else:
out_dict[k] = v
monney marked this conversation as resolved.
Show resolved Hide resolved

return out_dict
6 changes: 5 additions & 1 deletion tests/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
TensorBoardLogger
)
from tests.models import LightningTestModel
from argparse import Namespace


def test_tensorboard_logger(tmpdir):
Expand Down Expand Up @@ -108,6 +109,9 @@ def test_tensorboard_log_hyperparams(tmpdir):
"float": 0.3,
"int": 1,
"string": "abc",
"bool": True
"bool": True,
"list": [1, 2, 3],
"namespace": Namespace(foo=3),
"layer": torch.nn.BatchNorm1d
}
logger.log_hyperparams(hparams)