Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
devashishshankar committed May 30, 2020
1 parent 98aa87b commit 84175f6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
4 changes: 2 additions & 2 deletions tests/models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_horovod_cpu(tmpdir):
train_percent_check=0.4,
val_percent_check=0.2,
distributed_backend='horovod',
deterministic=True
deterministic=True,
)
_run_horovod(trainer_options)

Expand All @@ -80,7 +80,7 @@ def test_horovod_cpu_implicit(tmpdir):
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
deterministic=True
deterministic=True,
)
_run_horovod(trainer_options)

Expand Down
9 changes: 7 additions & 2 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,13 +393,18 @@ class CustomDummyObj:
assert isinstance(result, CustomDataLoader)
assert hasattr(result, 'dummy_kwarg')

# Shuffled DataLoader should also work
result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), train=True)
assert isinstance(result, torch.utils.data.DataLoader)
assert isinstance(result, CustomDataLoader)
assert hasattr(result, 'dummy_kwarg')

class CustomSampler(torch.utils.data.Sampler):
pass

# Should raise an error if existing sampler is being replaced
with pytest.raises(MisconfigurationException, match='DistributedSampler'):
trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), train=True)
trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), sampler=CustomSampler()),
trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))),
train=True)


Expand Down

0 comments on commit 84175f6

Please sign in to comment.