From 610321d92696afae00e71039156c6789d163ada9 Mon Sep 17 00:00:00 2001 From: dfhkjdfhjdf Date: Thu, 12 Mar 2020 02:02:41 +0200 Subject: [PATCH 1/2] Add TrainsLogger test --- tests/loggers/test_trains.py | 48 ++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tests/loggers/test_trains.py diff --git a/tests/loggers/test_trains.py b/tests/loggers/test_trains.py new file mode 100644 index 00000000000000..1c8ca4167462a4 --- /dev/null +++ b/tests/loggers/test_trains.py @@ -0,0 +1,48 @@ +import pickle + +import tests.models.utils as tutils +from pytorch_lightning import Trainer +from pytorch_lightning.loggers import TrainsLogger +from tests.models import LightningTestModel + + +def test_trains_logger(tmpdir): + """Verify that basic functionality of TRAINS logger works.""" + tutils.reset_seed() + + hparams = tutils.get_hparams() + model = LightningTestModel(hparams) + logger = TrainsLogger(project_name="examples", task_name="pytorch lightning test") + + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + train_percent_check=0.05, + logger=logger + ) + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + print('result finished') + assert result == 1, "Training failed" + + +def test_trains_pickle(tmpdir): + """Verify that pickling trainer with TRAINS logger works.""" + tutils.reset_seed() + + # hparams = tutils.get_hparams() + # model = LightningTestModel(hparams) + + logger = TrainsLogger(project_name="examples", task_name="pytorch lightning test") + + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + logger=logger + ) + + trainer = Trainer(**trainer_options) + pkl_bytes = pickle.dumps(trainer) + trainer2 = pickle.loads(pkl_bytes) + trainer2.logger.log_metrics({"acc": 1.0}) From bb178a12c3991c7e8c2a1849de4fff73c03803e3 Mon Sep 17 00:00:00 2001 From: dfhkjdfhjdf Date: Thu, 12 Mar 2020 02:03:28 +0200 Subject: [PATCH 2/2] Add TrainsLogger PR in CHANGELOG --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 985accb595374a..b352c2bccf04f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- +- Added `TrainsLogger` class ([#1122](https://github.com/PyTorchLightning/pytorch-lightning/pull/1122)) ### Changed