diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index c251983c6aac1..d9e2500707fc8 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -277,11 +277,11 @@ def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, results = trainer.fit(model) assert results == 1 - assert trainer.num_training_batches == 0 if limit_train_batches == 0.0 else float('inf') - assert trainer.num_val_batches[0] == 0 if limit_val_batches == 0.0 else float('inf') + assert trainer.num_training_batches == (0 if limit_train_batches == 0.0 else float('inf')) + assert trainer.num_val_batches[0] == (0 if limit_val_batches == 0.0 else float('inf')) trainer.test(ckpt_path=None) - assert trainer.num_test_batches[0] == 0 if limit_test_batches == 0.0 else float('inf') + assert trainer.num_test_batches[0] == (0 if limit_test_batches == 0.0 else float('inf')) @pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [