diff --git a/pytorch_lightning/loggers/csv.py b/pytorch_lightning/loggers/csv.py index a3c1e46348d140..e9aa13c127cc1b 100644 --- a/pytorch_lightning/loggers/csv.py +++ b/pytorch_lightning/loggers/csv.py @@ -34,7 +34,7 @@ class ExperimentWriter(object): NAME_HPARAMS_FILE = 'hparams.yaml' NAME_METRICS_FILE = 'metrics.csv' - def __init__(self, log_dir): + def __init__(self, log_dir: str) -> None: self.hparams = {} self.metrics = [] self.metrics_keys = ["step"] @@ -49,11 +49,11 @@ def __init__(self, log_dir): self.metrics_file_path = os.path.join(self.log_dir, self.NAME_METRICS_FILE) - def log_hparams(self, params): + def log_hparams(self, params: Dict[str, Any]) -> None: """Record hparams""" self.hparams.update(params) - def log_metrics(self, metrics_dict, step=None): + def log_metrics(self, metrics_dict: Dict[str, float], step: Optional[int] = None) -> None: """Record metrics""" def _handle_value(value): if isinstance(value, torch.Tensor): @@ -71,7 +71,7 @@ def _handle_value(value): new_row[k] = _handle_value(v) self.metrics.append(new_row) - def save(self): + def save(self) -> None: """Save recorded hparams and metrics into files""" hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE) save_hparams_to_yaml(hparams_file, self.hparams) @@ -135,6 +135,10 @@ def log_dir(self) -> str: log_dir = os.path.join(self.root_dir, version) return log_dir + @property + def save_dir(self) -> Optional[str]: + return self._save_dir + @property def experiment(self) -> ExperimentWriter: r"""