From 924e4c12a45c0d5352b5dfb83ab7f9188c76a8f3 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Fri, 24 Jul 2020 14:48:48 -0400 Subject: [PATCH 01/40] Test using torchtext.data.Field with include_lengths=True/False --- tests/utilities/test_apply_func_torchtext.py | 170 +++++++++++++++++++ 1 file changed, 170 insertions(+) create mode 100644 tests/utilities/test_apply_func_torchtext.py diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py new file mode 100644 index 0000000000000..c6f595e3d6579 --- /dev/null +++ b/tests/utilities/test_apply_func_torchtext.py @@ -0,0 +1,170 @@ +import torch +import torchtext + +import pytorch_lightning as pl + + +def get_torchtext_data_iterator(include_lengths=False): + text_field = torchtext.data.Field(sequential=True, pad_first=False, + init_token="", eos_token="", include_lengths=include_lengths) + + example1 = torchtext.data.example.Example.fromdict({"text": "a b c a c"}, {"text": ("text", text_field)}) + example2 = torchtext.data.example.Example.fromdict({"text": "b c a a"}, {"text": ("text", text_field)}) + example3 = torchtext.data.example.Example.fromdict({"text": "c b a"}, {"text": ("text", text_field)}) + + dataset = torchtext.data.Dataset([example1, example2, example3], {"text": text_field}) + text_field.build_vocab(dataset) + + iterator = torchtext.data.Iterator(dataset, batch_size=3, + sort_key=None, device=None, batch_size_fn=None, + train=True, repeat=False, shuffle=None, sort=None, sort_within_batch=None) + return iterator, text_field + + +def test_move_data_to_device_torchtext_include_length_true(): + """Test if batches created by torchtext with include_lengths=True raise an exception.""" + + class DebugModel(pl.LightningModule): + + def __init__(self): + super(DebugModel, self).__init__() + + # setup data loader generating batches with fields consisting of tuples of tensors + self.debug_data_loader, self.text_field = get_torchtext_data_iterator(include_lengths=True) + + self.learning_rate = 0.001 + + pad_idx = self.text_field.vocab.stoi[''] + self.criterion = torch.nn.CrossEntropyLoss(ignore_index=pad_idx) + + self.INPUT_DIM = len(self.text_field.vocab) + self.ENC_EMB_DIM = 4 # keep it small for debugging + self.embedding = torch.nn.Embedding(self.INPUT_DIM, self.ENC_EMB_DIM) + + self.hid_dim = 4 + self.rnn = torch.nn.GRU(self.ENC_EMB_DIM, self.hid_dim, 1, bidirectional=False) + self.out = torch.nn.Linear(self.hid_dim, self.embedding.num_embeddings) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.learning_rate) + + def forward(self, input_seq, length): + embedded = self.embedding(input_seq) + packed_embedded = torch.nn.utils.rnn.pack_padded_sequence(embedded, length, batch_first=False, + enforce_sorted=False) + packed_outputs, hidden = self.rnn(packed_embedded) + outputs, length = torch.nn.utils.rnn.pad_packed_sequence(packed_outputs) + + output = outputs.squeeze(0) + prediction = self.out(output) + + return prediction + + @staticmethod + def _parse_batch(batch): + source = batch.text[0] + source_length = batch.text[1] + + return source, source_length + + def training_step(self, batch, batch_nb): + """ Needed for testing data transfer. """ + x = self._parse_batch(batch) + target, target_length = x + + output = self.forward(target, target_length) + loss = self.criterion(output[:-1].view(-1, output.shape[2]), target[1:].view(-1)) + + prefix = 'train' + tensorboard_logs = {f'{prefix}_loss': loss.item()} + + result = {'loss': loss, 'log': tensorboard_logs} + return result + + def train_dataloader(self): + return self.debug_data_loader + + model = DebugModel() + + cuda_device_cnt = torch.cuda.device_count() + if cuda_device_cnt > 0: + use_num_cuda_devices = 1 + else: + use_num_cuda_devices = None + + trainer = pl.Trainer(fast_dev_run=True, max_steps=None, + gradient_clip_val=10, + weights_summary=None, gpus=use_num_cuda_devices, + show_progress_bar=True) + + result = trainer.fit(model) + # verify training completed + assert result == 1 + + +def test_move_data_to_device_torchtext_include_length_false(): + """Test if batches created by torchtext with include_lengths=False raise an exception.""" + + class DebugModel(pl.LightningModule): + + def __init__(self): + super(DebugModel, self).__init__() + + # setup data loader generating batches with fields consisting of tensors + self.debug_data_loader, self.text_field = get_torchtext_data_iterator(include_lengths=False) + + self.learning_rate = 0.001 + + pad_idx = self.text_field.vocab.stoi[''] + self.criterion = torch.nn.CrossEntropyLoss(ignore_index=pad_idx) + + self.INPUT_DIM = len(self.text_field.vocab) + self.ENC_EMB_DIM = 4 # keep it small for debugging + self.embedding = torch.nn.Embedding(self.INPUT_DIM, self.ENC_EMB_DIM) + + self.hid_dim = 4 + self.rnn = torch.nn.GRU(self.ENC_EMB_DIM, self.hid_dim, 1, bidirectional=False) + self.out = torch.nn.Linear(self.hid_dim, self.embedding.num_embeddings) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.learning_rate) + + def forward(self, input_seq): + embedded = self.embedding(input_seq) + outputs, hidden = self.rnn(embedded) + output = outputs.squeeze(0) + prediction = self.out(output) + return prediction + + def training_step(self, batch, batch_nb): + """ Needed for testing data transfer. """ + + target = batch.text + output = self.forward(target) + loss = self.criterion(output[:-1].view(-1, output.shape[2]), target[1:].view(-1)) + + prefix = 'train' + tensorboard_logs = {f'{prefix}_loss': loss.item()} + + result = {'loss': loss, 'log': tensorboard_logs} + return result + + def train_dataloader(self): + return self.debug_data_loader + + model = DebugModel() + + cuda_device_cnt = torch.cuda.device_count() + if cuda_device_cnt > 0: + use_num_cuda_devices = 1 + else: + use_num_cuda_devices = None + + trainer = pl.Trainer(fast_dev_run=True, max_steps=None, + gradient_clip_val=10, + weights_summary=None, gpus=use_num_cuda_devices, + show_progress_bar=True) + + result = trainer.fit(model) + # verify training completed + assert result == 1 From 1fcbe3631473dea5c223e652862cbcd8a3508bf0 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Fri, 24 Jul 2020 14:54:51 -0400 Subject: [PATCH 02/40] Fix issue that Tensors in a Batch generated by torchtext with torchtext.data.Field configured as include_lengths=True --- pytorch_lightning/utilities/apply_func.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 58462146364f2..816d9e48ad303 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -99,9 +99,25 @@ def batch_to(data): # Shallow copy because each Batch has a reference to Dataset which contains all examples device_data = copy(data) for field in data.fields: - # Batch contains output of Field.process(...) which is tensor hence .to(...) exists - device_field = getattr(data, field).to(device, non_blocking=True) - setattr(device_data, field, device_field) + # Batch contains output of Field.process(...) + if isinstance(getattr(data, field), torch.Tensor): + # standard case: usually a tensor hence .to(...) exists + device_field = getattr(data, field).to(device, non_blocking=True) + setattr(device_data, field, device_field) + elif isinstance(getattr(data, field), tuple): + # Case of include_lengths=True then torchtext produces a tuple of two tensors + # Use of generator expression to send Tensors to device (alternative could be list comprehension) + device_field = tuple(elem.to(device, non_blocking=True) for elem in getattr(data, field)) + setattr(device_data, field, device_field) + elif isinstance(getattr(data, field), list): + # Case for completeness + device_field = list(elem.to(device, non_blocking=True) for elem in getattr(data, field)) + setattr(device_data, field, device_field) + else: + # Catch all assuming the class has a .to if not it will fail; and more cases are needed + device_field = getattr(data, field).to(device, non_blocking=True) + setattr(device_data, field, device_field) + return device_data else: return data.to(device, non_blocking=True) From 59e97b2a98302ae9468c44fe5ebd8a3847f8e2c6 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Fri, 24 Jul 2020 15:43:58 -0400 Subject: [PATCH 03/40] Add description for fix of issue #2688 --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ed0abbcfa801..9042d370f8270 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `dtype` and `device` properties not getting updated in submodules ([#2657](https://github.com/PyTorchLightning/pytorch-lightning/pull/2657)) +- Fixed data transfer to device when using torchtext.data.Field and include_lengths is True ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) + ## [0.8.5] - 2020-07-09 ### Added From 3e3fbbe821925bac00cbd6ce393b8d9df8cbe12b Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Fri, 24 Jul 2020 16:43:58 -0400 Subject: [PATCH 04/40] changes to accomodate CodeFactor issues --- pytorch_lightning/utilities/apply_func.py | 6 ++++-- tests/utilities/test_apply_func_torchtext.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 816d9e48ad303..09698fccdc579 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -6,6 +6,7 @@ import torch import importlib + TORCHTEXT_AVAILABLE = importlib.util.find_spec("torchtext") is not None if TORCHTEXT_AVAILABLE: from torchtext.data import Batch @@ -92,6 +93,7 @@ def move_data_to_device(batch: Any, device: torch.device): - :meth:`torch.Tensor.to` - :class:`torch.device` """ + def batch_to(data): # try to move torchtext data first if TORCHTEXT_AVAILABLE and isinstance(data, Batch): @@ -119,7 +121,7 @@ def batch_to(data): setattr(device_data, field, device_field) return device_data - else: - return data.to(device, non_blocking=True) + + return data.to(device, non_blocking=True) return apply_to_collection(batch, dtype=(TransferableDataType, Batch), function=batch_to) diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index c6f595e3d6579..2f697c0a5fb9d 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -6,7 +6,7 @@ def get_torchtext_data_iterator(include_lengths=False): text_field = torchtext.data.Field(sequential=True, pad_first=False, - init_token="", eos_token="", include_lengths=include_lengths) + init_token="", eos_token="", include_lengths=include_lengths) # nosec example1 = torchtext.data.example.Example.fromdict({"text": "a b c a c"}, {"text": ("text", text_field)}) example2 = torchtext.data.example.Example.fromdict({"text": "b c a a"}, {"text": ("text", text_field)}) From fe9816d8bcab42632979f78460a13226293cbabb Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Fri, 24 Jul 2020 17:07:05 -0400 Subject: [PATCH 05/40] Another attemt to make last CodeFactor issue pass (it's a false alarm) --- tests/utilities/test_apply_func_torchtext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index 2f697c0a5fb9d..02ff7200cb399 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -5,7 +5,7 @@ def get_torchtext_data_iterator(include_lengths=False): - text_field = torchtext.data.Field(sequential=True, pad_first=False, + text_field = torchtext.data.Field(sequential=True, pad_first=False, # nosec init_token="", eos_token="", include_lengths=include_lengths) # nosec example1 = torchtext.data.example.Example.fromdict({"text": "a b c a c"}, {"text": ("text", text_field)}) From 957ee898472b23226989e59018d6b70813e0dec5 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Fri, 24 Jul 2020 17:57:11 -0400 Subject: [PATCH 06/40] temporarly disable test of test_grad_tracking to check if testing will pass --- tests/models/test_grad_norm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_grad_norm.py b/tests/models/test_grad_norm.py index d7978965a3cfe..bd8aa426d7eaa 100644 --- a/tests/models/test_grad_norm.py +++ b/tests/models/test_grad_norm.py @@ -41,7 +41,7 @@ def on_after_backward(self): out[prefix + 'total'] = round(norm, 3) self.stored_grad_norms.append(out) - +@pytest.mark.skip(reason="temporarly deactivated for testing if testing would pass") @pytest.mark.parametrize("norm_type", [1., 1.25, 1.5, 2, 3, 5, 10, 'inf']) def test_grad_tracking(tmpdir, norm_type, rtol=5e-3): os.environ['PL_DEV_DEBUG'] = '1' From 7971e7d086a86773334826e860894e215bccfb25 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Fri, 24 Jul 2020 20:41:29 -0400 Subject: [PATCH 07/40] reenable test in test_grad_norm --- tests/models/test_grad_norm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_grad_norm.py b/tests/models/test_grad_norm.py index bd8aa426d7eaa..d7978965a3cfe 100644 --- a/tests/models/test_grad_norm.py +++ b/tests/models/test_grad_norm.py @@ -41,7 +41,7 @@ def on_after_backward(self): out[prefix + 'total'] = round(norm, 3) self.stored_grad_norms.append(out) -@pytest.mark.skip(reason="temporarly deactivated for testing if testing would pass") + @pytest.mark.parametrize("norm_type", [1., 1.25, 1.5, 2, 3, 5, 10, 'inf']) def test_grad_tracking(tmpdir, norm_type, rtol=5e-3): os.environ['PL_DEV_DEBUG'] = '1' From 4d0a849940e72e399b828db7de949aa939efd254 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf <42753790+thschaaf@users.noreply.github.com> Date: Sun, 26 Jul 2020 10:58:54 -0400 Subject: [PATCH 08/40] Update CHANGELOG.md Co-authored-by: Jirka Borovec --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9042d370f8270..3f5e8efba9f8f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,7 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `dtype` and `device` properties not getting updated in submodules ([#2657](https://github.com/PyTorchLightning/pytorch-lightning/pull/2657)) -- Fixed data transfer to device when using torchtext.data.Field and include_lengths is True ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) +- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) ## [0.8.5] - 2020-07-09 From c994e880df3f412f389dcd1133a98ca23933b7e7 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Sun, 26 Jul 2020 11:59:29 -0400 Subject: [PATCH 09/40] Renamed get_torchtext_data_iterator to _get_torchtext_data_iterator as suggested by @borda --- tests/utilities/test_apply_func_torchtext.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index 02ff7200cb399..e13615de32e66 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -4,7 +4,7 @@ import pytorch_lightning as pl -def get_torchtext_data_iterator(include_lengths=False): +def _get_torchtext_data_iterator(include_lengths=False): text_field = torchtext.data.Field(sequential=True, pad_first=False, # nosec init_token="", eos_token="", include_lengths=include_lengths) # nosec @@ -30,7 +30,7 @@ def __init__(self): super(DebugModel, self).__init__() # setup data loader generating batches with fields consisting of tuples of tensors - self.debug_data_loader, self.text_field = get_torchtext_data_iterator(include_lengths=True) + self.debug_data_loader, self.text_field = _get_torchtext_data_iterator(include_lengths=True) self.learning_rate = 0.001 @@ -111,7 +111,7 @@ def __init__(self): super(DebugModel, self).__init__() # setup data loader generating batches with fields consisting of tensors - self.debug_data_loader, self.text_field = get_torchtext_data_iterator(include_lengths=False) + self.debug_data_loader, self.text_field = _get_torchtext_data_iterator(include_lengths=False) self.learning_rate = 0.001 From f60613cf943e642afe673d33d7fe72be32d104d2 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf <42753790+thschaaf@users.noreply.github.com> Date: Sun, 26 Jul 2020 15:22:49 -0400 Subject: [PATCH 10/40] Update pytorch_lightning/utilities/apply_func.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/utilities/apply_func.py | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 09698fccdc579..810b9a9677f51 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -101,25 +101,8 @@ def batch_to(data): # Shallow copy because each Batch has a reference to Dataset which contains all examples device_data = copy(data) for field in data.fields: - # Batch contains output of Field.process(...) - if isinstance(getattr(data, field), torch.Tensor): - # standard case: usually a tensor hence .to(...) exists - device_field = getattr(data, field).to(device, non_blocking=True) - setattr(device_data, field, device_field) - elif isinstance(getattr(data, field), tuple): - # Case of include_lengths=True then torchtext produces a tuple of two tensors - # Use of generator expression to send Tensors to device (alternative could be list comprehension) - device_field = tuple(elem.to(device, non_blocking=True) for elem in getattr(data, field)) - setattr(device_data, field, device_field) - elif isinstance(getattr(data, field), list): - # Case for completeness - device_field = list(elem.to(device, non_blocking=True) for elem in getattr(data, field)) - setattr(device_data, field, device_field) - else: - # Catch all assuming the class has a .to if not it will fail; and more cases are needed - device_field = getattr(data, field).to(device, non_blocking=True) - setattr(device_data, field, device_field) - + device_field = move_data_to_device(getattr(data, field), device) + setattr(device_data, field, device_field) return device_data return data.to(device, non_blocking=True) From c9fdf50a7c468c69522c1627a81b9472e679208d Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Sun, 26 Jul 2020 20:22:01 -0400 Subject: [PATCH 11/40] adding tests more specific to batch_move_data_to_device with tochtext Batch --- tests/utilities/test_apply_func_torchtext.py | 55 +++++++++++++++++--- 1 file changed, 47 insertions(+), 8 deletions(-) diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index e13615de32e66..1448c2f972ce1 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -1,23 +1,30 @@ import torch import torchtext +from torchtext.data.example import Example import pytorch_lightning as pl +from pytorch_lightning.utilities.apply_func import move_data_to_device def _get_torchtext_data_iterator(include_lengths=False): text_field = torchtext.data.Field(sequential=True, pad_first=False, # nosec - init_token="", eos_token="", include_lengths=include_lengths) # nosec + init_token="", eos_token="", # nosec + include_lengths=include_lengths) # nosec - example1 = torchtext.data.example.Example.fromdict({"text": "a b c a c"}, {"text": ("text", text_field)}) - example2 = torchtext.data.example.Example.fromdict({"text": "b c a a"}, {"text": ("text", text_field)}) - example3 = torchtext.data.example.Example.fromdict({"text": "c b a"}, {"text": ("text", text_field)}) + example1 = Example.fromdict({"text": "a b c a c"}, {"text": ("text", text_field)}) + example2 = Example.fromdict({"text": "b c a a"}, {"text": ("text", text_field)}) + example3 = Example.fromdict({"text": "c b a"}, {"text": ("text", text_field)}) - dataset = torchtext.data.Dataset([example1, example2, example3], {"text": text_field}) + dataset = torchtext.data.Dataset([example1, example2, example3], + {"text": text_field} + ) text_field.build_vocab(dataset) iterator = torchtext.data.Iterator(dataset, batch_size=3, - sort_key=None, device=None, batch_size_fn=None, - train=True, repeat=False, shuffle=None, sort=None, sort_within_batch=None) + sort_key=None, device=None, + batch_size_fn=None, + train=True, repeat=False, shuffle=None, + sort=None, sort_within_batch=None) return iterator, text_field @@ -50,7 +57,9 @@ def configure_optimizers(self): def forward(self, input_seq, length): embedded = self.embedding(input_seq) - packed_embedded = torch.nn.utils.rnn.pack_padded_sequence(embedded, length, batch_first=False, + packed_embedded = torch.nn.utils.rnn.pack_padded_sequence(embedded, + length, + batch_first=False, enforce_sorted=False) packed_outputs, hidden = self.rnn(packed_embedded) outputs, length = torch.nn.utils.rnn.pad_packed_sequence(packed_outputs) @@ -168,3 +177,33 @@ def train_dataloader(self): result = trainer.fit(model) # verify training completed assert result == 1 + + +def test_torchtext_include_lengths_false_batch_move_data_to_device(): + cuda_device_cnt = torch.cuda.device_count() + if cuda_device_cnt > 0: + device = torch.device('cuda') + else: + device = torch.device('cpu') + + data_iterator, _ = _get_torchtext_data_iterator(include_lengths=True) + data_iter = iter(data_iterator) + batch = next(data_iter) + + # this call should not throw an error + batch_on_device = move_data_to_device(batch, device) + + +def test_torchtext_include_lengths_true_batch_move_data_to_device(): + cuda_device_cnt = torch.cuda.device_count() + if cuda_device_cnt > 0: + device = torch.device('cuda') + else: + device = torch.device('cpu') + + data_iterator, _ = _get_torchtext_data_iterator(include_lengths=False) + data_iter = iter(data_iterator) + batch = next(data_iter) + + # this call should not throw an error + batch_on_device = move_data_to_device(batch, device) From 5e568ea080daa234f7aa835ee7033d9b188be710 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Sun, 26 Jul 2020 23:10:38 -0400 Subject: [PATCH 12/40] added check that Tensors were moved to target device --- tests/utilities/test_apply_func_torchtext.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index 1448c2f972ce1..552bdad2fe03e 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -28,7 +28,7 @@ def _get_torchtext_data_iterator(include_lengths=False): return iterator, text_field -def test_move_data_to_device_torchtext_include_length_true(): +def test_torchtext_field_include_length_true(): """Test if batches created by torchtext with include_lengths=True raise an exception.""" class DebugModel(pl.LightningModule): @@ -179,7 +179,7 @@ def train_dataloader(self): assert result == 1 -def test_torchtext_include_lengths_false_batch_move_data_to_device(): +def test_batch_move_data_to_device_torchtext_include_lengths_false(): cuda_device_cnt = torch.cuda.device_count() if cuda_device_cnt > 0: device = torch.device('cuda') @@ -192,9 +192,13 @@ def test_torchtext_include_lengths_false_batch_move_data_to_device(): # this call should not throw an error batch_on_device = move_data_to_device(batch, device) + # tensor with data + assert(batch_on_device.text[0].device == device) + # tensor with length of data + assert(batch_on_device.text[1].device == device) -def test_torchtext_include_lengths_true_batch_move_data_to_device(): +def test_batch_move_data_to_device_torchtext_include_lengths_true(): cuda_device_cnt = torch.cuda.device_count() if cuda_device_cnt > 0: device = torch.device('cuda') @@ -207,3 +211,5 @@ def test_torchtext_include_lengths_true_batch_move_data_to_device(): # this call should not throw an error batch_on_device = move_data_to_device(batch, device) + # tensor with data + assert(batch_on_device.text[0].device == device) From a6b96b07ca0f9d98309d9d6e2ea00b467b8bfb79 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Mon, 27 Jul 2020 11:05:39 -0400 Subject: [PATCH 13/40] removed tests using RNN models to be moved into a separate PR --- tests/utilities/test_apply_func_torchtext.py | 158 +------------------ 1 file changed, 3 insertions(+), 155 deletions(-) diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index 552bdad2fe03e..fd3fcaa70887d 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -2,7 +2,6 @@ import torchtext from torchtext.data.example import Example -import pytorch_lightning as pl from pytorch_lightning.utilities.apply_func import move_data_to_device @@ -28,157 +27,6 @@ def _get_torchtext_data_iterator(include_lengths=False): return iterator, text_field -def test_torchtext_field_include_length_true(): - """Test if batches created by torchtext with include_lengths=True raise an exception.""" - - class DebugModel(pl.LightningModule): - - def __init__(self): - super(DebugModel, self).__init__() - - # setup data loader generating batches with fields consisting of tuples of tensors - self.debug_data_loader, self.text_field = _get_torchtext_data_iterator(include_lengths=True) - - self.learning_rate = 0.001 - - pad_idx = self.text_field.vocab.stoi[''] - self.criterion = torch.nn.CrossEntropyLoss(ignore_index=pad_idx) - - self.INPUT_DIM = len(self.text_field.vocab) - self.ENC_EMB_DIM = 4 # keep it small for debugging - self.embedding = torch.nn.Embedding(self.INPUT_DIM, self.ENC_EMB_DIM) - - self.hid_dim = 4 - self.rnn = torch.nn.GRU(self.ENC_EMB_DIM, self.hid_dim, 1, bidirectional=False) - self.out = torch.nn.Linear(self.hid_dim, self.embedding.num_embeddings) - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=self.learning_rate) - - def forward(self, input_seq, length): - embedded = self.embedding(input_seq) - packed_embedded = torch.nn.utils.rnn.pack_padded_sequence(embedded, - length, - batch_first=False, - enforce_sorted=False) - packed_outputs, hidden = self.rnn(packed_embedded) - outputs, length = torch.nn.utils.rnn.pad_packed_sequence(packed_outputs) - - output = outputs.squeeze(0) - prediction = self.out(output) - - return prediction - - @staticmethod - def _parse_batch(batch): - source = batch.text[0] - source_length = batch.text[1] - - return source, source_length - - def training_step(self, batch, batch_nb): - """ Needed for testing data transfer. """ - x = self._parse_batch(batch) - target, target_length = x - - output = self.forward(target, target_length) - loss = self.criterion(output[:-1].view(-1, output.shape[2]), target[1:].view(-1)) - - prefix = 'train' - tensorboard_logs = {f'{prefix}_loss': loss.item()} - - result = {'loss': loss, 'log': tensorboard_logs} - return result - - def train_dataloader(self): - return self.debug_data_loader - - model = DebugModel() - - cuda_device_cnt = torch.cuda.device_count() - if cuda_device_cnt > 0: - use_num_cuda_devices = 1 - else: - use_num_cuda_devices = None - - trainer = pl.Trainer(fast_dev_run=True, max_steps=None, - gradient_clip_val=10, - weights_summary=None, gpus=use_num_cuda_devices, - show_progress_bar=True) - - result = trainer.fit(model) - # verify training completed - assert result == 1 - - -def test_move_data_to_device_torchtext_include_length_false(): - """Test if batches created by torchtext with include_lengths=False raise an exception.""" - - class DebugModel(pl.LightningModule): - - def __init__(self): - super(DebugModel, self).__init__() - - # setup data loader generating batches with fields consisting of tensors - self.debug_data_loader, self.text_field = _get_torchtext_data_iterator(include_lengths=False) - - self.learning_rate = 0.001 - - pad_idx = self.text_field.vocab.stoi[''] - self.criterion = torch.nn.CrossEntropyLoss(ignore_index=pad_idx) - - self.INPUT_DIM = len(self.text_field.vocab) - self.ENC_EMB_DIM = 4 # keep it small for debugging - self.embedding = torch.nn.Embedding(self.INPUT_DIM, self.ENC_EMB_DIM) - - self.hid_dim = 4 - self.rnn = torch.nn.GRU(self.ENC_EMB_DIM, self.hid_dim, 1, bidirectional=False) - self.out = torch.nn.Linear(self.hid_dim, self.embedding.num_embeddings) - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=self.learning_rate) - - def forward(self, input_seq): - embedded = self.embedding(input_seq) - outputs, hidden = self.rnn(embedded) - output = outputs.squeeze(0) - prediction = self.out(output) - return prediction - - def training_step(self, batch, batch_nb): - """ Needed for testing data transfer. """ - - target = batch.text - output = self.forward(target) - loss = self.criterion(output[:-1].view(-1, output.shape[2]), target[1:].view(-1)) - - prefix = 'train' - tensorboard_logs = {f'{prefix}_loss': loss.item()} - - result = {'loss': loss, 'log': tensorboard_logs} - return result - - def train_dataloader(self): - return self.debug_data_loader - - model = DebugModel() - - cuda_device_cnt = torch.cuda.device_count() - if cuda_device_cnt > 0: - use_num_cuda_devices = 1 - else: - use_num_cuda_devices = None - - trainer = pl.Trainer(fast_dev_run=True, max_steps=None, - gradient_clip_val=10, - weights_summary=None, gpus=use_num_cuda_devices, - show_progress_bar=True) - - result = trainer.fit(model) - # verify training completed - assert result == 1 - - def test_batch_move_data_to_device_torchtext_include_lengths_false(): cuda_device_cnt = torch.cuda.device_count() if cuda_device_cnt > 0: @@ -193,9 +41,9 @@ def test_batch_move_data_to_device_torchtext_include_lengths_false(): # this call should not throw an error batch_on_device = move_data_to_device(batch, device) # tensor with data - assert(batch_on_device.text[0].device == device) + assert (batch_on_device.text[0].device == device) # tensor with length of data - assert(batch_on_device.text[1].device == device) + assert (batch_on_device.text[1].device == device) def test_batch_move_data_to_device_torchtext_include_lengths_true(): @@ -212,4 +60,4 @@ def test_batch_move_data_to_device_torchtext_include_lengths_true(): # this call should not throw an error batch_on_device = move_data_to_device(batch, device) # tensor with data - assert(batch_on_device.text[0].device == device) + assert (batch_on_device.text[0].device == device) From 398ab549e28fa8c3d30355f71d108247b19655b5 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Mon, 27 Jul 2020 17:16:36 -0400 Subject: [PATCH 14/40] fixing FLAKE8 errors that showed up after merge from master branch modified: tests/base/datamodules.py modified: tests/callbacks/test_model_checkpoint.py --- tests/base/datamodules.py | 2 +- tests/callbacks/test_model_checkpoint.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/base/datamodules.py b/tests/base/datamodules.py index 23c07f93d4697..b1ac0cfe4d7cd 100644 --- a/tests/base/datamodules.py +++ b/tests/base/datamodules.py @@ -13,7 +13,7 @@ def __init__(self, data_dir: str = './'): def prepare_data(self): TrialMNIST(self.data_dir, train=True, download=True) TrialMNIST(self.data_dir, train=False, download=True) - + def setup(self): mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True) self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64]) diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py index 4cb52a54610e3..f3e4f113784dc 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/callbacks/test_model_checkpoint.py @@ -28,7 +28,7 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): max_epochs=2, ) trainer.fit(model) - assert checkpoint.dirpath == tmpdir / trainer.logger.name / f'version_0' / 'checkpoints' + assert checkpoint.dirpath == tmpdir / trainer.logger.name / 'version_0' / 'checkpoints' @pytest.mark.parametrize( From a99fc7d202b98b2f96544b69dc229b924fca9b30 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Mon, 27 Jul 2020 20:34:33 -0400 Subject: [PATCH 15/40] parameterized test to reduce code duplication --- tests/utilities/test_apply_func_torchtext.py | 37 +++++++------------- 1 file changed, 12 insertions(+), 25 deletions(-) diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index fd3fcaa70887d..7df591c13e0f8 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -1,3 +1,4 @@ +import pytest import torch import torchtext from torchtext.data.example import Example @@ -27,37 +28,23 @@ def _get_torchtext_data_iterator(include_lengths=False): return iterator, text_field -def test_batch_move_data_to_device_torchtext_include_lengths_false(): - cuda_device_cnt = torch.cuda.device_count() - if cuda_device_cnt > 0: - device = torch.device('cuda') - else: - device = torch.device('cpu') +# @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +# @pytest.mark.parametrize('device', [torch.device('cpu'), torch.device('cuda', 0)]) +@pytest.mark.parametrize('device', [torch.device('cpu')] if not torch.cuda.is_available() else [torch.device('cpu'), torch.device('cuda', 0)]) +@pytest.mark.parametrize('include_lengths', [False, True]) +def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, device): + # cuda_device_cnt = torch.cuda.device_count() + # if cuda_device_cnt > 0: + # device = torch.device('cuda') + # else: + # device = torch.device('cpu') - data_iterator, _ = _get_torchtext_data_iterator(include_lengths=True) + data_iterator, _ = _get_torchtext_data_iterator(include_lengths=include_lengths) data_iter = iter(data_iterator) batch = next(data_iter) - # this call should not throw an error batch_on_device = move_data_to_device(batch, device) # tensor with data assert (batch_on_device.text[0].device == device) # tensor with length of data assert (batch_on_device.text[1].device == device) - - -def test_batch_move_data_to_device_torchtext_include_lengths_true(): - cuda_device_cnt = torch.cuda.device_count() - if cuda_device_cnt > 0: - device = torch.device('cuda') - else: - device = torch.device('cpu') - - data_iterator, _ = _get_torchtext_data_iterator(include_lengths=False) - data_iter = iter(data_iterator) - batch = next(data_iter) - - # this call should not throw an error - batch_on_device = move_data_to_device(batch, device) - # tensor with data - assert (batch_on_device.text[0].device == device) From 61e692f08b360e9066899becbf8bca47f3ab67b1 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Mon, 27 Jul 2020 20:50:32 -0400 Subject: [PATCH 16/40] Added check only if length tensor exist. Removed left over comments. --- tests/utilities/test_apply_func_torchtext.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index 7df591c13e0f8..5101911531d5f 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -28,23 +28,19 @@ def _get_torchtext_data_iterator(include_lengths=False): return iterator, text_field -# @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") -# @pytest.mark.parametrize('device', [torch.device('cpu'), torch.device('cuda', 0)]) -@pytest.mark.parametrize('device', [torch.device('cpu')] if not torch.cuda.is_available() else [torch.device('cpu'), torch.device('cuda', 0)]) +@pytest.mark.parametrize('device', [torch.device('cpu')] if not torch.cuda.is_available() else [torch.device('cpu'), + torch.device('cuda', + 0)]) @pytest.mark.parametrize('include_lengths', [False, True]) def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, device): - # cuda_device_cnt = torch.cuda.device_count() - # if cuda_device_cnt > 0: - # device = torch.device('cuda') - # else: - # device = torch.device('cpu') - data_iterator, _ = _get_torchtext_data_iterator(include_lengths=include_lengths) data_iter = iter(data_iterator) batch = next(data_iter) - batch_on_device = move_data_to_device(batch, device) + # tensor with data assert (batch_on_device.text[0].device == device) - # tensor with length of data - assert (batch_on_device.text[1].device == device) + + if include_lengths: + # tensor with length of data + assert (batch_on_device.text[1].device == device) From 0c25f43a6f0a80c48641c9bccef5639f31480e92 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Mon, 27 Jul 2020 21:41:18 -0400 Subject: [PATCH 17/40] rearranged device parameterization and added pytest.param --- tests/utilities/test_apply_func_torchtext.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index 5101911531d5f..2a427c5b9ac54 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -28,9 +28,10 @@ def _get_torchtext_data_iterator(include_lengths=False): return iterator, text_field -@pytest.mark.parametrize('device', [torch.device('cpu')] if not torch.cuda.is_available() else [torch.device('cpu'), - torch.device('cuda', - 0)]) +@pytest.mark.parametrize(['device'], + [pytest.param(torch.device('cuda', 0)), + pytest.param(torch.device('cpu'))] if torch.cuda.is_available() else [ + pytest.param(torch.device('cpu'))]) @pytest.mark.parametrize('include_lengths', [False, True]) def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, device): data_iterator, _ = _get_torchtext_data_iterator(include_lengths=include_lengths) @@ -38,9 +39,10 @@ def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, de batch = next(data_iter) batch_on_device = move_data_to_device(batch, device) - # tensor with data - assert (batch_on_device.text[0].device == device) - if include_lengths: + # tensor with data + assert (batch_on_device.text[0].device == device) # tensor with length of data assert (batch_on_device.text[1].device == device) + else: + assert (batch_on_device.text.device == device) From f08dd78643d57773eff785377dc462d6e49d0c7d Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Mon, 27 Jul 2020 22:16:55 -0400 Subject: [PATCH 18/40] Try to figure out why only one device is tested on Linux machines --- tests/utilities/test_apply_func_torchtext.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index 2a427c5b9ac54..aae1eeb7a71d8 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -28,11 +28,13 @@ def _get_torchtext_data_iterator(include_lengths=False): return iterator, text_field -@pytest.mark.parametrize(['device'], - [pytest.param(torch.device('cuda', 0)), - pytest.param(torch.device('cpu'))] if torch.cuda.is_available() else [ - pytest.param(torch.device('cpu'))]) @pytest.mark.parametrize('include_lengths', [False, True]) +# @pytest.mark.parametrize(['device'], +# [pytest.param(torch.device('cuda', 0)), +# pytest.param(torch.device('cpu'))] if not torch.cuda.is_available() else [ +# pytest.param(torch.device('cpu'))]) +@pytest.mark.parametrize(['device'], [pytest.param(torch.device('cuda', 0)), pytest.param(torch.device('cpu'))]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, device): data_iterator, _ = _get_torchtext_data_iterator(include_lengths=include_lengths) data_iter = iter(data_iterator) From d2c4598dc6035ccb26c6bd88fb4d02272809473f Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Mon, 27 Jul 2020 23:17:17 -0400 Subject: [PATCH 19/40] Testing on CPU and GPU devices (GPU test is skip if no cuda device is available. --- tests/utilities/test_apply_func_torchtext.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index aae1eeb7a71d8..f17184fbe2df1 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -29,12 +29,8 @@ def _get_torchtext_data_iterator(include_lengths=False): @pytest.mark.parametrize('include_lengths', [False, True]) -# @pytest.mark.parametrize(['device'], -# [pytest.param(torch.device('cuda', 0)), -# pytest.param(torch.device('cpu'))] if not torch.cuda.is_available() else [ -# pytest.param(torch.device('cpu'))]) -@pytest.mark.parametrize(['device'], [pytest.param(torch.device('cuda', 0)), pytest.param(torch.device('cpu'))]) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +@pytest.mark.parametrize(['device'], [pytest.param(torch.device('cuda', 0))]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test assumes GPU machine") def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, device): data_iterator, _ = _get_torchtext_data_iterator(include_lengths=include_lengths) data_iter = iter(data_iterator) @@ -48,3 +44,8 @@ def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, de assert (batch_on_device.text[1].device == device) else: assert (batch_on_device.text.device == device) + + +@pytest.mark.parametrize('include_lengths', [False, True]) +def test_batch_move_data_to_device_torchtext_include_lengths_cpu(include_lengths): + test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, torch.device('cpu')) From 9bd3854f78987b63468c5b301f57803440eac633 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Tue, 28 Jul 2020 10:08:02 -0400 Subject: [PATCH 20/40] added test for TPU device (experimental) --- tests/utilities/test_apply_func_torchtext.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index f17184fbe2df1..39b7af9d56e6b 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -5,6 +5,20 @@ from pytorch_lightning.utilities.apply_func import move_data_to_device +try: + import torch_xla + import torch_xla.core.xla_model as xm + import torch_xla.distributed.xla_multiprocessing as xmp + SERIAL_EXEC = xmp.MpSerialExecutor() + # TODO: The tests are aborted if the following lines are uncommented. Must be resolved with XLA team + # device = torch_xla.core.xla_model.xla_device() + # device_type = torch_xla.core.xla_model.xla_device_hw(device) + # TPU_AVAILABLE = device_type == 'TPU' +except ImportError: + TPU_AVAILABLE = False +else: + TPU_AVAILABLE = True + def _get_torchtext_data_iterator(include_lengths=False): text_field = torchtext.data.Field(sequential=True, pad_first=False, # nosec @@ -49,3 +63,8 @@ def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, de @pytest.mark.parametrize('include_lengths', [False, True]) def test_batch_move_data_to_device_torchtext_include_lengths_cpu(include_lengths): test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, torch.device('cpu')) + + +@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") +def test_batch_move_data_to_device_torchtext_include_lengths_tpu(include_lengths): + test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, xm.xla_device()) From d04c2881644856031043178e7e82873817fbcdcb Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Tue, 28 Jul 2020 10:09:43 -0400 Subject: [PATCH 21/40] Adding test parameterization for TPU test (experimental) --- tests/utilities/test_apply_func_torchtext.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index 39b7af9d56e6b..6673185c5b363 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -65,6 +65,7 @@ def test_batch_move_data_to_device_torchtext_include_lengths_cpu(include_lengths test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, torch.device('cpu')) +@pytest.mark.parametrize('include_lengths', [False, True]) @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") def test_batch_move_data_to_device_torchtext_include_lengths_tpu(include_lengths): test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, xm.xla_device()) From cca6ff39b28364c370785c5e1e38f26d6727f513 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Tue, 28 Jul 2020 10:55:21 -0400 Subject: [PATCH 22/40] change import statement to limit what is imported for a TPU environment --- tests/utilities/test_apply_func_torchtext.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index 6673185c5b363..1721093075343 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -8,12 +8,12 @@ try: import torch_xla import torch_xla.core.xla_model as xm - import torch_xla.distributed.xla_multiprocessing as xmp - SERIAL_EXEC = xmp.MpSerialExecutor() - # TODO: The tests are aborted if the following lines are uncommented. Must be resolved with XLA team - # device = torch_xla.core.xla_model.xla_device() - # device_type = torch_xla.core.xla_model.xla_device_hw(device) - # TPU_AVAILABLE = device_type == 'TPU' + # import torch_xla.distributed.xla_multiprocessing as xmp + # SERIAL_EXEC = xmp.MpSerialExecutor() + # # TODO: The tests are aborted if the following lines are uncommented. Must be resolved with XLA team + # # device = torch_xla.core.xla_model.xla_device() + # # device_type = torch_xla.core.xla_model.xla_device_hw(device) + # # TPU_AVAILABLE = device_type == 'TPU' except ImportError: TPU_AVAILABLE = False else: From 5f3680de451b9eabb0e6a7788ded1dc01f358a6c Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Tue, 28 Jul 2020 10:58:39 -0400 Subject: [PATCH 23/40] made test work with TPU --- tests/utilities/test_apply_func_torchtext.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index 1721093075343..3a3f4c13f7a9a 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -8,12 +8,6 @@ try: import torch_xla import torch_xla.core.xla_model as xm - # import torch_xla.distributed.xla_multiprocessing as xmp - # SERIAL_EXEC = xmp.MpSerialExecutor() - # # TODO: The tests are aborted if the following lines are uncommented. Must be resolved with XLA team - # # device = torch_xla.core.xla_model.xla_device() - # # device_type = torch_xla.core.xla_model.xla_device_hw(device) - # # TPU_AVAILABLE = device_type == 'TPU' except ImportError: TPU_AVAILABLE = False else: From 08ebb6dc179ece8520d75b179625664b60340001 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Tue, 28 Jul 2020 13:47:47 -0400 Subject: [PATCH 24/40] Change to trigger CI --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bc2ef9ffa41a8..ef972d7fe35c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,7 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed TPU multi-core and Float16 ([#2632](https://github.com/PyTorchLightning/pytorch-lightning/pull/2632)) -- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) +- Fixed data transfer to GPU/TPU device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) ## [0.8.5] - 2020-07-09 From fa6b2f914ce2de0dad71fe038e057ee74ee40d8b Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Tue, 28 Jul 2020 14:00:34 -0400 Subject: [PATCH 25/40] Change to trigger CI --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ef972d7fe35c2..bc2ef9ffa41a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,7 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed TPU multi-core and Float16 ([#2632](https://github.com/PyTorchLightning/pytorch-lightning/pull/2632)) -- Fixed data transfer to GPU/TPU device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) +- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) ## [0.8.5] - 2020-07-09 From 940c34d48c0457080dc5ec5cf6834b3b01334c4a Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Tue, 28 Jul 2020 14:23:23 -0400 Subject: [PATCH 26/40] uncommented TPU test to check CI --- tests/utilities/test_apply_func_torchtext.py | 24 ++++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index 3a3f4c13f7a9a..b3cd3defa5119 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -5,13 +5,14 @@ from pytorch_lightning.utilities.apply_func import move_data_to_device -try: - import torch_xla - import torch_xla.core.xla_model as xm -except ImportError: - TPU_AVAILABLE = False -else: - TPU_AVAILABLE = True + +# try: +# import torch_xla +# import torch_xla.core.xla_model as xm +# except ImportError: +# TPU_AVAILABLE = False +# else: +# TPU_AVAILABLE = True def _get_torchtext_data_iterator(include_lengths=False): @@ -58,8 +59,7 @@ def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, de def test_batch_move_data_to_device_torchtext_include_lengths_cpu(include_lengths): test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, torch.device('cpu')) - -@pytest.mark.parametrize('include_lengths', [False, True]) -@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") -def test_batch_move_data_to_device_torchtext_include_lengths_tpu(include_lengths): - test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, xm.xla_device()) +# @pytest.mark.parametrize('include_lengths', [False, True]) +# @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") +# def test_batch_move_data_to_device_torchtext_include_lengths_tpu(include_lengths): +# test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, xm.xla_device()) From 584328a7c215bd6cd86fd20fe260afa6da09cbdd Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Tue, 28 Jul 2020 20:38:27 -0400 Subject: [PATCH 27/40] reenabling TPU test --- tests/utilities/test_apply_func_torchtext.py | 24 ++++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index b3cd3defa5119..3a3f4c13f7a9a 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -5,14 +5,13 @@ from pytorch_lightning.utilities.apply_func import move_data_to_device - -# try: -# import torch_xla -# import torch_xla.core.xla_model as xm -# except ImportError: -# TPU_AVAILABLE = False -# else: -# TPU_AVAILABLE = True +try: + import torch_xla + import torch_xla.core.xla_model as xm +except ImportError: + TPU_AVAILABLE = False +else: + TPU_AVAILABLE = True def _get_torchtext_data_iterator(include_lengths=False): @@ -59,7 +58,8 @@ def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, de def test_batch_move_data_to_device_torchtext_include_lengths_cpu(include_lengths): test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, torch.device('cpu')) -# @pytest.mark.parametrize('include_lengths', [False, True]) -# @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") -# def test_batch_move_data_to_device_torchtext_include_lengths_tpu(include_lengths): -# test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, xm.xla_device()) + +@pytest.mark.parametrize('include_lengths', [False, True]) +@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") +def test_batch_move_data_to_device_torchtext_include_lengths_tpu(include_lengths): + test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, xm.xla_device()) From ae71b148b9adbacdb2ab496918ca9673f55649e8 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Wed, 29 Jul 2020 08:44:22 -0400 Subject: [PATCH 28/40] small change to trigger CI build --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bc2ef9ffa41a8..ef972d7fe35c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,7 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed TPU multi-core and Float16 ([#2632](https://github.com/PyTorchLightning/pytorch-lightning/pull/2632)) -- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) +- Fixed data transfer to GPU/TPU device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) ## [0.8.5] - 2020-07-09 From 34201bcd5872e73b61c63405c86d6a9c6f31a6f9 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Wed, 29 Jul 2020 10:22:29 -0400 Subject: [PATCH 29/40] small change to trigger CI build --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ef972d7fe35c2..f563a14626980 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,7 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed TPU multi-core and Float16 ([#2632](https://github.com/PyTorchLightning/pytorch-lightning/pull/2632)) -- Fixed data transfer to GPU/TPU device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) +- Fixed data transfer to GPU/TPU device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) ## [0.8.5] - 2020-07-09 From a53a469594f398663c73b5ee9a14399e43b39156 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Wed, 29 Jul 2020 10:58:30 -0400 Subject: [PATCH 30/40] small change to trigger CI build --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f563a14626980..ef972d7fe35c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,7 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed TPU multi-core and Float16 ([#2632](https://github.com/PyTorchLightning/pytorch-lightning/pull/2632)) -- Fixed data transfer to GPU/TPU device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) +- Fixed data transfer to GPU/TPU device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) ## [0.8.5] - 2020-07-09 From 647e44be7d54b082642590c3607db0a3c1337d8d Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Wed, 29 Jul 2020 11:23:02 -0400 Subject: [PATCH 31/40] adding tests/utilities/test_apply_func_torchtext.py to CI TPU test --- dockers/tpu-tests/tpu_test_cases.jsonnet | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dockers/tpu-tests/tpu_test_cases.jsonnet b/dockers/tpu-tests/tpu_test_cases.jsonnet index 9294727f420b6..6e30c4d8634df 100644 --- a/dockers/tpu-tests/tpu_test_cases.jsonnet +++ b/dockers/tpu-tests/tpu_test_cases.jsonnet @@ -21,7 +21,7 @@ local tputests = base.BaseTest { command: utils.scriptCommand( ||| cd pytorch-lightning - coverage run --source=pytorch_lightning -m pytest tests/models/test_tpu.py -v + coverage run --source=pytorch_lightning -m pytest tests/models/test_tpu.py tests/utilities/test_apply_func_torchtext.py -v test_exit_code=$? echo "\n||| END PYTEST LOGS |||\n" coverage xml From ff080dac70cf1ce4a8dc6095527948d6812d1f14 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Wed, 29 Jul 2020 11:52:17 -0400 Subject: [PATCH 32/40] try to make test not skipped on CI with TPU --- tests/utilities/test_apply_func_torchtext.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index 3a3f4c13f7a9a..3a5d76d37eb73 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -7,7 +7,6 @@ try: import torch_xla - import torch_xla.core.xla_model as xm except ImportError: TPU_AVAILABLE = False else: @@ -62,4 +61,4 @@ def test_batch_move_data_to_device_torchtext_include_lengths_cpu(include_lengths @pytest.mark.parametrize('include_lengths', [False, True]) @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") def test_batch_move_data_to_device_torchtext_include_lengths_tpu(include_lengths): - test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, xm.xla_device()) + test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, torch_xla._XLAC._xla_get_default_device()) From 43a5ea9912288d42324c8d2bdf51c7b9236284e7 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Wed, 29 Jul 2020 12:21:30 -0400 Subject: [PATCH 33/40] remove testing on TPU --- CHANGELOG.md | 2 +- dockers/tpu-tests/tpu_test_cases.jsonnet | 2 +- tests/models/test_tpu.py | 34 ++++++++++---------- tests/utilities/test_apply_func_torchtext.py | 13 -------- 4 files changed, 19 insertions(+), 32 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ef972d7fe35c2..bc2ef9ffa41a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,7 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed TPU multi-core and Float16 ([#2632](https://github.com/PyTorchLightning/pytorch-lightning/pull/2632)) -- Fixed data transfer to GPU/TPU device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) +- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) ## [0.8.5] - 2020-07-09 diff --git a/dockers/tpu-tests/tpu_test_cases.jsonnet b/dockers/tpu-tests/tpu_test_cases.jsonnet index 6e30c4d8634df..9294727f420b6 100644 --- a/dockers/tpu-tests/tpu_test_cases.jsonnet +++ b/dockers/tpu-tests/tpu_test_cases.jsonnet @@ -21,7 +21,7 @@ local tputests = base.BaseTest { command: utils.scriptCommand( ||| cd pytorch-lightning - coverage run --source=pytorch_lightning -m pytest tests/models/test_tpu.py tests/utilities/test_apply_func_torchtext.py -v + coverage run --source=pytorch_lightning -m pytest tests/models/test_tpu.py -v test_exit_code=$? echo "\n||| END PYTEST LOGS |||\n" coverage xml diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 89d5dce279840..94d497db782d9 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -219,23 +219,23 @@ def test_early_stop_checkpoints_on_tpu(tmpdir): assert torch_xla._XLAC._xla_get_default_device() == 'xla:1' -@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") -@pl_multi_process_test -def test_early_stop_checkpoints_on_tpu(tmpdir): - """Test if single TPU core training works""" - model = EvalModelTemplate() - trainer = Trainer( - early_stop_callback=True, - default_root_dir=tmpdir, - progress_bar_refresh_rate=0, - max_epochs=50, - limit_train_batches=10, - limit_val_batches=10, - distributed_backend='tpu', - tpu_cores=[5], - ) - trainer.fit(model) - assert torch_xla._XLAC._xla_get_default_device() == 'xla:5' +# @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") +# @pl_multi_process_test +# def test_early_stop_checkpoints_on_tpu(tmpdir): +# """Test if single TPU core training works""" +# model = EvalModelTemplate() +# trainer = Trainer( +# early_stop_callback=True, +# default_root_dir=tmpdir, +# progress_bar_refresh_rate=0, +# max_epochs=50, +# limit_train_batches=10, +# limit_val_batches=10, +# distributed_backend='tpu', +# tpu_cores=[5], +# ) +# trainer.fit(model) +# assert torch_xla._XLAC._xla_get_default_device() == 'xla:5' @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index 3a5d76d37eb73..f17184fbe2df1 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -5,13 +5,6 @@ from pytorch_lightning.utilities.apply_func import move_data_to_device -try: - import torch_xla -except ImportError: - TPU_AVAILABLE = False -else: - TPU_AVAILABLE = True - def _get_torchtext_data_iterator(include_lengths=False): text_field = torchtext.data.Field(sequential=True, pad_first=False, # nosec @@ -56,9 +49,3 @@ def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, de @pytest.mark.parametrize('include_lengths', [False, True]) def test_batch_move_data_to_device_torchtext_include_lengths_cpu(include_lengths): test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, torch.device('cpu')) - - -@pytest.mark.parametrize('include_lengths', [False, True]) -@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") -def test_batch_move_data_to_device_torchtext_include_lengths_tpu(include_lengths): - test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, torch_xla._XLAC._xla_get_default_device()) From 73583c1386f4c2eea9c973d61c4194c9aaf55983 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Wed, 29 Jul 2020 13:29:20 -0400 Subject: [PATCH 34/40] undo an accidental change to test_tpu.py (file should not have been touched) --- tests/models/test_tpu.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 94d497db782d9..89d5dce279840 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -219,23 +219,23 @@ def test_early_stop_checkpoints_on_tpu(tmpdir): assert torch_xla._XLAC._xla_get_default_device() == 'xla:1' -# @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") -# @pl_multi_process_test -# def test_early_stop_checkpoints_on_tpu(tmpdir): -# """Test if single TPU core training works""" -# model = EvalModelTemplate() -# trainer = Trainer( -# early_stop_callback=True, -# default_root_dir=tmpdir, -# progress_bar_refresh_rate=0, -# max_epochs=50, -# limit_train_batches=10, -# limit_val_batches=10, -# distributed_backend='tpu', -# tpu_cores=[5], -# ) -# trainer.fit(model) -# assert torch_xla._XLAC._xla_get_default_device() == 'xla:5' +@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") +@pl_multi_process_test +def test_early_stop_checkpoints_on_tpu(tmpdir): + """Test if single TPU core training works""" + model = EvalModelTemplate() + trainer = Trainer( + early_stop_callback=True, + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + max_epochs=50, + limit_train_batches=10, + limit_val_batches=10, + distributed_backend='tpu', + tpu_cores=[5], + ) + trainer.fit(model) + assert torch_xla._XLAC._xla_get_default_device() == 'xla:5' @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") From b9297118c0bf9883b123394db63b6c215f2f8c2e Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Wed, 29 Jul 2020 14:08:00 -0400 Subject: [PATCH 35/40] small change to trigger CI build --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bc2ef9ffa41a8..e5dc0ea40688b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,7 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed TPU multi-core and Float16 ([#2632](https://github.com/PyTorchLightning/pytorch-lightning/pull/2632)) -- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) +- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) ## [0.8.5] - 2020-07-09 From 68e2152bf03949440fa2559c163c969281b5cfe7 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Wed, 29 Jul 2020 14:23:55 -0400 Subject: [PATCH 36/40] small change to trigger CI build --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e5dc0ea40688b..bc2ef9ffa41a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,7 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed TPU multi-core and Float16 ([#2632](https://github.com/PyTorchLightning/pytorch-lightning/pull/2632)) -- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) +- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) ## [0.8.5] - 2020-07-09 From c97cd69d5ac4f81949f12c427cee75fad6b627fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 29 Jul 2020 21:31:02 +0200 Subject: [PATCH 37/40] Update tests/utilities/test_apply_func_torchtext.py --- tests/utilities/test_apply_func_torchtext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index f17184fbe2df1..f7d074f0b299f 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -29,7 +29,7 @@ def _get_torchtext_data_iterator(include_lengths=False): @pytest.mark.parametrize('include_lengths', [False, True]) -@pytest.mark.parametrize(['device'], [pytest.param(torch.device('cuda', 0))]) +@pytest.mark.parametrize(['device'], [pytest.param(torch.device('cpu')), pytest.param(torch.device('cuda', 0))]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test assumes GPU machine") def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, device): data_iterator, _ = _get_torchtext_data_iterator(include_lengths=include_lengths) From 16850770e0fd29775bff14e2e6cda473e7040a4c Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Wed, 29 Jul 2020 15:54:50 -0400 Subject: [PATCH 38/40] Revert to previous version --- tests/utilities/test_apply_func_torchtext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index f7d074f0b299f..f17184fbe2df1 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -29,7 +29,7 @@ def _get_torchtext_data_iterator(include_lengths=False): @pytest.mark.parametrize('include_lengths', [False, True]) -@pytest.mark.parametrize(['device'], [pytest.param(torch.device('cpu')), pytest.param(torch.device('cuda', 0))]) +@pytest.mark.parametrize(['device'], [pytest.param(torch.device('cuda', 0))]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test assumes GPU machine") def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, device): data_iterator, _ = _get_torchtext_data_iterator(include_lengths=include_lengths) From 8a7d68bb97644871ee8266886f8a2e84bf6d9388 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 29 Jul 2020 23:48:50 +0200 Subject: [PATCH 39/40] Apply suggestions from code review --- tests/utilities/test_apply_func_torchtext.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index f17184fbe2df1..9ea29420788d7 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -15,9 +15,10 @@ def _get_torchtext_data_iterator(include_lengths=False): example2 = Example.fromdict({"text": "b c a a"}, {"text": ("text", text_field)}) example3 = Example.fromdict({"text": "c b a"}, {"text": ("text", text_field)}) - dataset = torchtext.data.Dataset([example1, example2, example3], - {"text": text_field} - ) + dataset = torchtext.data.Dataset( + [example1, example2, example3], + {"text": text_field}, + ) text_field.build_vocab(dataset) iterator = torchtext.data.Iterator(dataset, batch_size=3, From 3c04090d95b29a0d85a12312846b6850baefd1c0 Mon Sep 17 00:00:00 2001 From: Thomas Schaaf Date: Wed, 29 Jul 2020 21:53:15 -0400 Subject: [PATCH 40/40] Change to trigger CI --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 85a902c62bbcd..2f519c7de9b38 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,7 +43,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed test metrics not being logged with `LoggerCollection` ([#2723](https://github.com/PyTorchLightning/pytorch-lightning/pull/2723)) -- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) +- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) ## [0.8.5] - 2020-07-09