Skip to content

Commit

Permalink
Tensorboard path generalisation (#804)
Browse files Browse the repository at this point in the history
* Allow experiment versions to be overridden by passing a string value.
Allow experiment names to be empty, in which case no per-experiment subdirectory will be created and checkpoints will be saved in the directory given by the save_dir parameter.

* Document tensorboard api changes

* Review comment fixes plus fixed test failure for minimum requirements build

* More format fixes from review
  • Loading branch information
bobkemp committed Feb 10, 2020
1 parent fc0ad03 commit 8fa802e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 9 deletions.
42 changes: 33 additions & 9 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ class TensorBoardLogger(LightningLoggerBase):
Args:
save_dir (str): Save directory
name (str): Experiment name. Defaults to "default".
version (int): Experiment version. If version is not specified the logger inspects the save
directory for existing versions, then automatically assigns the next available version.
name (str): Experiment name. Defaults to "default". If it is the empty string then no per-experiment
subdirectory is used.
version (int|str): Experiment version. If version is not specified the logger inspects the save
directory for existing versions, then automatically assigns the next available version.
If it is a string then it is used as the run-specific subdirectory name,
otherwise version_${version} is used.
\**kwargs (dict): Other arguments are passed directly to the :class:`SummaryWriter` constructor.
"""
Expand All @@ -47,6 +50,30 @@ def __init__(self, save_dir, name="default", version=None, **kwargs):
self.tags = {}
self.kwargs = kwargs

@property
def root_dir(self):
"""
Parent directory for all tensorboard checkpoint subdirectories.
If the experiment name parameter is None or the empty string, no experiment subdirectory is used
and checkpoint will be saved in save_dir/version_dir
"""
if self.name is None or len(self.name) == 0:
return self.save_dir
else:
return os.path.join(self.save_dir, self.name)

@property
def log_dir(self):
"""
The directory for this run's tensorboard checkpoint. By default, it is named 'version_${self.version}'
but it can be overridden by passing a string value for the constructor's version parameter
instead of None or an int
"""
# create a pseudo standard path ala test-tube
version = self.version if isinstance(self.version, str) else f"version_{self.version}"
log_dir = os.path.join(self.root_dir, version)
return log_dir

@property
def experiment(self):
r"""
Expand All @@ -61,10 +88,8 @@ def experiment(self):
if self._experiment is not None:
return self._experiment

root_dir = os.path.join(self.save_dir, self.name)
os.makedirs(root_dir, exist_ok=True)
log_dir = os.path.join(root_dir, "version_" + str(self.version))
self._experiment = SummaryWriter(log_dir=log_dir, **self.kwargs)
os.makedirs(self.root_dir, exist_ok=True)
self._experiment = SummaryWriter(log_dir=self.log_dir, **self.kwargs)
return self._experiment

@rank_zero_only
Expand Down Expand Up @@ -108,8 +133,7 @@ def save(self):
# you are using PT version (<v1.2) which does not have implemented flush
self.experiment._get_file_writer().flush()

# create a preudo standard path ala test-tube
dir_path = os.path.join(self.save_dir, self.name, 'version_%s' % self.version)
dir_path = self.log_dir
if not os.path.isdir(dir_path):
dir_path = self.save_dir

Expand Down
13 changes: 13 additions & 0 deletions tests/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,19 @@ def test_tensorboard_manual_versioning(tmpdir):
assert logger.version == 1


def test_tensorboard_named_version(tmpdir):
"""Verify that manual versioning works for string versions, e.g. '2020-02-05-162402' """

tmpdir.mkdir("tb_versioning")
expected_version = "2020-02-05-162402"

logger = TensorBoardLogger(save_dir=tmpdir, name="tb_versioning", version=expected_version)
logger.log_hyperparams({"a": 1, "b": 2}) # Force data to be written

assert logger.version == expected_version
# Could also test existence of the directory but this fails in the "minimum requirements" test setup


@pytest.mark.parametrize("step_idx", [10, None])
def test_tensorboard_log_metrics(tmpdir, step_idx):
logger = TensorBoardLogger(tmpdir)
Expand Down

0 comments on commit 8fa802e

Please sign in to comment.