Skip to content

Commit

Permalink
add and update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Aug 2, 2020
1 parent 730bc4e commit acbe573
Showing 1 changed file with 67 additions and 8 deletions.
75 changes: 67 additions & 8 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,65 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'


@pytest.mark.parametrize(
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
[
pytest.param(0.0, 0.0, 0.0),
pytest.param(1.0, 1.0, 1.0),
]
)
def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent"""
model = EvalModelTemplate()
model.train_dataloader = model.train_dataloader__infinite
model.val_dataloader = model.val_dataloader__infinite
model.test_dataloader = model.test_dataloader__infinite

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
limit_test_batches=limit_test_batches,
)

results = trainer.fit(model)
assert results == 1
assert trainer.num_training_batches == 0 if limit_train_batches == 0.0 else float('inf')
assert trainer.num_val_batches[0] == 0 if limit_val_batches == 0.0 else float('inf')

trainer.test(ckpt_path=None)
assert trainer.num_test_batches[0] == 0 if limit_test_batches == 0.0 else float('inf')


@pytest.mark.parametrize(
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
[pytest.param(10, 10, 10)]
)
def test_inf_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number"""
model = EvalModelTemplate()
model.train_dataloader = model.train_dataloader__infinite
model.val_dataloader = model.val_dataloader__infinite
model.test_dataloader = model.test_dataloader__infinite

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
limit_test_batches=limit_test_batches,
)

results = trainer.fit(model)
assert results
assert trainer.num_training_batches == limit_train_batches
assert trainer.num_val_batches[0] == limit_val_batches

trainer.test(ckpt_path=None)
assert trainer.num_test_batches[0] == limit_test_batches


@pytest.mark.parametrize(
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
[
Expand All @@ -265,7 +324,7 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
]
)
def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify num_batches for val & test dataloaders passed with batch limit in percent"""
"""Verify num_batches for train, val & test dataloaders passed with batch limit in percent"""
model = EvalModelTemplate()
model.val_dataloader = model.val_dataloader__multiple_mixed_length
model.test_dataloader = model.test_dataloader__multiple_mixed_length
Expand Down Expand Up @@ -306,7 +365,7 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim
]
)
def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify num_batches for val & test dataloaders passed with batch limit as number"""
"""Verify num_batches for train, val & test dataloaders passed with batch limit as number"""
os.environ['PL_DEV_DEBUG'] = '1'

model = EvalModelTemplate()
Expand Down Expand Up @@ -435,7 +494,7 @@ def test_train_inf_dataloader_error(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.fit(model)


Expand All @@ -446,7 +505,7 @@ def test_val_inf_dataloader_error(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.fit(model)


Expand All @@ -457,7 +516,7 @@ def test_test_inf_dataloader_error(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.test(model)


Expand Down Expand Up @@ -737,7 +796,7 @@ def test_train_dataloader_not_implemented_error_failed(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.fit(model)


Expand All @@ -748,7 +807,7 @@ def test_val_dataloader_not_implemented_error_failed(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_val_batches=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.fit(model)


Expand All @@ -759,5 +818,5 @@ def test_test_dataloader_not_implemented_error_failed(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_test_batches=0.5)

with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
trainer.test(model)

0 comments on commit acbe573

Please sign in to comment.