Skip to content

Commit

Permalink
add more test
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Aug 19, 2020
1 parent 6e084d1 commit 4bee960
Showing 1 changed file with 35 additions and 5 deletions.
40 changes: 35 additions & 5 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,9 +789,42 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
pytest.param(0.0), # this should run no sanity checks
pytest.param(1),
pytest.param(1.0),
pytest.param(0.3),
pytest.param(0.5),
pytest.param(5),
])
def test_num_sanity_val_steps(tmpdir, limit_val_batches):
"""
Test that num_sanity_val_steps!=-1 runs through all validation data once.
Makes sure the number of sanity check batches is clipped to limit_val_batches.
"""
model = EvalModelTemplate()
model.validation_step = model.validation_step__multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
num_sanity_val_steps = 4

trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=num_sanity_val_steps,
limit_val_batches=limit_val_batches, # should have no influence
max_steps=1,
)
assert trainer.num_sanity_val_steps == num_sanity_val_steps
val_dataloaders = model.val_dataloader__multiple_mixed_length()

with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked:
trainer.fit(model, val_dataloaders=val_dataloaders)
assert mocked.call_count == sum(
min(num_sanity_val_steps, num_batches) for num_batches in trainer.num_val_batches
)


@pytest.mark.parametrize(['limit_val_batches'], [
pytest.param(0.0), # this should run no sanity checks
pytest.param(1),
pytest.param(1.0),
pytest.param(0.3),
])
def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
"""
Test that num_sanity_val_steps=-1 runs through all validation data once.
Makes sure the number of sanity check batches is clipped to limit_val_batches.
Expand All @@ -810,10 +843,7 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches):

with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked:
trainer.fit(model, val_dataloaders=val_dataloaders)
if isinstance(limit_val_batches, float):
assert mocked.call_count == sum(len(dl) * limit_val_batches for dl in val_dataloaders)
if isinstance(limit_val_batches, int):
assert mocked.call_count == sum(limit_val_batches for dl in val_dataloaders)
assert mocked.call_count == sum(trainer.num_val_batches)


@pytest.mark.parametrize("trainer_kwargs,expected", [
Expand Down

0 comments on commit 4bee960

Please sign in to comment.