Skip to content

Commit

Permalink
device
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed May 20, 2020
1 parent b022f0d commit 405fcf1
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions tests/core/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,13 @@ def forward(self, x, y):
assert model.training


def test_rnn_summary_shapes():
model = ParityRNN()
@pytest.mark.parametrize('device', [
torch.device('cpu'),
torch.device('cuda', 0)
])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
def test_rnn_summary_shapes(device):
model = ParityRNN().to(device)

b = 3
t = 5
Expand Down

0 comments on commit 405fcf1

Please sign in to comment.