Skip to content

Commit

Permalink
test device wip
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed May 19, 2020
1 parent c2eca3b commit b022f0d
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions tests/core/test_memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
import pytest

from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.core.memory import ModelSummary
Expand All @@ -11,7 +12,13 @@
# Device (CPU, GPU, amp)
# Different input shapes (tensor, nested lists, nested tuples, unknowns)

def test_linear_model_summary_shapes():

@pytest.mark.parametrize('device', [
torch.device('cpu'),
torch.device('cuda', 0)
])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
def test_linear_model_summary_shapes(device):
""" Test that the model summary correctly computes the input- and output shapes. """

class CurrentModel(LightningModule):
Expand All @@ -34,7 +41,8 @@ def forward(self, x, y):
out = self.combine(out)
return out

model = CurrentModel()
model = CurrentModel().to(device)
model.train()
summary = ModelSummary(model)
assert summary.in_sizes == [
[2, 10], # layer 2
Expand All @@ -50,6 +58,7 @@ def forward(self, x, y):
[2, 7], # relu
'unknown'
]
assert model.training


def test_rnn_summary_shapes():
Expand Down

0 comments on commit b022f0d

Please sign in to comment.