Skip to content

Commit

Permalink
added more dm tests
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon committed Jul 30, 2020
1 parent 20ec0af commit 9c10e67
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 31 deletions.
3 changes: 3 additions & 0 deletions tests/base/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class TrialMNISTDataModule(LightningDataModule):
def __init__(self, data_dir: str = './'):
super().__init__()
self.data_dir = data_dir
self.non_picklable = None

def prepare_data(self):
TrialMNIST(self.data_dir, train=True, download=True)
Expand All @@ -25,6 +26,8 @@ def setup(self, stage: str = None):
self.mnist_test = TrialMNIST(root=self.data_dir, train=False, num_samples=32, download=True)
self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape)

self.non_picklable = lambda x: x**2

def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)

Expand Down
31 changes: 0 additions & 31 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,34 +279,3 @@ def test_full_loop_ddp_spawn(tmpdir):
result = trainer.test(datamodule=dm)
result = result[0]
assert result['test_acc'] > 0.8


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_full_loop_ddp_spawn_non_picklable(tmpdir):
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

reset_seed()

dm = TrialMNISTDataModule(tmpdir)
dm.non_pickle_thing = lambda x: x**2

model = EvalModelTemplate()

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=3,
weights_summary=None,
distributed_backend='ddp_spawn',
gpus=[0, 1]
)
trainer.fit(model, dm)

# fit model
result = trainer.fit(model)
assert result == 1

# test
result = trainer.test(datamodule=dm)
result = result[0]
assert result['test_acc'] > 0.8

0 comments on commit 9c10e67

Please sign in to comment.