Skip to content

Commit

Permalink
Support torchtext on a single GPU (#2379)
Browse files Browse the repository at this point in the history
* Handle torchtext.data.Batch on GPU

* Update CHANGELOG.md

* Apply code review requests

* Correct the docs

* Change requirements
  • Loading branch information
elkotito committed Jun 27, 2020
1 parent 73a78a1 commit e82d9cd
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added TorchText support for moving data to GPU ([#2379](https://github.com/PyTorchLightning/pytorch-lightning/pull/2379))

### Changed

- Changed epoch indexing from 0 instead of 1 ([#2289](https://github.com/PyTorchLightning/pytorch-lightning/pull/2289))
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
- :class:`list`
- :class:`dict`
- :class:`tuple`
- ``torchtext.data.Batch`` (COMING SOON)
- :class:`torchtext.data.batch.Batch`
For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).
Expand Down
16 changes: 14 additions & 2 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Any, Callable, Union

import torch
from torchtext.data import Batch
from copy import copy


def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any:
Expand Down Expand Up @@ -84,6 +86,16 @@ def move_data_to_device(batch: Any, device: torch.device):
- :meth:`torch.Tensor.to`
- :class:`torch.device`
"""
def to(data):

def batch_to(data):
if isinstance(data, Batch):
# Shallow copy because each Batch has a reference to Dataset which contains all examples
device_data = copy(data)
for field in data.fields:
# Batch contains output of Field.process(...) which is tensor hence .to(...) exists
device_field = getattr(data, field).to(device, non_blocking=True)
setattr(device_data, field, device_field)
return device_data

return data.to(device, non_blocking=True)
return apply_to_collection(batch, dtype=TransferableDataType, function=to)
return apply_to_collection(batch, dtype=(TransferableDataType, Batch), function=batch_to)
1 change: 1 addition & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ tensorboard>=1.14
future>=0.17.1 # required for builtins in setup.py
# pyyaml>=3.13
PyYAML>=5.1 # OmegaConf requirement
torchtext>=0.3.1
2 changes: 1 addition & 1 deletion requirements/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ matplotlib>=3.1.1
horovod>=0.19.1
omegaconf>=2.0.0
# scipy>=0.13.3
scikit-learn>=0.20.0
scikit-learn>=0.20.0
30 changes: 30 additions & 0 deletions tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pytorch_lightning.trainer.distrib_parts import _parse_gpu_ids, determine_root_gpu_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from torchtext.data import Batch, Dataset, Example, Field, LabelField

PRETEND_N_OF_GPUS = 16

Expand Down Expand Up @@ -301,3 +302,32 @@ def to(self, *args, **kwargs):

batch = trainer.transfer_batch_to_gpu(CustomBatchType())
assert batch.a.type() == 'torch.cuda.FloatTensor'

# torchtext.data.Batch
samples = [
{'text': 'PyTorch Lightning is awesome!', 'label': 0},
{'text': 'Please make it work with torchtext', 'label': 1}
]

text_field = Field()
label_field = LabelField()
fields = {
'text': ('text', text_field),
'label': ('label', label_field)
}

examples = [Example.fromdict(sample, fields) for sample in samples]
dataset = Dataset(
examples=examples,
fields=fields.values()
)

# Batch runs field.process() that numericalizes tokens, but it requires to build dictionary first
text_field.build_vocab(dataset)
label_field.build_vocab(dataset)

batch = Batch(data=examples, dataset=dataset)
batch = trainer.transfer_batch_to_gpu(batch, 0)

assert batch.text.type() == 'torch.cuda.LongTensor'
assert batch.label.type() == 'torch.cuda.LongTensor'

0 comments on commit e82d9cd

Please sign in to comment.