From e2422906096917378ad88677d505eda8ffff1219 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 5 Aug 2020 21:39:28 +0200 Subject: [PATCH] tests --- pytorch_lightning/core/saving.py | 2 +- tests/loggers/test_csv.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 5e3ef1d97236d7..37c63de1804ab1 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -313,7 +313,7 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]: return {} with open(config_yaml) as fp: - tags = yaml.load(fp, Loader=yaml.SafeLoader) + tags = yaml.load(fp) return tags diff --git a/tests/loggers/test_csv.py b/tests/loggers/test_csv.py index f2e8ab85fb9e56..3bc8330075e6a1 100644 --- a/tests/loggers/test_csv.py +++ b/tests/loggers/test_csv.py @@ -70,9 +70,11 @@ def test_file_logger_log_metrics(tmpdir, step_idx): logger.log_metrics(metrics, step_idx) logger.save() - path_yaml = os.path.join(logger.log_dir, ExperimentWriter.NAME_HPARAMS_FILE) - params = load_hparams_from_yaml(path_yaml) - assert all([n in params for n in metrics]) + path_csv = os.path.join(logger.log_dir, ExperimentWriter.NAME_METRICS_FILE) + with open(path_csv, 'r') as fp: + lines = fp.readlines() + assert len(lines) == 2 + assert all([n in lines[0] for n in metrics]) def test_file_logger_log_hyperparams(tmpdir): @@ -89,3 +91,7 @@ def test_file_logger_log_hyperparams(tmpdir): } logger.log_hyperparams(hparams) logger.save() + + path_yaml = os.path.join(logger.log_dir, ExperimentWriter.NAME_HPARAMS_FILE) + params = load_hparams_from_yaml(path_yaml) + assert all([n in params for n in hparams])