Skip to content

Commit

Permalink
remove obsolete self._device in Trainer (#1849)
Browse files Browse the repository at this point in the history
* remove unused device attribute

* dtype

* move on_gpu to model
  • Loading branch information
awaelchli committed May 17, 2020
1 parent b84b024 commit 4cdebf9
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 13 deletions.
15 changes: 10 additions & 5 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,6 @@ def __init__(self, *args, **kwargs):
self.logger = None
self.example_input_array = None

#: True if your model is currently running on GPUs.
#: Useful to set flags around the LightningModule for different CPU vs GPU behavior.
self.on_gpu = False

#: True if using dp
self.use_dp = False

Expand All @@ -72,10 +68,19 @@ def __init__(self, *args, **kwargs):
self.hparams = None

#: Current dtype
self._dtype = torch.FloatTensor
self._dtype = torch.float

#: device reference
self._device = torch.device('cpu')

@property
def on_gpu(self):
"""
True if your model is currently running on GPUs.
Useful to set flags around the LightningModule for different CPU vs GPU behavior.
"""
return self.device.type == 'cuda'

def print(self, *args, **kwargs) -> None:
r"""
Prints only from process 0. Use this in any distributed mode to log only once.
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,6 @@ def ddp_train(self, process_idx, model):
# copy model to each gpu
if self.on_gpu:
self.root_gpu = process_idx
self._device = torch.device('cuda', self.root_gpu)
torch.cuda.set_device(self.root_gpu)
model.cuda(self.root_gpu)

Expand Down
6 changes: 0 additions & 6 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,6 @@ def copy_trainer_model_properties(self, model):

for m in [model, ref_model]:
m.trainer = self
m.on_gpu = self.on_gpu
m.use_dp = self.use_dp
m.use_ddp2 = self.use_ddp2
m.use_ddp = self.use_ddp
Expand All @@ -432,7 +431,6 @@ def copy_trainer_model_properties(self, model):
m.use_tpu = self.use_tpu
m.tpu_local_core_rank = self.tpu_local_core_rank
m.tpu_global_core_rank = self.tpu_global_core_rank
m._device = self._device

def transfer_batch_to_tpu(self, batch):
return self.__transfer_data_to_device(batch, device='tpu')
Expand Down Expand Up @@ -488,7 +486,6 @@ def __transfer_data_to_device(self, batch, device, gpu_id=None):

def single_gpu_train(self, model):
model.cuda(self.root_gpu)
self._device = torch.device('cuda', self.root_gpu)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
Expand All @@ -505,7 +502,6 @@ def single_gpu_train(self, model):
def tpu_train(self, tpu_core_idx, model):
# put model on tpu
model.to(xm.xla_device())
self._device = xm.xla_device()

# get the appropriate tpu ranks
self.tpu_local_core_rank = xm.get_local_ordinal()
Expand Down Expand Up @@ -545,7 +541,6 @@ def dp_train(self, model):
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)

model.cuda(self.root_gpu)
self._device = torch.device('cuda', self.root_gpu)

# hack forward to do autocast for the user
model_autocast_original_forward = model.forward
Expand Down Expand Up @@ -585,7 +580,6 @@ def horovod_train(self, model):
assert self.root_gpu == hvd.local_rank()
torch.cuda.set_device(self.root_gpu)
model.cuda(self.root_gpu)
self._device = torch.device('cuda', self.root_gpu)

# avoid duplicating progress bar
if hvd.rank() != 0 and self.progress_bar_callback is not None:
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,6 @@ def __init__(
# distributed backend choice
self.distributed_backend = distributed_backend
self.set_distributed_mode(distributed_backend)
self._device = torch.device('cpu')

# override dist backend when using tpus
if self.on_tpu:
Expand Down

0 comments on commit 4cdebf9

Please sign in to comment.