From 35c235d153d38514c3805d7d3a85fc7cf4063679 Mon Sep 17 00:00:00 2001 From: xmotli02 Date: Sun, 26 Jul 2020 22:42:14 +0200 Subject: [PATCH] Added basic file logger #1803 --- docs/source/loggers.rst | 6 + pytorch_lightning/loggers/__init__.py | 2 + pytorch_lightning/loggers/file_logger.py | 180 +++++++++++++++++++++++ tests/loggers/test_file_logger.py | 85 +++++++++++ 4 files changed, 273 insertions(+) create mode 100644 pytorch_lightning/loggers/file_logger.py create mode 100644 tests/loggers/test_file_logger.py diff --git a/docs/source/loggers.rst b/docs/source/loggers.rst index eb94b52c21f172..02d54da1a92074 100644 --- a/docs/source/loggers.rst +++ b/docs/source/loggers.rst @@ -138,4 +138,10 @@ Test-tube ^^^^^^^^^ .. autoclass:: pytorch_lightning.loggers.test_tube.TestTubeLogger + :noindex: + +FileLogger +^^^^^^^^^^ + +.. autoclass:: pytorch_lightning.loggers.file_logger.FileLogger :noindex: \ No newline at end of file diff --git a/pytorch_lightning/loggers/__init__.py b/pytorch_lightning/loggers/__init__.py index daa2b99bb80c6f..d2e713ba2c0120 100644 --- a/pytorch_lightning/loggers/__init__.py +++ b/pytorch_lightning/loggers/__init__.py @@ -2,6 +2,8 @@ from pytorch_lightning.loggers.base import LightningLoggerBase, LoggerCollection from pytorch_lightning.loggers.tensorboard import TensorBoardLogger +from pytorch_lightning.loggers.file_logger import FileLogger + __all__ = [ 'LightningLoggerBase', diff --git a/pytorch_lightning/loggers/file_logger.py b/pytorch_lightning/loggers/file_logger.py new file mode 100644 index 00000000000000..adb91054668673 --- /dev/null +++ b/pytorch_lightning/loggers/file_logger.py @@ -0,0 +1,180 @@ +""" +File logger +----------- +""" +import io +import os +import csv +import torch + +from argparse import Namespace +from typing import Optional, Dict, Any, Union + + +from pytorch_lightning import _logger as log +from pytorch_lightning.core.saving import save_hparams_to_yaml +from pytorch_lightning.loggers.base import LightningLoggerBase +from pytorch_lightning.utilities.distributed import rank_zero_only + + +class ExperimentWriter(object): + NAME_HPARAMS_FILE = 'hparams.yaml' + NAME_METRICS_FILE = 'metrics.csv' + + def __init__(self, log_dir): + self.hparams = {} + self.metrics = [] + self.metrics_keys = ["step"] + + self.log_dir = log_dir + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + self.metrics_file_path = os.path.join(self.log_dir, self.NAME_METRICS_FILE) + + def log_hparams(self, params): + self.hparams.update(params) + + def log_metrics(self, metrics_dict, step=None): + def _handle_value(value): + if isinstance(value, torch.Tensor): + return value.item() + return value + + if step is None: + step = len(self.metrics) + + new_row = dict.fromkeys(self.metrics_keys) + new_row['step'] = step + for k, v in metrics_dict.items(): + if k not in self.metrics_keys: + self.metrics_keys.append(k) + new_row[k] = _handle_value(v) + self.metrics.append(new_row) + + def save(self): + hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE) + save_hparams_to_yaml(hparams_file, self.hparams) + + if self.metrics: + with io.open(self.metrics_file_path, 'w', newline='') as f: + self.writer = csv.DictWriter(f, fieldnames=self.metrics_keys) + self.writer.writeheader() + self.writer.writerows(self.metrics) + + +class FileLogger(LightningLoggerBase): + r""" + Log to local file system in yaml and CSV format. Logs are saved to + ``os.path.join(save_dir, name, version)``. + + Example: + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.loggers import FileLogger + >>> logger = FileLogger("logs", name="my_exp_name") + >>> trainer = Trainer(logger=logger) + + Args: + save_dir: Save directory + name: Experiment name. Defaults to ``'default'``. + version: Experiment version. If version is not specified the logger inspects the save + directory for existing versions, then automatically assigns the next available version. + """ + + def __init__(self, + save_dir: str, + name: Optional[str] = "default", + version: Optional[Union[int, str]] = None): + + super().__init__() + self._save_dir = save_dir + self._name = name or '' + self._version = version + self._experiment = None + + @property + def root_dir(self) -> str: + """ + Parent directory for all checkpoint subdirectories. + If the experiment name parameter is ``None`` or the empty string, no experiment subdirectory is used + and the checkpoint will be saved in "save_dir/version_dir" + """ + if self.name is None or len(self.name) == 0: + return self._save_dir + return os.path.join(self._save_dir, self.name) + + @property + def log_dir(self) -> str: + """ + The log directory for this run. By default, it is named + ``'version_${self.version}'`` but it can be overridden by passing a string value + for the constructor's version parameter instead of ``None`` or an int. + """ + # create a pseudo standard path ala test-tube + version = self.version if isinstance(self.version, str) else f"version_{self.version}" + log_dir = os.path.join(self.root_dir, version) + return log_dir + + @property + def experiment(self) -> ExperimentWriter: + r""" + + Actual ExperimentWriter object. To use ExperimentWriter features in your + :class:`~pytorch_lightning.core.lightning.LightningModule` do the following. + + Example:: + + self.logger.experiment.some_experiment_writer_function() + + """ + if self._experiment is not None: + return self._experiment + + os.makedirs(self.root_dir, exist_ok=True) + self._experiment = ExperimentWriter(log_dir=self.log_dir) + return self._experiment + + @rank_zero_only + def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: + params = self._convert_params(params) + self.experiment.log_hparams(params) + + @rank_zero_only + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + self.experiment.log_metrics(metrics, step) + + @rank_zero_only + def save(self) -> None: + super().save() + self.experiment.save() + + @rank_zero_only + def finalize(self, status: str) -> None: + self.save() + + @property + def name(self) -> str: + return self._name + + @property + def version(self) -> int: + if self._version is None: + self._version = self._get_next_version() + return self._version + + def _get_next_version(self): + root_dir = os.path.join(self._save_dir, self.name) + + if not os.path.isdir(root_dir): + log.warning('Missing logger folder: %s', root_dir) + return 0 + + existing_versions = [] + for d in os.listdir(root_dir): + if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"): + existing_versions.append(int(d.split("_")[1])) + + if len(existing_versions) == 0: + return 0 + + return max(existing_versions) + 1 diff --git a/tests/loggers/test_file_logger.py b/tests/loggers/test_file_logger.py new file mode 100644 index 00000000000000..517d1238f844a7 --- /dev/null +++ b/tests/loggers/test_file_logger.py @@ -0,0 +1,85 @@ +from argparse import Namespace + +import pytest +import torch +import os + +from pytorch_lightning.loggers import FileLogger + + +def test_file_logger_automatic_versioning(tmpdir): + """Verify that automatic versioning works""" + + root_dir = tmpdir.mkdir("exp") + root_dir.mkdir("version_0") + root_dir.mkdir("version_1") + + logger = FileLogger(save_dir=tmpdir, name="exp") + + assert logger.version == 2 + + +def test_file_logger_manual_versioning(tmpdir): + """Verify that manual versioning works""" + + root_dir = tmpdir.mkdir("exp") + root_dir.mkdir("version_0") + root_dir.mkdir("version_1") + root_dir.mkdir("version_2") + + logger = FileLogger(save_dir=tmpdir, name="exp", version=1) + + assert logger.version == 1 + + +def test_file_logger_named_version(tmpdir): + """Verify that manual versioning works for string versions, e.g. '2020-02-05-162402' """ + + exp_name = "exp" + tmpdir.mkdir(exp_name) + expected_version = "2020-02-05-162402" + + logger = FileLogger(save_dir=tmpdir, name=exp_name, version=expected_version) + logger.log_hyperparams({"a": 1, "b": 2}) + logger.save() + assert logger.version == expected_version + assert os.listdir(tmpdir / exp_name) == [expected_version] + assert os.listdir(tmpdir / exp_name / expected_version) + + +@pytest.mark.parametrize("name", ['', None]) +def test_file_logger_no_name(tmpdir, name): + """Verify that None or empty name works""" + logger = FileLogger(save_dir=tmpdir, name=name) + logger.save() + assert logger.root_dir == tmpdir + assert os.listdir(tmpdir / 'version_0') + + +@pytest.mark.parametrize("step_idx", [10, None]) +def test_file_logger_log_metrics(tmpdir, step_idx): + logger = FileLogger(tmpdir) + metrics = { + "float": 0.3, + "int": 1, + "FloatTensor": torch.tensor(0.1), + "IntTensor": torch.tensor(1) + } + logger.log_metrics(metrics, step_idx) + logger.save() + + +def test_file_logger_log_hyperparams(tmpdir): + logger = FileLogger(tmpdir) + hparams = { + "float": 0.3, + "int": 1, + "string": "abc", + "bool": True, + "dict": {'a': {'b': 'c'}}, + "list": [1, 2, 3], + "namespace": Namespace(foo=Namespace(bar='buzz')), + "layer": torch.nn.BatchNorm1d + } + logger.log_hyperparams(hparams) + logger.save()