Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed May 25, 2020
1 parent e2b97a8 commit aeee807
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions tests/core/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@ def test_auto_move_data(src_device, dest_device):
""" Test that the decorator moves the data to the device the model is on. """

class CurrentModel(EvalModelTemplate):
pass

# @auto_move_data
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
# apply the decorator
CurrentModel.forward = auto_move_data(CurrentModel.forward)

model = CurrentModel().to(dest_device)
# setattr(model, 'forward', auto_move_data(model.forward))
model.forward = auto_move_data(model.forward) # apply the decorator
model = CurrentModel()
model = model.to(dest_device)
model.prepare_data()
loader = model.train_dataloader()

x, y, = next(iter(loader))
x = x.flatten(1)

# test that data on source device gets moved to destination device
x = x.to(src_device)
assert model(x).device == dest_device, "Automoving data to same device as model failed"

0 comments on commit aeee807

Please sign in to comment.