From c7e44934ac8219b89a3214798c7a584886a6a8d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 May 2020 05:20:01 +0200 Subject: [PATCH] tpu id --- pytorch_lightning/trainer/distrib_parts.py | 4 ++-- pytorch_lightning/trainer/evaluation_loop.py | 2 +- pytorch_lightning/trainer/training_loop.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 77e8139258ba8..335eabeb87359 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -432,8 +432,8 @@ def copy_trainer_model_properties(self, model): m.tpu_local_core_rank = self.tpu_local_core_rank m.tpu_global_core_rank = self.tpu_global_core_rank - def transfer_batch_to_tpu(self, batch: Any): - device = xm.xla_device() if XLA_AVAILABLE else torch.device('cpu') + def transfer_batch_to_tpu(self, batch: Any, tpu_id: int = None): + device = xm.xla_device(tpu_id) if XLA_AVAILABLE else torch.device('cpu') return self.__transfer_data_to_device(batch, device) def transfer_batch_to_gpu(self, batch: Any, gpu_id: int): diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 58d16632b3c57..22527ef8e1d64 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -434,7 +434,7 @@ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: # TPU data transfer if self.use_tpu: - batch = self.transfer_batch_to_tpu(batch) + batch = self.transfer_batch_to_tpu(batch, self.tpu_id) args[0] = batch # CPU, TPU or gpu step diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index cbe1186480c28..47e2045d36bdb 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -729,7 +729,7 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens): # TPU support elif self.use_tpu: - batch = self.transfer_batch_to_tpu(batch) + batch = self.transfer_batch_to_tpu(batch, self.tpu_id) args[0] = batch output = self.model.training_step(*args)