From 56d8f8a694c64675647c3ce37aec66664417a861 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 15 May 2020 00:39:20 +0200 Subject: [PATCH 01/18] refactor and added hook variant a variant b add test revert rename add changelog docs --- CHANGELOG.md | 2 + pytorch_lightning/core/hooks.py | 38 ++++++++++ pytorch_lightning/trainer/distrib_parts.py | 85 +++++++++++++++------- tests/models/test_hooks.py | 31 ++++++++ 4 files changed, 128 insertions(+), 28 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f0d53d1a73369..ec1a872d80936 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -88,6 +88,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added using `store_true` for bool args ([#1822](https://github.com/PyTorchLightning/pytorch-lightning/pull/1822), [#1842](https://github.com/PyTorchLightning/pytorch-lightning/pull/1842)) - Added dummy logger for internally disabling logging for some features ([#1836](https://github.com/PyTorchLightning/pytorch-lightning/pull/1836)) +- Added a model hook `transfer_batch_to_device` that enables moving custom data structures to the target device ([1756](https://github.com/PyTorchLightning/pytorch-lightning/pull/1756)). + ### Changed - Enable `non-blocking` for device transfers to GPU ([#1843](https://github.com/PyTorchLightning/pytorch-lightning/pull/1843)) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 1a3f05be11c50..6d35b2f0fb974 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -153,3 +153,41 @@ def backward(self, use_amp, loss, optimizer): scaled_loss.backward() else: loss.backward() + + def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: + """ + Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors + wrapped in a custom data structure. + Lightning only calls the hook if it does not recognize the data type of your batch as one of + + - :class:`torch.Tensor` + - :class:`list` + - :class:`dict` + - :class:`tuple` + - ``torchtext.data.Batch`` (COMING SOON) + + These data types (and any arbitrary nesting of them) are supported out of the box. + For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...). + + Example:: + + def transfer_batch_to_device(self, batch, device) + if isinstance(batch, CustomBatch): + # move all tensors in your custom data structure to the device + batch.samples = batch.samples.to(device) + batch.targets = batch.targets.to(device) + return batch + + Args: + batch: A batch of data that needs to be transferred to a new device. + device: The target device as defined in PyTorch. + + Returns: + A reference to the data on the new device. + + Note: + This hook should only transfer the data and not modify it, nor should it move the data to + any other device than the one passed in as argument. + The :class:`~pytorch_lightning.trainer.trainer.Trainer` already takes care of splitting the + batch and determines the target devices. + """ diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index c2103d40e6cd2..7c4db72f055b7 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -99,35 +99,64 @@ def copy_trainer_model_properties(self, model): m.tpu_local_core_rank = self.tpu_local_core_rank m.tpu_global_core_rank = self.tpu_global_core_rank - def transfer_batch_to_tpu(self, batch): - return self.__transfer_data_to_device(batch, device='tpu') - - def transfer_batch_to_gpu(self, batch, gpu_id): - return self.__transfer_data_to_device(batch, device='gpu', gpu_id=gpu_id) - - def __transfer_data_to_device(self, batch, device, gpu_id=None): - if device == 'tpu' and XLA_AVAILABLE: - # base case: object can be directly moved using `to` - if callable(getattr(batch, 'to', None)): - xla_device = xm.xla_device(self.tpu_id) if self.tpu_id is not None else xm.xla_device() - return batch.to(xla_device) - - if device == 'gpu': - # base case: object can be directly moved using `cuda` or `to` - if callable(getattr(batch, 'cuda', None)): - # non_blocking will be ignored if tensor is not pinned. - # so we can always set it to True - return batch.cuda(gpu_id, non_blocking=True) - - if callable(getattr(batch, 'to', None)): - # non_blocking will be ignored if tensor is not pinned. - # so we can always set it to True - return batch.to(torch.device('cuda', gpu_id), non_blocking=True) + def transfer_batch_to_tpu(self, batch: Any): + device = xm.xla_device() if XLA_AVAILABLE else torch.device('cpu') + return self.__transfer_data_to_device(batch, device) + + def transfer_batch_to_gpu(self, batch: Any, gpu_id: int): + device = torch.device('cuda', gpu_id) + return self.__transfer_data_to_device(batch, device) + + def __transfer_data_to_device(self, batch: Any, device: torch.device): + if callable(getattr(batch, 'to', None)): + return batch.to(device) + + # when list + if isinstance(batch, list): + for i, x in enumerate(batch): + batch[i] = self.__transfer_data_to_device(x, device) + return batch + + # when tuple + if isinstance(batch, tuple): + # when namedtuple + if hasattr(batch, '_fields'): + elem_type = type(batch) + return elem_type(*(self.__transfer_data_to_device(x, device) for x in batch)) + else: + batch = list(batch) + for i, x in enumerate(batch): + batch[i] = self.__transfer_data_to_device(x, device) + return tuple(batch) + + # when dict + if isinstance(batch, dict): + for k, v in batch.items(): + batch[k] = self.__transfer_data_to_device(v, device) + + return batch + + # check if the model hook can move the data + model = self.get_model() + if model is not None and self.is_overridden('transfer_batch_to_device', model): + batch = model.transfer_batch_to_device(batch, device) + + # nothing matches, return the value as is without transform + return batch + + def __transfer_data_to_device(self, batch: Any, device: torch.device): + + if self.is_overriden('transfer_batch_to_device'): + return self.get_model().transfer_batch_to_device(batch, device) + + # base case: object can be directly moved using `to` + if callable(getattr(batch, 'to', None)): + return batch.to(device, non_blocking=True) # when list if isinstance(batch, list): for i, x in enumerate(batch): - batch[i] = self.__transfer_data_to_device(x, device, gpu_id) + batch[i] = self.__transfer_data_to_device(x, device) return batch # when tuple @@ -135,17 +164,17 @@ def __transfer_data_to_device(self, batch, device, gpu_id=None): # when namedtuple if hasattr(batch, '_fields'): elem_type = type(batch) - return elem_type(*(self.__transfer_data_to_device(x, device, gpu_id) for x in batch)) + return elem_type(*(self.__transfer_data_to_device(x, device) for x in batch)) else: batch = list(batch) for i, x in enumerate(batch): - batch[i] = self.__transfer_data_to_device(x, device, gpu_id) + batch[i] = self.__transfer_data_to_device(x, device) return tuple(batch) # when dict if isinstance(batch, dict): for k, v in batch.items(): - batch[k] = self.__transfer_data_to_device(v, device, gpu_id) + batch[k] = self.__transfer_data_to_device(v, device) return batch diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 568a8eae437c2..95fd3c17edace 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -1,3 +1,5 @@ +from unittest.mock import MagicMock + import pytest import torch @@ -68,3 +70,32 @@ def training_epoch_end(self, outputs): # metrics are kept after each epoch for i in range(num_epochs): assert metrics[f'epoch_metric_{i}'] == i + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_transfer_batch_hook(): + + class CustomBatch: + + def __init__(self, data): + self.samples = data[0] + self.targets = data[1] + + class CurrentTestModel(EvalModelTemplate): + + def transfer_batch_to_device(self, batch, device): + if isinstance(batch, CustomBatch): + batch.samples = batch.samples.to(device) + batch.targets = batch.targets.to(device) + return batch + + model = CurrentTestModel(tutils.get_default_hparams()) + batch = CustomBatch((torch.zeros(5, 28), torch.ones(5, 1, dtype=torch.long))) + + trainer = Trainer() + # running .fit() would require us to implement custom data loaders, we mock the model reference instead + trainer.get_model = MagicMock(return_value=model) + + batch_gpu = trainer.transfer_batch_to_gpu(batch, 0) + device = torch.device('cuda', 0) + assert batch_gpu.samples.device == batch_gpu.targets.device == device From c64057b38a6c6501dca447e0b8c20c3cc27d93f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 15 May 2020 19:14:18 +0200 Subject: [PATCH 02/18] resolve merge duplication --- pytorch_lightning/trainer/distrib_parts.py | 37 ---------------------- 1 file changed, 37 deletions(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 7c4db72f055b7..34e184466448d 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -107,43 +107,6 @@ def transfer_batch_to_gpu(self, batch: Any, gpu_id: int): device = torch.device('cuda', gpu_id) return self.__transfer_data_to_device(batch, device) - def __transfer_data_to_device(self, batch: Any, device: torch.device): - if callable(getattr(batch, 'to', None)): - return batch.to(device) - - # when list - if isinstance(batch, list): - for i, x in enumerate(batch): - batch[i] = self.__transfer_data_to_device(x, device) - return batch - - # when tuple - if isinstance(batch, tuple): - # when namedtuple - if hasattr(batch, '_fields'): - elem_type = type(batch) - return elem_type(*(self.__transfer_data_to_device(x, device) for x in batch)) - else: - batch = list(batch) - for i, x in enumerate(batch): - batch[i] = self.__transfer_data_to_device(x, device) - return tuple(batch) - - # when dict - if isinstance(batch, dict): - for k, v in batch.items(): - batch[k] = self.__transfer_data_to_device(v, device) - - return batch - - # check if the model hook can move the data - model = self.get_model() - if model is not None and self.is_overridden('transfer_batch_to_device', model): - batch = model.transfer_batch_to_device(batch, device) - - # nothing matches, return the value as is without transform - return batch - def __transfer_data_to_device(self, batch: Any, device: torch.device): if self.is_overriden('transfer_batch_to_device'): From 0b4fc00b01a7615c507a75c01580b0fec4085dee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 15 May 2020 19:21:50 +0200 Subject: [PATCH 03/18] overridden typo --- pytorch_lightning/trainer/distrib_parts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 34e184466448d..de53a8cf510cd 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -109,7 +109,7 @@ def transfer_batch_to_gpu(self, batch: Any, gpu_id: int): def __transfer_data_to_device(self, batch: Any, device: torch.device): - if self.is_overriden('transfer_batch_to_device'): + if self.is_overridden('transfer_batch_to_device'): return self.get_model().transfer_batch_to_device(batch, device) # base case: object can be directly moved using `to` From e3675606bee72af3e62ba259b871c7d0936c1456 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 15 May 2020 20:39:48 +0200 Subject: [PATCH 04/18] fix test --- tests/models/test_hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 95fd3c17edace..ddf72ff183c19 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -89,7 +89,7 @@ def transfer_batch_to_device(self, batch, device): batch.targets = batch.targets.to(device) return batch - model = CurrentTestModel(tutils.get_default_hparams()) + model = CurrentTestModel() batch = CustomBatch((torch.zeros(5, 28), torch.ones(5, 1, dtype=torch.long))) trainer = Trainer() From 365c2da82c328abb6ec573951461d274355731cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 May 2020 05:20:01 +0200 Subject: [PATCH 05/18] tpu id --- pytorch_lightning/trainer/distrib_parts.py | 4 ++-- pytorch_lightning/trainer/evaluation_loop.py | 2 +- pytorch_lightning/trainer/training_loop.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index de53a8cf510cd..15179ebe291ed 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -99,8 +99,8 @@ def copy_trainer_model_properties(self, model): m.tpu_local_core_rank = self.tpu_local_core_rank m.tpu_global_core_rank = self.tpu_global_core_rank - def transfer_batch_to_tpu(self, batch: Any): - device = xm.xla_device() if XLA_AVAILABLE else torch.device('cpu') + def transfer_batch_to_tpu(self, batch: Any, tpu_id: int = None): + device = xm.xla_device(tpu_id) if XLA_AVAILABLE else torch.device('cpu') return self.__transfer_data_to_device(batch, device) def transfer_batch_to_gpu(self, batch: Any, gpu_id: int): diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index a8c866f9901a1..0676e28bcdaaf 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -434,7 +434,7 @@ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: # TPU data transfer if self.use_tpu: - batch = self.transfer_batch_to_tpu(batch) + batch = self.transfer_batch_to_tpu(batch, self.tpu_id) args[0] = batch # CPU, TPU or gpu step diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 4961e58093979..b7286797e23c2 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -753,7 +753,7 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens): # TPU support elif self.use_tpu: - batch = self.transfer_batch_to_tpu(batch) + batch = self.transfer_batch_to_tpu(batch, self.tpu_id) args[0] = batch output = self.model.training_step(*args) From 1b8de2990fd9de0f643207b5a3cc01d0d807c07f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 20 May 2020 00:11:15 +0200 Subject: [PATCH 06/18] raise if TPU not available --- pytorch_lightning/trainer/distrib_parts.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 15179ebe291ed..e81a8bca00222 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -100,7 +100,12 @@ def copy_trainer_model_properties(self, model): m.tpu_global_core_rank = self.tpu_global_core_rank def transfer_batch_to_tpu(self, batch: Any, tpu_id: int = None): - device = xm.xla_device(tpu_id) if XLA_AVAILABLE else torch.device('cpu') + if not XLA_AVAILABLE: + raise MisconfigurationException( + 'Requested to transfer batch to TPU but XLA is not available.' + ' Are you sure this machine has TPUs?' + ) + device = xm.xla_device(tpu_id) return self.__transfer_data_to_device(batch, device) def transfer_batch_to_gpu(self, batch: Any, gpu_id: int): From e3de6e77c7f2a71d003114a1add9c344534ecfc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 20 May 2020 00:41:02 +0200 Subject: [PATCH 07/18] re-use apply_to_collection function for parsing collections --- pytorch_lightning/trainer/distrib_parts.py | 35 +++------------------- 1 file changed, 4 insertions(+), 31 deletions(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index e81a8bca00222..bbf41f0e0e711 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -18,6 +18,7 @@ LightningDistributedDataParallel, LightningDataParallel, ) +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.distributed import rank_zero_only @@ -113,41 +114,13 @@ def transfer_batch_to_gpu(self, batch: Any, gpu_id: int): return self.__transfer_data_to_device(batch, device) def __transfer_data_to_device(self, batch: Any, device: torch.device): - if self.is_overridden('transfer_batch_to_device'): return self.get_model().transfer_batch_to_device(batch, device) - # base case: object can be directly moved using `to` - if callable(getattr(batch, 'to', None)): - return batch.to(device, non_blocking=True) - - # when list - if isinstance(batch, list): - for i, x in enumerate(batch): - batch[i] = self.__transfer_data_to_device(x, device) - return batch - - # when tuple - if isinstance(batch, tuple): - # when namedtuple - if hasattr(batch, '_fields'): - elem_type = type(batch) - return elem_type(*(self.__transfer_data_to_device(x, device) for x in batch)) - else: - batch = list(batch) - for i, x in enumerate(batch): - batch[i] = self.__transfer_data_to_device(x, device) - return tuple(batch) - - # when dict - if isinstance(batch, dict): - for k, v in batch.items(): - batch[k] = self.__transfer_data_to_device(v, device) - - return batch + def to(tensor): + return tensor.to(device, non_blocking=True) - # nothing matches, return the value as is without transform - return batch + return apply_to_collection(batch, dtype=torch.Tensor, function=to) def single_gpu_train(self, model): model.cuda(self.root_gpu) From 1c1b155b4b65b686026a1b55dbeca3cd2b36894b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 20 May 2020 00:42:57 +0200 Subject: [PATCH 08/18] comment --- pytorch_lightning/trainer/distrib_parts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index bbf41f0e0e711..4684f503d6225 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -106,7 +106,7 @@ def transfer_batch_to_tpu(self, batch: Any, tpu_id: int = None): 'Requested to transfer batch to TPU but XLA is not available.' ' Are you sure this machine has TPUs?' ) - device = xm.xla_device(tpu_id) + device = xm.xla_device(tpu_id) # None will use all available devices return self.__transfer_data_to_device(batch, device) def transfer_batch_to_gpu(self, batch: Any, gpu_id: int): From 7f31a7898aa20ffe5a2cd742e5f0c88ca2f0e7b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 20 May 2020 01:31:53 +0200 Subject: [PATCH 09/18] make utility function available to user --- pytorch_lightning/__init__.py | 4 ++-- pytorch_lightning/core/hooks.py | 7 ++++++- pytorch_lightning/trainer/distrib_parts.py | 13 +++++-------- pytorch_lightning/utilities/__init__.py | 1 + pytorch_lightning/utilities/apply_func.py | 8 ++++++++ 5 files changed, 22 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index 3a6ffc1d7f527..74876ca471d8c 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -60,8 +60,8 @@ 'Trainer', 'LightningModule', 'Callback', - 'data_loader' - 'seed_everything' + 'data_loader', + 'seed_everything', ] # necessary for regular bolts imports. Skip exception since bolts is not always installed diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 6d35b2f0fb974..317676dafbc2a 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -166,7 +166,8 @@ def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: - :class:`tuple` - ``torchtext.data.Batch`` (COMING SOON) - These data types (and any arbitrary nesting of them) are supported out of the box. + These data types (and any arbitrary nesting of them) are supported out of the box + (see :func:`~pytorch_lightning.utilities.apply_func.transfer_batch_to_device`). For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...). Example:: @@ -190,4 +191,8 @@ def transfer_batch_to_device(self, batch, device) any other device than the one passed in as argument. The :class:`~pytorch_lightning.trainer.trainer.Trainer` already takes care of splitting the batch and determines the target devices. + + See Also: + - :func:`~pytorch_lightning.utilities.apply_func.transfer_batch_to_device` + - :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection` """ diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 4684f503d6225..cf2b6a01ee2f7 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -18,7 +18,7 @@ LightningDistributedDataParallel, LightningDataParallel, ) -from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities import transfer_batch_to_device from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.distributed import rank_zero_only @@ -111,16 +111,13 @@ def transfer_batch_to_tpu(self, batch: Any, tpu_id: int = None): def transfer_batch_to_gpu(self, batch: Any, gpu_id: int): device = torch.device('cuda', gpu_id) - return self.__transfer_data_to_device(batch, device) + return self.__transfer_batch_to_device(batch, device) - def __transfer_data_to_device(self, batch: Any, device: torch.device): + def __transfer_batch_to_device(self, batch: Any, device: torch.device): if self.is_overridden('transfer_batch_to_device'): + # user-override for custom batch types return self.get_model().transfer_batch_to_device(batch, device) - - def to(tensor): - return tensor.to(device, non_blocking=True) - - return apply_to_collection(batch, dtype=torch.Tensor, function=to) + return transfer_batch_to_device(batch, device) def single_gpu_train(self, model): model.cuda(self.root_gpu) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index c8bc28052398b..53e454981c4a0 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -1,3 +1,4 @@ """General utilities""" from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities.apply_func import transfer_batch_to_device diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 724715c3d8607..d4281658c4939 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -1,6 +1,8 @@ from collections import Mapping, Sequence from typing import Any, Callable, Union +import torch + def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any: """ @@ -34,3 +36,9 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable # data is neither of dtype, nor a collection return data + + +def transfer_batch_to_device(batch: Any, device: torch.device): + def to(tensor): + return tensor.to(device, non_blocking=True) + return apply_to_collection(batch, dtype=torch.Tensor, function=to) From 2d348a9deb997bc5eadc0f43afb5b265e0dcafd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 20 May 2020 12:15:35 +0200 Subject: [PATCH 10/18] documentation --- pytorch_lightning/trainer/distrib_parts.py | 32 ++++++++++++++++++++-- pytorch_lightning/utilities/apply_func.py | 15 ++++++++++ 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index cf2b6a01ee2f7..feb91a29eb2da 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -100,16 +100,42 @@ def copy_trainer_model_properties(self, model): m.tpu_local_core_rank = self.tpu_local_core_rank m.tpu_global_core_rank = self.tpu_global_core_rank - def transfer_batch_to_tpu(self, batch: Any, tpu_id: int = None): + def transfer_batch_to_tpu(self, batch: Any, tpu_id: Optional[int] = None): + """ + Transfers the data to the TPU. + + Args: + batch: A tensor or collection of tensors. + tpu_id: The id of the TPU core. If omitted, the first available core is chosen. + + Returns: + the tensor on the TPU device. + + See Also: + - :func:`~pytorch_lightning.utilities.apply_func.transfer_batch_to_device` + """ if not XLA_AVAILABLE: raise MisconfigurationException( 'Requested to transfer batch to TPU but XLA is not available.' ' Are you sure this machine has TPUs?' ) - device = xm.xla_device(tpu_id) # None will use all available devices + device = xm.xla_device(tpu_id) return self.__transfer_data_to_device(batch, device) - def transfer_batch_to_gpu(self, batch: Any, gpu_id: int): + def transfer_batch_to_gpu(self, batch: Any, gpu_id: Optional[int] = None): + """ + Transfers the data to the GPU. + + Args: + batch: A tensor or collection of tensors. + gpu_id: The id of the GPU device. If omitted, the first available GPU is chosen. + + Returns: + the tensor on the GPU device. + + See Also: + - :func:`~pytorch_lightning.utilities.apply_func.transfer_batch_to_device` + """ device = torch.device('cuda', gpu_id) return self.__transfer_batch_to_device(batch, device) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index d4281658c4939..6034967a99d2a 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -39,6 +39,21 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable def transfer_batch_to_device(batch: Any, device: torch.device): + """ + Transfers a collection of tensors to the given device. + + Args: + batch: A tensor or collection of tensors. See :func:`apply_to_collection` + for a list of supported collection types. + device: The device to which tensors should be moved + + Returns: + the same collection but with all contained tensors residing on the new device. + + See Also: + - :meth:`torch.Tensor.to` + - :class:`torch.device` + """ def to(tensor): return tensor.to(device, non_blocking=True) return apply_to_collection(batch, dtype=torch.Tensor, function=to) From 39eca5ea94d8f75b28c81c44eb435ee749210ec1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 20 May 2020 20:50:48 +0200 Subject: [PATCH 11/18] move changelog entry to top --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ec1a872d80936..64e475739eec2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Speed up single-core TPU training by loading data using `ParallelLoader` ([#2033](https://github.com/PyTorchLightning/pytorch-lightning/pull/2033)) +- Added a model hook `transfer_batch_to_device` that enables moving custom data structures to the target device ([1756](https://github.com/PyTorchLightning/pytorch-lightning/pull/1756)). + ### Changed - Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729)) @@ -88,8 +90,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added using `store_true` for bool args ([#1822](https://github.com/PyTorchLightning/pytorch-lightning/pull/1822), [#1842](https://github.com/PyTorchLightning/pytorch-lightning/pull/1842)) - Added dummy logger for internally disabling logging for some features ([#1836](https://github.com/PyTorchLightning/pytorch-lightning/pull/1836)) -- Added a model hook `transfer_batch_to_device` that enables moving custom data structures to the target device ([1756](https://github.com/PyTorchLightning/pytorch-lightning/pull/1756)). - ### Changed - Enable `non-blocking` for device transfers to GPU ([#1843](https://github.com/PyTorchLightning/pytorch-lightning/pull/1843)) From 97ad5c7e22c3a29a71c475f646ec99cc6a458d5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 26 May 2020 22:09:28 +0200 Subject: [PATCH 12/18] fix tpu transfer call --- pytorch_lightning/trainer/distrib_parts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index feb91a29eb2da..2567397c05b26 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -120,7 +120,7 @@ def transfer_batch_to_tpu(self, batch: Any, tpu_id: Optional[int] = None): ' Are you sure this machine has TPUs?' ) device = xm.xla_device(tpu_id) - return self.__transfer_data_to_device(batch, device) + return transfer_batch_to_device(batch, device) def transfer_batch_to_gpu(self, batch: Any, gpu_id: Optional[int] = None): """ From e0f06ba51baa87b09c125d7363f2b41aa7015aed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 28 May 2020 12:36:11 +0200 Subject: [PATCH 13/18] fix call --- pytorch_lightning/trainer/distrib_parts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 2567397c05b26..33eb17dcb5633 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -120,7 +120,7 @@ def transfer_batch_to_tpu(self, batch: Any, tpu_id: Optional[int] = None): ' Are you sure this machine has TPUs?' ) device = xm.xla_device(tpu_id) - return transfer_batch_to_device(batch, device) + return self.__transfer_batch_to_device(batch, device) def transfer_batch_to_gpu(self, batch: Any, gpu_id: Optional[int] = None): """ From 225f4a04439d8c598300e4836a843d20e65327b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 28 May 2020 12:40:05 +0200 Subject: [PATCH 14/18] remove hardcoded string --- pytorch_lightning/trainer/distrib_parts.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 33eb17dcb5633..850dcc1bf9d8f 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -13,6 +13,7 @@ from typing import Union, Callable, Any, List, Optional from pytorch_lightning import _logger as log +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.overrides.data_parallel import ( LightningDistributedDataParallel, @@ -140,7 +141,7 @@ def transfer_batch_to_gpu(self, batch: Any, gpu_id: Optional[int] = None): return self.__transfer_batch_to_device(batch, device) def __transfer_batch_to_device(self, batch: Any, device: torch.device): - if self.is_overridden('transfer_batch_to_device'): + if self.is_overridden(LightningModule.transfer_batch_to_device.__name__): # user-override for custom batch types return self.get_model().transfer_batch_to_device(batch, device) return transfer_batch_to_device(batch, device) From 0a6c7fc549038c98defe434510c64764abf51c48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 28 May 2020 12:40:18 +0200 Subject: [PATCH 15/18] improve test --- tests/models/test_hooks.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index ddf72ff183c19..0967a205cbbd0 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -83,7 +83,10 @@ def __init__(self, data): class CurrentTestModel(EvalModelTemplate): + hook_called = False + def transfer_batch_to_device(self, batch, device): + self.hook_called = True if isinstance(batch, CustomBatch): batch.samples = batch.samples.to(device) batch.targets = batch.targets.to(device) @@ -95,7 +98,7 @@ def transfer_batch_to_device(self, batch, device): trainer = Trainer() # running .fit() would require us to implement custom data loaders, we mock the model reference instead trainer.get_model = MagicMock(return_value=model) - batch_gpu = trainer.transfer_batch_to_gpu(batch, 0) - device = torch.device('cuda', 0) - assert batch_gpu.samples.device == batch_gpu.targets.device == device + expected = torch.device('cuda', 0) + assert model.hook_called + assert batch_gpu.samples.device == batch_gpu.targets.device == expected From 832da9a9c7d8cc5a2a7b07ad4b1b0c9219d63c81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 31 May 2020 02:48:23 +0200 Subject: [PATCH 16/18] call model hook by default --- pytorch_lightning/core/hooks.py | 12 ++++++++---- pytorch_lightning/trainer/distrib_parts.py | 7 +++---- tests/models/test_hooks.py | 12 +++++++----- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 317676dafbc2a..4b2119ea822f1 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -3,6 +3,8 @@ import torch from torch import Tensor from torch.optim.optimizer import Optimizer +from pytorch_lightning.utilities import transfer_batch_to_device + try: from apex import amp @@ -158,7 +160,8 @@ def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: """ Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors wrapped in a custom data structure. - Lightning only calls the hook if it does not recognize the data type of your batch as one of + + The data types listed below (and any arbitrary nesting of them) are supported out of the box: - :class:`torch.Tensor` - :class:`list` @@ -166,8 +169,6 @@ def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: - :class:`tuple` - ``torchtext.data.Batch`` (COMING SOON) - These data types (and any arbitrary nesting of them) are supported out of the box - (see :func:`~pytorch_lightning.utilities.apply_func.transfer_batch_to_device`). For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...). Example:: @@ -177,6 +178,8 @@ def transfer_batch_to_device(self, batch, device) # move all tensors in your custom data structure to the device batch.samples = batch.samples.to(device) batch.targets = batch.targets.to(device) + else: + batch = super().transfer_batch_to_device(data, device) return batch Args: @@ -188,7 +191,7 @@ def transfer_batch_to_device(self, batch, device) Note: This hook should only transfer the data and not modify it, nor should it move the data to - any other device than the one passed in as argument. + any other device than the one passed in as argument (unless you know what you are doing). The :class:`~pytorch_lightning.trainer.trainer.Trainer` already takes care of splitting the batch and determines the target devices. @@ -196,3 +199,4 @@ def transfer_batch_to_device(self, batch, device) - :func:`~pytorch_lightning.utilities.apply_func.transfer_batch_to_device` - :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection` """ + return transfer_batch_to_device(batch, device) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 850dcc1bf9d8f..80bc8e3c3f6c3 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -13,7 +13,6 @@ from typing import Union, Callable, Any, List, Optional from pytorch_lightning import _logger as log -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.overrides.data_parallel import ( LightningDistributedDataParallel, @@ -141,9 +140,9 @@ def transfer_batch_to_gpu(self, batch: Any, gpu_id: Optional[int] = None): return self.__transfer_batch_to_device(batch, device) def __transfer_batch_to_device(self, batch: Any, device: torch.device): - if self.is_overridden(LightningModule.transfer_batch_to_device.__name__): - # user-override for custom batch types - return self.get_model().transfer_batch_to_device(batch, device) + model = self.get_model() + if model is not None: + return model.transfer_batch_to_device(batch, device) return transfer_batch_to_device(batch, device) def single_gpu_train(self, model): diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 0967a205cbbd0..47b73eb9e715b 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -85,12 +85,14 @@ class CurrentTestModel(EvalModelTemplate): hook_called = False - def transfer_batch_to_device(self, batch, device): + def transfer_batch_to_device(self, data, device): self.hook_called = True - if isinstance(batch, CustomBatch): - batch.samples = batch.samples.to(device) - batch.targets = batch.targets.to(device) - return batch + if isinstance(data, CustomBatch): + data.samples = data.samples.to(device) + data.targets = data.targets.to(device) + else: + data = super().transfer_batch_to_device(data, device) + return data model = CurrentTestModel() batch = CustomBatch((torch.zeros(5, 28), torch.ones(5, 1, dtype=torch.long))) From 47b4693b5ec3dc44ab6560915e186542d23c36a4 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 1 Jun 2020 23:10:39 +0200 Subject: [PATCH 17/18] Apply suggestions from code review --- pytorch_lightning/trainer/distrib_parts.py | 4 ++-- pytorch_lightning/utilities/apply_func.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 80bc8e3c3f6c3..1b40dd76af5a0 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -108,7 +108,7 @@ def transfer_batch_to_tpu(self, batch: Any, tpu_id: Optional[int] = None): batch: A tensor or collection of tensors. tpu_id: The id of the TPU core. If omitted, the first available core is chosen. - Returns: + Return: the tensor on the TPU device. See Also: @@ -130,7 +130,7 @@ def transfer_batch_to_gpu(self, batch: Any, gpu_id: Optional[int] = None): batch: A tensor or collection of tensors. gpu_id: The id of the GPU device. If omitted, the first available GPU is chosen. - Returns: + Return: the tensor on the GPU device. See Also: diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 6034967a99d2a..2c1191ad7e05d 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -47,7 +47,7 @@ def transfer_batch_to_device(batch: Any, device: torch.device): for a list of supported collection types. device: The device to which tensors should be moved - Returns: + Return: the same collection but with all contained tensors residing on the new device. See Also: From ea9fab37388b4b8d0f5df0675fcfad453ac4cf5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 3 Jun 2020 01:09:39 +0200 Subject: [PATCH 18/18] rename utility function --- pytorch_lightning/core/hooks.py | 6 +++--- pytorch_lightning/trainer/distrib_parts.py | 8 ++++---- pytorch_lightning/utilities/__init__.py | 2 +- pytorch_lightning/utilities/apply_func.py | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 4b2119ea822f1..960c7124383b0 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -3,7 +3,7 @@ import torch from torch import Tensor from torch.optim.optimizer import Optimizer -from pytorch_lightning.utilities import transfer_batch_to_device +from pytorch_lightning.utilities import move_data_to_device try: @@ -196,7 +196,7 @@ def transfer_batch_to_device(self, batch, device) batch and determines the target devices. See Also: - - :func:`~pytorch_lightning.utilities.apply_func.transfer_batch_to_device` + - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` - :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection` """ - return transfer_batch_to_device(batch, device) + return move_data_to_device(batch, device) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 1b40dd76af5a0..9ea54a0e00346 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -18,7 +18,7 @@ LightningDistributedDataParallel, LightningDataParallel, ) -from pytorch_lightning.utilities import transfer_batch_to_device +from pytorch_lightning.utilities import move_data_to_device from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.distributed import rank_zero_only @@ -112,7 +112,7 @@ def transfer_batch_to_tpu(self, batch: Any, tpu_id: Optional[int] = None): the tensor on the TPU device. See Also: - - :func:`~pytorch_lightning.utilities.apply_func.transfer_batch_to_device` + - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` """ if not XLA_AVAILABLE: raise MisconfigurationException( @@ -134,7 +134,7 @@ def transfer_batch_to_gpu(self, batch: Any, gpu_id: Optional[int] = None): the tensor on the GPU device. See Also: - - :func:`~pytorch_lightning.utilities.apply_func.transfer_batch_to_device` + - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` """ device = torch.device('cuda', gpu_id) return self.__transfer_batch_to_device(batch, device) @@ -143,7 +143,7 @@ def __transfer_batch_to_device(self, batch: Any, device: torch.device): model = self.get_model() if model is not None: return model.transfer_batch_to_device(batch, device) - return transfer_batch_to_device(batch, device) + return move_data_to_device(batch, device) def single_gpu_train(self, model): model.cuda(self.root_gpu) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 53e454981c4a0..51eb3b283d43a 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -1,4 +1,4 @@ """General utilities""" from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn -from pytorch_lightning.utilities.apply_func import transfer_batch_to_device +from pytorch_lightning.utilities.apply_func import move_data_to_device diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 2c1191ad7e05d..bb32f79df9030 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -38,7 +38,7 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable return data -def transfer_batch_to_device(batch: Any, device: torch.device): +def move_data_to_device(batch: Any, device: torch.device): """ Transfers a collection of tensors to the given device.