Skip to content

Commit

Permalink
fix non blocking + docs
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Aug 11, 2020
1 parent 35a3fd2 commit 3047bfe
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
9 changes: 7 additions & 2 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,13 @@ def transfer_batch_to_device(self, batch, device)
Note:
This hook should only transfer the data and not modify it, nor should it move the data to
any other device than the one passed in as argument (unless you know what you are doing).
The :class:`~pytorch_lightning.trainer.trainer.Trainer` already takes care of splitting the
batch and determines the target devices.
Note:
This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support
for your custom batch objects, you need to define your custom
:class:`~torch.nn.parallel.DistributedDataParallel` or
:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib
import inspect
from abc import ABC
from collections.abc import Mapping, Sequence
from copy import copy
Expand Down Expand Up @@ -104,6 +105,11 @@ def batch_to(data):
setattr(device_data, field, device_field)
return device_data

return data.to(device, non_blocking=True)
kwargs = dict()
signature = inspect.signature(data.to)
if "non_blocking" in signature.parameters:
kwargs.update(non_blocking=True)

return data.to(device, **kwargs)

return apply_to_collection(batch, dtype=(TransferableDataType, Batch), function=batch_to)

0 comments on commit 3047bfe

Please sign in to comment.