Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix importing torchtext batch #6365

Merged
merged 6 commits into from
Mar 5, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 10 additions & 21 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,7 @@
import torch

from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _module_available, _TORCHTEXT_AVAILABLE

if _TORCHTEXT_AVAILABLE:
if _module_available("torchtext.legacy.data"):
from torchtext.legacy.data import Batch
else:
from torchtext.data import Batch
else:
Batch = type(None)
from pytorch_lightning.utilities.torchtext_batch import Batch


def to_dtype_tensor(value, dtype: torch.dtype = None, device: torch.device = None):
Expand Down Expand Up @@ -142,22 +134,19 @@ def move_data_to_device(batch: Any, device: torch.device):
"""

def batch_to(data):
# try to move torchtext data first
if _TORCHTEXT_AVAILABLE and isinstance(data, Batch):

# Shallow copy because each Batch has a reference to Dataset which contains all examples
device_data = copy(data)
for field, field_value in data.dataset.fields.items():
if field_value is None:
continue
device_field = move_data_to_device(getattr(data, field), device)
setattr(device_data, field, device_field)
return device_data
# Shallow copy because each Batch has a reference to Dataset which contains all examples
device_data = copy(data)
for field, field_value in data.dataset.fields.items():
if field_value is None:
continue
device_field = move_data_to_device(getattr(data, field), device)
setattr(device_data, field, device_field)
return device_data

kwargs = dict(non_blocking=True) if isinstance(data, torch.Tensor) else {}
return data.to(device, **kwargs)

dtype = (TransferableDataType, Batch) if _TORCHTEXT_AVAILABLE else TransferableDataType
dtype = (TransferableDataType, Batch)
return apply_to_collection(batch, dtype=dtype, function=batch_to)


Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,5 @@ def _compare_version(package: str, op, version) -> bool:
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
_RPC_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.rpc')
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != 'none'])
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
_TORCHVISION_AVAILABLE = _module_available('torchvision')
_XLA_AVAILABLE = _module_available("torch_xla")
101 changes: 101 additions & 0 deletions pytorch_lightning/utilities/torchtext_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# THIS IS PURE COPY of TORCHTEXT BATCH CLASS
# THIS STRUCTURE SEEMS TO BE DEPRECATED IN TORCHTEXT

import torch


class Batch(object):
"""Defines a batch of examples along with its Fields.

