Skip to content

Commit

Permalink
tpu id
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed May 18, 2020
1 parent f62660d commit c7e4493
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,8 @@ def copy_trainer_model_properties(self, model):
m.tpu_local_core_rank = self.tpu_local_core_rank
m.tpu_global_core_rank = self.tpu_global_core_rank

def transfer_batch_to_tpu(self, batch: Any):
device = xm.xla_device() if XLA_AVAILABLE else torch.device('cpu')
def transfer_batch_to_tpu(self, batch: Any, tpu_id: int = None):
device = xm.xla_device(tpu_id) if XLA_AVAILABLE else torch.device('cpu')
return self.__transfer_data_to_device(batch, device)

def transfer_batch_to_gpu(self, batch: Any, gpu_id: int):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode:

# TPU data transfer
if self.use_tpu:
batch = self.transfer_batch_to_tpu(batch)
batch = self.transfer_batch_to_tpu(batch, self.tpu_id)
args[0] = batch

# CPU, TPU or gpu step
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens):

# TPU support
elif self.use_tpu:
batch = self.transfer_batch_to_tpu(batch)
batch = self.transfer_batch_to_tpu(batch, self.tpu_id)
args[0] = batch
output = self.model.training_step(*args)

Expand Down

0 comments on commit c7e4493

Please sign in to comment.