Skip to content

Commit

Permalink
added more dm tests
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon committed Jul 30, 2020
1 parent bcc8720 commit 20ec0af
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,34 @@ def test_full_loop_ddp_spawn(tmpdir):
result = trainer.test(datamodule=dm)
result = result[0]
assert result['test_acc'] > 0.8


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_full_loop_ddp_spawn_non_picklable(tmpdir):
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

reset_seed()

dm = TrialMNISTDataModule(tmpdir)
dm.non_pickle_thing = lambda x: x**2

model = EvalModelTemplate()

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=3,
weights_summary=None,
distributed_backend='ddp_spawn',
gpus=[0, 1]
)
trainer.fit(model, dm)

# fit model
result = trainer.fit(model)
assert result == 1

# test
result = trainer.test(datamodule=dm)
result = result[0]
assert result['test_acc'] > 0.8

0 comments on commit 20ec0af

Please sign in to comment.