diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 48459cd889e33..e22a5c21a6a5f 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -29,9 +29,12 @@ class TensorBoardLogger(LightningLoggerBase): Args: save_dir (str): Save directory - name (str): Experiment name. Defaults to "default". - version (int): Experiment version. If version is not specified the logger inspects the save - directory for existing versions, then automatically assigns the next available version. + name (str): Experiment name. Defaults to "default". If it is the empty string then no per-experiment + subdirectory is used. + version (int|str): Experiment version. If version is not specified the logger inspects the save + directory for existing versions, then automatically assigns the next available version. + If it is a string then it is used as the run-specific subdirectory name, + otherwise version_${version} is used. \**kwargs (dict): Other arguments are passed directly to the :class:`SummaryWriter` constructor. """ @@ -47,6 +50,30 @@ def __init__(self, save_dir, name="default", version=None, **kwargs): self.tags = {} self.kwargs = kwargs + @property + def root_dir(self): + """ + Parent directory for all tensorboard checkpoint subdirectories. + If the experiment name parameter is None or the empty string, no experiment subdirectory is used + and checkpoint will be saved in save_dir/version_dir + """ + if self.name is None or len(self.name) == 0: + return self.save_dir + else: + return os.path.join(self.save_dir, self.name) + + @property + def log_dir(self): + """ + The directory for this run's tensorboard checkpoint. By default, it is named 'version_${self.version}' + but it can be overridden by passing a string value for the constructor's version parameter + instead of None or an int + """ + # create a pseudo standard path ala test-tube + version = self.version if isinstance(self.version, str) else f"version_{self.version}" + log_dir = os.path.join(self.root_dir, version) + return log_dir + @property def experiment(self): r""" @@ -61,10 +88,8 @@ def experiment(self): if self._experiment is not None: return self._experiment - root_dir = os.path.join(self.save_dir, self.name) - os.makedirs(root_dir, exist_ok=True) - log_dir = os.path.join(root_dir, "version_" + str(self.version)) - self._experiment = SummaryWriter(log_dir=log_dir, **self.kwargs) + os.makedirs(self.root_dir, exist_ok=True) + self._experiment = SummaryWriter(log_dir=self.log_dir, **self.kwargs) return self._experiment @rank_zero_only @@ -108,8 +133,7 @@ def save(self): # you are using PT version (