Skip to content

Commit

Permalink
device property
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed May 12, 2020
1 parent 7b60d49 commit 57a9b75
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 8 deletions.
6 changes: 5 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ def __init__(self, *args, **kwargs):
self.hparams = None

#: device reference
self.device = None
self._device = None

@property
def device(self) -> Union[None, str, object]:
return self._device

def print(self, *args, **kwargs) -> None:
r"""
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def ddp_train(self, process_idx, model):
# copy model to each gpu
if self.on_gpu:
self.root_gpu = process_idx
self.device = torch.device('cuda', self.root_gpu)
self._device = torch.device('cuda', self.root_gpu)
torch.cuda.set_device(self.root_gpu)
model.cuda(self.root_gpu)

Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def copy_trainer_model_properties(self, model):
m.use_tpu = self.use_tpu
m.tpu_local_core_rank = self.tpu_local_core_rank
m.tpu_global_core_rank = self.tpu_global_core_rank
m.device = self.device
m._device = self._device

def transfer_batch_to_tpu(self, batch):
return self.__transfer_data_to_device(batch, device='tpu')
Expand Down Expand Up @@ -484,7 +484,7 @@ def __transfer_data_to_device(self, batch, device, gpu_id=None):

def single_gpu_train(self, model):
model.cuda(self.root_gpu)
self.device = torch.device('cuda', self.root_gpu)
self._device = torch.device('cuda', self.root_gpu)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
Expand All @@ -501,7 +501,7 @@ def single_gpu_train(self, model):
def tpu_train(self, tpu_core_idx, model):
# put model on tpu
model.to(xm.xla_device())
self.device = xm.xla_device()
self._device = xm.xla_device()

# get the appropriate tpu ranks
self.tpu_local_core_rank = xm.get_local_ordinal()
Expand Down Expand Up @@ -539,7 +539,7 @@ def dp_train(self, model):
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)

model.cuda(self.root_gpu)
self.device = torch.device('cuda', self.root_gpu)
self._device = torch.device('cuda', self.root_gpu)

# hack forward to do autocast for the user
model_autocast_original_forward = model.forward
Expand Down Expand Up @@ -579,7 +579,7 @@ def horovod_train(self, model):
assert self.root_gpu == hvd.local_rank()
torch.cuda.set_device(self.root_gpu)
model.cuda(self.root_gpu)
self.device = torch.device('cuda', self.root_gpu)
self._device = torch.device('cuda', self.root_gpu)

# avoid duplicating progress bar
if hvd.rank() != 0 and self.progress_bar_callback is not None:
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def __init__(
# distributed backend choice
self.distributed_backend = distributed_backend
self.set_distributed_mode(distributed_backend)
self.device = torch.device('cpu')
self._device = torch.device('cpu')

# override dist backend when using tpus
if self.on_tpu:
Expand Down Expand Up @@ -519,6 +519,10 @@ def __init__(
# Callback system
self.on_init_end()

@property
def device(self) -> Union[None, str, object]:
return self._device

@property
def slurm_job_id(self) -> int:
try:
Expand Down

0 comments on commit 57a9b75

Please sign in to comment.