From 84175f6f77536ed6dab5bfe3817c03cebf2afad9 Mon Sep 17 00:00:00 2001 From: Devashish Shankar Date: Sun, 31 May 2020 02:12:36 +0530 Subject: [PATCH] Fix tests --- tests/models/test_horovod.py | 4 ++-- tests/trainer/test_dataloaders.py | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 9dadefd116a48a..4e5fe0ef815529 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -64,7 +64,7 @@ def test_horovod_cpu(tmpdir): train_percent_check=0.4, val_percent_check=0.2, distributed_backend='horovod', - deterministic=True + deterministic=True, ) _run_horovod(trainer_options) @@ -80,7 +80,7 @@ def test_horovod_cpu_implicit(tmpdir): max_epochs=1, train_percent_check=0.4, val_percent_check=0.2, - deterministic=True + deterministic=True, ) _run_horovod(trainer_options) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index d4e2058fcd717d..e897e959b3ee50 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -393,13 +393,18 @@ class CustomDummyObj: assert isinstance(result, CustomDataLoader) assert hasattr(result, 'dummy_kwarg') + # Shuffled DataLoader should also work + result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), train=True) + assert isinstance(result, torch.utils.data.DataLoader) + assert isinstance(result, CustomDataLoader) + assert hasattr(result, 'dummy_kwarg') + class CustomSampler(torch.utils.data.Sampler): pass # Should raise an error if existing sampler is being replaced with pytest.raises(MisconfigurationException, match='DistributedSampler'): - trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), train=True) - trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), sampler=CustomSampler()), + trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), train=True)