From 0da220c345e1f5b60357007fa9293d2fe635825d Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 23 Jun 2020 20:31:33 -0400 Subject: [PATCH] fixes slurm weights saving --- tests/trainer/test_trainer_steps.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/trainer/test_trainer_steps.py b/tests/trainer/test_trainer_steps.py index 8fd14e281df01..88ff4f8c7a16e 100644 --- a/tests/trainer/test_trainer_steps.py +++ b/tests/trainer/test_trainer_steps.py @@ -24,8 +24,8 @@ def test_trainingstep_dict(tmpdir): out = trainer.run_training_batch(batch, batch_idx) assert out.signal == 0 - assert out.all_log_metrics['log_acc1'] == 12.0 - assert out.all_log_metrics['log_acc2'] == 7.0 + assert out.batch_log_metrics['log_acc1'] == 12.0 + assert out.batch_log_metrics['log_acc2'] == 7.0 pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end'] assert pbar_metrics['pbar_acc1'] == 17.0 @@ -55,8 +55,8 @@ def training_step_with_step_end(tmpdir): out = trainer.run_training_batch(batch, batch_idx) assert out.signal == 0 - assert out.all_log_metrics['log_acc1'] == 12.0 - assert out.all_log_metrics['log_acc2'] == 7.0 + assert out.batch_log_metrics['log_acc1'] == 12.0 + assert out.batch_log_metrics['log_acc2'] == 7.0 pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end'] assert pbar_metrics['pbar_acc1'] == 17.0 @@ -91,8 +91,8 @@ def test_full_training_loop_dict(tmpdir): out = trainer.run_training_batch(batch, batch_idx) assert out.signal == 0 - assert out.all_log_metrics['log_acc1'] == 12.0 - assert out.all_log_metrics['log_acc2'] == 7.0 + assert out.batch_log_metrics['log_acc1'] == 12.0 + assert out.batch_log_metrics['log_acc2'] == 7.0 pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end'] assert pbar_metrics['pbar_acc1'] == 17.0 @@ -127,8 +127,8 @@ def test_train_step_epoch_end(tmpdir): out = trainer.run_training_batch(batch, batch_idx) assert out.signal == 0 - assert out.all_log_metrics['log_acc1'] == 12.0 - assert out.all_log_metrics['log_acc2'] == 7.0 + assert out.batch_log_metrics['log_acc1'] == 12.0 + assert out.batch_log_metrics['log_acc2'] == 7.0 pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end'] assert pbar_metrics['pbar_acc1'] == 17.0