Skip to content

Commit

Permalink
rename utility function
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Jun 2, 2020
1 parent 47b4693 commit ea9fab3
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
6 changes: 3 additions & 3 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer
from pytorch_lightning.utilities import transfer_batch_to_device
from pytorch_lightning.utilities import move_data_to_device


try:
Expand Down Expand Up @@ -196,7 +196,7 @@ def transfer_batch_to_device(self, batch, device)
batch and determines the target devices.
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.transfer_batch_to_device`
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
- :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection`
"""
return transfer_batch_to_device(batch, device)
return move_data_to_device(batch, device)
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
LightningDistributedDataParallel,
LightningDataParallel,
)
from pytorch_lightning.utilities import transfer_batch_to_device
from pytorch_lightning.utilities import move_data_to_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.distributed import rank_zero_only

Expand Down Expand Up @@ -112,7 +112,7 @@ def transfer_batch_to_tpu(self, batch: Any, tpu_id: Optional[int] = None):
the tensor on the TPU device.
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.transfer_batch_to_device`
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
"""
if not XLA_AVAILABLE:
raise MisconfigurationException(
Expand All @@ -134,7 +134,7 @@ def transfer_batch_to_gpu(self, batch: Any, gpu_id: Optional[int] = None):
the tensor on the GPU device.
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.transfer_batch_to_device`
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
"""
device = torch.device('cuda', gpu_id)
return self.__transfer_batch_to_device(batch, device)
Expand All @@ -143,7 +143,7 @@ def __transfer_batch_to_device(self, batch: Any, device: torch.device):
model = self.get_model()
if model is not None:
return model.transfer_batch_to_device(batch, device)
return transfer_batch_to_device(batch, device)
return move_data_to_device(batch, device)

def single_gpu_train(self, model):
model.cuda(self.root_gpu)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""General utilities"""

from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.apply_func import transfer_batch_to_device
from pytorch_lightning.utilities.apply_func import move_data_to_device
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable
return data


def transfer_batch_to_device(batch: Any, device: torch.device):
def move_data_to_device(batch: Any, device: torch.device):
"""
Transfers a collection of tensors to the given device.
Expand Down

0 comments on commit ea9fab3

Please sign in to comment.