From 2a849c3f5b9ee7bd2d4e7115cd9268179d7316de Mon Sep 17 00:00:00 2001 From: Devashish Shankar Date: Sun, 31 May 2020 02:12:36 +0530 Subject: [PATCH] Fix tests --- tests/models/test_horovod.py | 4 ++-- tests/trainer/test_dataloaders.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 9dadefd116a48..4e5fe0ef81552 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -64,7 +64,7 @@ def test_horovod_cpu(tmpdir): train_percent_check=0.4, val_percent_check=0.2, distributed_backend='horovod', - deterministic=True + deterministic=True, ) _run_horovod(trainer_options) @@ -80,7 +80,7 @@ def test_horovod_cpu_implicit(tmpdir): max_epochs=1, train_percent_check=0.4, val_percent_check=0.2, - deterministic=True + deterministic=True, ) _run_horovod(trainer_options) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index d4e2058fcd717..e8b8008879289 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -393,14 +393,20 @@ class CustomDummyObj: assert isinstance(result, CustomDataLoader) assert hasattr(result, 'dummy_kwarg') + # Shuffled DataLoader should also work + result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), train=True) + assert isinstance(result, torch.utils.data.DataLoader) + assert isinstance(result, CustomDataLoader) + assert hasattr(result, 'dummy_kwarg') + class CustomSampler(torch.utils.data.Sampler): pass # Should raise an error if existing sampler is being replaced with pytest.raises(MisconfigurationException, match='DistributedSampler'): - trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), train=True) - trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), sampler=CustomSampler()), - train=True) + trainer.auto_add_sampler( + CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), + train=True) @pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs')