Skip to content

Commit

Permalink
re-use apply_to_collection function for parsing collections
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed May 19, 2020
1 parent be3fa7e commit c884f68
Showing 1 changed file with 4 additions and 31 deletions.
35 changes: 4 additions & 31 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@
LightningDistributedDataParallel,
LightningDataParallel,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.distributed import rank_zero_only

Expand Down Expand Up @@ -446,41 +447,13 @@ def transfer_batch_to_gpu(self, batch: Any, gpu_id: int):
return self.__transfer_data_to_device(batch, device)

def __transfer_data_to_device(self, batch: Any, device: torch.device):

if self.is_overridden('transfer_batch_to_device'):
return self.get_model().transfer_batch_to_device(batch, device)

# base case: object can be directly moved using `to`
if callable(getattr(batch, 'to', None)):
return batch.to(device, non_blocking=True)

# when list
if isinstance(batch, list):
for i, x in enumerate(batch):
batch[i] = self.__transfer_data_to_device(x, device)
return batch

# when tuple
if isinstance(batch, tuple):
# when namedtuple
if hasattr(batch, '_fields'):
elem_type = type(batch)
return elem_type(*(self.__transfer_data_to_device(x, device) for x in batch))
else:
batch = list(batch)
for i, x in enumerate(batch):
batch[i] = self.__transfer_data_to_device(x, device)
return tuple(batch)

# when dict
if isinstance(batch, dict):
for k, v in batch.items():
batch[k] = self.__transfer_data_to_device(v, device)

return batch
def to(tensor):
return tensor.to(device, non_blocking=True)

# nothing matches, return the value as is without transform
return batch
return apply_to_collection(batch, dtype=torch.Tensor, function=to)

def single_gpu_train(self, model):
model.cuda(self.root_gpu)
Expand Down

0 comments on commit c884f68

Please sign in to comment.