Attributes:
batch_size: Number of examples in the batch.
dataset: A reference to the dataset object the examples come from
(which itself contains the dataset's Field objects).
train: Deprecated: this attribute is left for backwards compatibility,
however it is UNUSED as of the merger with pytorch 0.4.
input_fields: The names of the fields that are used as input for the model
target_fields: The names of the fields that are used as targets during model training

Also stores the Variable for each column in the batch as an attribute.
"""

def __init__(self, data=None, dataset=None, device=None):
"""Create a Batch from a list of examples."""
if data is not None:
self.batch_size = len(data)
self.dataset = dataset
self.fields = dataset.fields.keys() # copy field names
self.input_fields = [k for k, v in dataset.fields.items() if v is not None and not v.is_target]
self.target_fields = [k for k, v in dataset.fields.items() if v is not None and v.is_target]

for (name, field) in dataset.fields.items():
if field is not None:
batch = [getattr(x, name) for x in data]
setattr(self, name, field.process(batch, device=device))

@classmethod
def fromvars(cls, dataset, batch_size, train=None, **kwargs):
"""Create a Batch directly from a number of Variables."""
batch = cls()
batch.batch_size = batch_size
batch.dataset = dataset
batch.fields = dataset.fields.keys()
for k, v in kwargs.items():
setattr(batch, k, v)
return batch

def __repr__(self):
return str(self)

def __str__(self):
if not self.__dict__:
return 'Empty {} instance'.format(torch.typename(self))

fields_to_index = filter(lambda field: field is not None, self.fields)
var_strs = '\n'.join(['\t[.' + name + ']' + ":" + _short_str(getattr(self, name))
for name in fields_to_index if hasattr(self, name)])

data_str = (' from {}'.format(self.dataset.name.upper())
if hasattr(self.dataset, 'name')
and isinstance(self.dataset.name, str) else '')

strt = '[{} of size {}{}]\n{}'.format(torch.typename(self),
self.batch_size, data_str, var_strs)
return '\n' + strt

def __len__(self):
return self.batch_size

def _get_field_values(self, fields):
if len(fields) == 0:
return None
elif len(fields) == 1:
return getattr(self, fields[0])
else:
return tuple(getattr(self, f) for f in fields)

def __iter__(self):
yield self._get_field_values(self.input_fields)
yield self._get_field_values(self.target_fields)


def _short_str(tensor):
# unwrap variable to tensor
if not torch.is_tensor(tensor):
# (1) unpack variable
if hasattr(tensor, 'data'):
tensor = getattr(tensor, 'data')
# (2) handle include_lengths
elif isinstance(tensor, tuple):
return str(tuple(_short_str(t) for t in tensor))
# (3) fallback to default str
else:
return str(tensor)

# copied from torch _tensor_str
size_str = 'x'.join(str(size) for size in tensor.size())
device_str = '' if not tensor.is_cuda else \
' (GPU {})'.format(tensor.get_device())
strt = '[{} of size {}{}]'.format(torch.typename(tensor),
size_str, device_str)
return strt
1 change: 0 additions & 1 deletion requirements/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
matplotlib>3.1
horovod>=0.21.2 # no need to install with [pytorch] as pytorch is already installed
omegaconf>=2.0.1
torchtext>=0.5
onnx>=1.7.0
onnxruntime>=1.3.0
hydra-core>=1.0
Expand Down
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ pre-commit>=1.0

cloudpickle>=1.3
nltk>=3.3
torchtext>=0.5
pandas # needed in benchmarks
8 changes: 8 additions & 0 deletions tests/helpers/imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import operator

from pytorch_lightning.utilities.imports import _compare_version

if _compare_version("torch", operator.ge, "0.9.0"):
Borda marked this conversation as resolved.
Show resolved Hide resolved
from torchtext.legacy.data import Batch, Dataset, Example, Field, Iterator, LabelField # noqa: F401
else:
from torchtext.data import Batch, Dataset, Example, Field, Iterator, LabelField # noqa: F401
2 changes: 1 addition & 1 deletion tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import pytest
import torch
from torchtext.data import Batch, Dataset, Example, Field, LabelField

import tests.helpers.pipelines as tpipes
import tests.helpers.utils as tutils
Expand All @@ -25,6 +24,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.imports import Batch, Dataset, Example, Field, LabelField
from tests.helpers.runif import RunIf
from tests.helpers.simple_models import ClassificationModel

Expand Down
9 changes: 4 additions & 5 deletions tests/utilities/test_apply_func_torchtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
# limitations under the License.
import pytest
import torch
import torchtext
from torchtext.data.example import Example

from pytorch_lightning.utilities.apply_func import move_data_to_device
from tests.helpers.imports import Batch, Dataset, Example, Field, Iterator
from tests.helpers.runif import RunIf


def _get_torchtext_data_iterator(include_lengths=False):
text_field = torchtext.data.Field(
text_field = Field(
sequential=True,
pad_first=False, # nosec
init_token="<s>",
Expand All @@ -33,13 +32,13 @@ def _get_torchtext_data_iterator(include_lengths=False):
example2 = Example.fromdict({"text": "b c a a"}, {"text": ("text", text_field)})
example3 = Example.fromdict({"text": "c b a"}, {"text": ("text", text_field)})

dataset = torchtext.data.Dataset(
dataset = Dataset(
[example1, example2, example3],
{"text": text_field},
)
text_field.build_vocab(dataset)

iterator = torchtext.data.Iterator(
iterator = Iterator(
dataset,
batch_size=3,
sort_key=None,
Expand Down