diff --git a/docs/source/guides/data.rst b/docs/source/guides/data.rst index dbfba8598a36f..1eda7ac6291ae 100644 --- a/docs/source/guides/data.rst +++ b/docs/source/guides/data.rst @@ -393,6 +393,9 @@ option when using sequential data. to ``limit_{mode}_batches``, if it is set to 1.0 it will run for the whole dataset, otherwise it will throw an exception. Here ``mode`` can be train/val/test/predict. +When iterable datasets are used, Lightning will pre-fetch 1 batch (in addition to the current batch) so it can detect +when the training will stop and run validation if necessary. + .. testcode:: # IterableDataset diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index cb0e79ae89448..8f35b39d60fdd 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -81,6 +81,13 @@ def dataloaders(self) -> Sequence[DataLoader]: raise RuntimeError("Dataloaders should be available.") return dataloaders + @property + def prefetch_batches(self) -> int: + batches = self.trainer.num_test_batches if self.trainer.testing else self.trainer.num_val_batches + is_unsized = batches[self.current_dataloader_idx] == float("inf") + inter_batch_parallelism = os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1" + return 1 if is_unsized or inter_batch_parallelism else 0 + def connect(self, epoch_loop: EvaluationEpochLoop) -> None: # type: ignore[override] """Connect the evaluation epoch loop with this loop.""" self.epoch_loop = epoch_loop @@ -121,7 +128,7 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None: void(*args, **kwargs) data_fetcher_cls = _select_data_fetcher_type(self.trainer) - self._data_fetcher = data_fetcher_cls() + self._data_fetcher = data_fetcher_cls(prefetch_batches=self.prefetch_batches) # hook self._on_evaluation_model_eval() diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 8174b34a3cfa5..10388615fcc94 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -85,6 +85,8 @@ def on_run_start( # type: ignore[override] self._reload_dataloader_state_dict(data_fetcher) # creates the iterator inside the fetcher but returns `self` self._data_fetcher = cast(AbstractDataFetcher, iter(data_fetcher)) + # add the previous `fetched` value to properly track `is_last_batch` with no prefetching + data_fetcher.fetched += self.batch_progress.current.ready def advance( # type: ignore[override] self, diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index bcb818bf72bff..15a561e99569d 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -142,7 +142,9 @@ def reset(self) -> None: def on_run_start(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[override] self._reload_dataloader_state_dict(data_fetcher) - iter(data_fetcher) # creates the iterator inside the fetcher + _ = iter(data_fetcher) # creates the iterator inside the fetcher + # add the previous `fetched` value to properly track `is_last_batch` with no prefetching + data_fetcher.fetched += self.batch_progress.current.ready def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[override] """Runs a single training batch. diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index a942f3bf75a99..0c9c68b24f4a0 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -149,6 +149,12 @@ def restarting(self, restarting: bool) -> None: restarting &= finished_before_on_train_end Loop.restarting.fset(self, restarting) # call the parent setter + @property + def prefetch_batches(self) -> int: + is_unsized = self.trainer.num_training_batches == float("inf") + inter_batch_parallelism = os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1" + return 1 if is_unsized or inter_batch_parallelism else 0 + @property def _skip_backward(self) -> bool: """Determines whether the loop will skip backward during automatic optimization.""" @@ -213,8 +219,9 @@ def on_run_start(self) -> None: # type: ignore[override] """Calls the ``on_train_start`` hook.""" # reset train dataloader and val dataloader self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module) + data_fetcher_cls = _select_data_fetcher(self.trainer) - self._data_fetcher = data_fetcher_cls() + self._data_fetcher = data_fetcher_cls(prefetch_batches=self.prefetch_batches) self._is_fresh_start_epoch = True self._results.to(device=self.trainer.lightning_module.device) diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index 10e6a52135cfe..3f59a8f017cc7 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -89,17 +89,13 @@ def has_iterable_dataset(dataloader: DataLoader) -> bool: def has_len(dataloader: Union[DataLoader, Iterable]) -> bool: """Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or - infinite dataloader. - - Raises: - ValueError: - If the length of Dataloader is 0, as it requires at least one batch - """ - + infinite dataloader.""" try: # try getting the length if len(dataloader) == 0: - raise ValueError("`Dataloader` returned 0 length. Please make sure that it returns at least 1 batch") + rank_zero_warn( + f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention." + ) has_len = True except TypeError: has_len = False @@ -122,30 +118,27 @@ def has_len_all_ranks( model: Union["pl.LightningModule", "pl.LightningDataModule"], ) -> bool: """Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or - infinite dataloader. - - Raises: - ValueError: - If the length of Dataloader is 0, as it requires at least one batch - """ + infinite dataloader.""" try: - total_length = training_type.reduce(torch.tensor(len(dataloader)).to(model.device), reduce_op="sum") local_length = len(dataloader) + total_length = training_type.reduce(torch.tensor(local_length).to(model.device), reduce_op="sum") if total_length == 0: - raise MisconfigurationException( - "Total length of `Dataloader` across ranks is zero. Please make sure that it returns at least 1 batch." + rank_zero_warn( + f"Total length of `{dataloader.__class__.__name__}` across ranks is zero." + " Please make sure this was your intention." ) if total_length > 0 and local_length == 0: if model.allow_zero_length_dataloader_with_multiple_devices: rank_zero_warn( - "Total length of `Dataloader` across ranks is zero, but local rank has zero length." - " Please be cautious of uneven batch length." + f"Total length of `{dataloader.__class__.__name__}` across ranks is zero, but local rank has zero" + " length. Please be cautious of uneven batch length." ) has_len = False else: raise MisconfigurationException( - "`Dataloader` within local rank has zero length. Please make sure that it returns at least 1 batch." + f"`{dataloader.__class__.__name__}` within local rank has zero length." + " Please make sure that it returns at least 1 batch." ) else: has_len = True diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 2d4aa533b2895..ea039fcb23e19 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -15,7 +15,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable, Iterator from copy import deepcopy -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Sized, Tuple import torch from torch.utils.data.dataloader import DataLoader @@ -30,6 +30,7 @@ MergedIteratorState, patch_dataloader_iterator, ) +from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training @@ -79,6 +80,8 @@ def __init__(self, prefetch_batches: int = 0) -> None: def setup(self, dataloader: Iterable, **kwargs: Any) -> None: self._add_capture_metadata_collate(dataloader) self._dataloader = dataloader + _patch_dataloader_get_iterators() + self._attach_data_fetcher() @property def dataloader(self) -> Iterable: @@ -172,8 +175,6 @@ def _attach_data_fetcher_fn(loader: DataLoader) -> None: def __iter__(self) -> "AbstractDataFetcher": self.reset() - self._attach_data_fetcher() - _patch_dataloader_get_iterators() self.dataloader_iter = iter(self.dataloader) self._apply_patch() self.prefetching() @@ -205,7 +206,7 @@ class DataFetcher(AbstractDataFetcher): Args: prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessary to properly track - whether a batch is the last one (available with :attr:`self.done`). + whether a batch is the last one (available with :attr:`self.done`) under any training setup. store_on_device: Whether to store the pre-fetched batches on device. """ @@ -214,11 +215,13 @@ def __init__(self, prefetch_batches: int = 1, store_on_device: bool = True) -> N self.store_on_device = store_on_device self.batch_to_device: Callable[[Any], Any] = _no_op_batch_to_device self.batches: List[Any] = [] + self._has_len = False def setup( # type: ignore[override] self, dataloader: Iterable, batch_to_device: Optional[Callable[[Any], Any]] = None ) -> None: super().setup(dataloader) + self._has_len = has_len(dataloader) if batch_to_device is not None: self.batch_to_device = batch_to_device @@ -233,6 +236,9 @@ def prefetching(self) -> None: try: self._fetch_next_batch(iterator) except StopIteration: + # this would only happen when prefetch_batches > the number of batches available and makes + # `fetching_function` jump directly to the empty iterator case without trying to fetch again + self.done = True break def fetching_function(self) -> Any: @@ -266,6 +272,11 @@ def _fetch_next_batch(self, iterator: Iterator) -> None: start_output = self.on_fetch_start() batch = next(iterator) self.fetched += 1 + if not self.prefetch_batches and self._has_len: + # when we don't prefetch but the dataloader is sized, we use the length for `done` + dataloader = self.dataloader + assert isinstance(dataloader, Sized) # `_has_len` is True + self.done = self.fetched >= len(dataloader) self.on_fetch_end(batch, start_output) def move_to_device(self, batch: Any) -> Any: @@ -360,7 +371,8 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None: ... """ - def __init__(self) -> None: + def __init__(self, prefetch_batches: int = 0) -> None: + # prefetch batches is not used for this class super().__init__() self.store_on_device = False diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index d578cecdab01e..cfc347293484c 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -648,16 +648,12 @@ def train_dataloader(self): "ready": n_epochs, "started": n_epochs, "processed": n_epochs, - # TODO: the following "-1" offset will be fixed by - # https://github.com/PyTorchLightning/pytorch-lightning/pull/8578 "completed": n_epochs - 1, }, "current": { "ready": n_epochs, "started": n_epochs, "processed": n_epochs, - # TODO: the following "-1" offset will be fixed by - # https://github.com/PyTorchLightning/pytorch-lightning/pull/8578 "completed": n_epochs - 1, }, }, @@ -956,8 +952,6 @@ def val_dataloader(self): # totals are increased by 1 (the failed batch which never completed) expected = state_dict.copy() - # TODO: `is_last_batch` is not correct on reload, the next line should not be necessary - expected["epoch_loop.batch_progress"]["is_last_batch"] = val_check_interval == 1.0 assert state_dict_after_restart["epoch_loop.batch_progress"] == expected["epoch_loop.batch_progress"] val_dl_progress = "epoch_loop.val_loop.dataloader_progress" diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index d8e4989bb34d6..08d54e05bfffe 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -516,20 +516,16 @@ def test_mixing_of_dataloader_options(tmpdir, ckpt_path): assert len(trainer.test_dataloaders) == 1 -def test_error_on_zero_len_dataloader(tmpdir): - """Test that error is raised if a zero-length dataloader is defined.""" - - class CustomBoringModel(BoringModel): - def train_dataloader(self): - return DataLoader(RandomDataset(32, 0)) - - model = CustomBoringModel() +def test_warning_on_zero_len_dataloader(tmpdir): + """Test that a warning is raised if a zero-length dataloader is defined.""" + model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, fast_dev_run=1, ) - with pytest.raises(ValueError, match="returned 0 length. .* at least 1 batch"): - trainer.fit(model) + dataloader = DataLoader(RandomDataset(32, 0)) + with pytest.warns(UserWarning, match="returned 0 length"): + trainer.fit(model, dataloader) @RunIf(skip_windows=True) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 630d384660527..0258fac2823db 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1452,7 +1452,7 @@ def load_state_dict(self, state_dict): class RandomFaultTolerantSampler(RandomSampler): - def __init__(self, *args, seed: int = 0, generator=None, **kwargs): + def __init__(self, *args, seed: int = 0, **kwargs): generator = torch.Generator().manual_seed(seed) super().__init__(*args, generator=generator, **kwargs) self.counter = 0 @@ -1558,7 +1558,7 @@ def configure_optimizers(self): seed_everything(42) model = TestModel(should_fail=True) trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval) - with suppress(CustomException): + with pytest.raises(CustomException): trainer.fit(model) trainer.train_dataloader = None failed_batches = model.batches diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py index 547b151c838c7..f64ba026b0a01 100644 --- a/tests/utilities/test_data.py +++ b/tests/utilities/test_data.py @@ -93,7 +93,7 @@ def __iter__(self): def test_has_len(): assert has_len(DataLoader(RandomDataset(1, 1))) - with pytest.raises(ValueError, match="`Dataloader` returned 0 length."): + with pytest.warns(UserWarning, match="`DataLoader` returned 0 length."): assert has_len(DataLoader(RandomDataset(0, 0))) assert not has_len(DataLoader(RandomIterableDataset(1, 1))) @@ -112,8 +112,8 @@ def test_has_len_all_rank(): trainer = Trainer(fast_dev_run=True) model = BoringModel() - with pytest.raises(MisconfigurationException, match="Total length of `Dataloader` across ranks is zero."): - assert not has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.strategy, model) + with pytest.warns(UserWarning, match="Total length of `DataLoader` across ranks is zero."): + assert has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.strategy, model) assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.strategy, model) diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 4f76e2ab917fe..11ed3213c70b1 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -18,7 +18,6 @@ import pytest import torch -from torch import tensor from torch.utils.data import DataLoader, Dataset, IterableDataset from pytorch_lightning import Callback, LightningDataModule, Trainer @@ -30,57 +29,74 @@ from tests.helpers.runif import RunIf +class IterDataset(IterableDataset): + def __iter__(self): + yield 1 + yield 2 + yield 3 + + +class SizedDataset(Dataset): + def __len__(self): + return 3 + + def __getitem__(self, idx): + return idx + 1 + + @pytest.mark.parametrize("use_combined_loader", [False, True]) -def test_prefetch_iterator(use_combined_loader): - """Test the DataFetcher with PyTorch IterableDataset.""" - - class IterDataset(IterableDataset): - def __iter__(self): - yield 1 - yield 2 - yield 3 - - for prefetch_batches in range(5): - iterator = DataFetcher(prefetch_batches=prefetch_batches) - assert iterator.prefetch_batches == prefetch_batches - - if use_combined_loader: - loader = CombinedLoader([DataLoader(IterDataset()), DataLoader(IterDataset())]) - else: - loader = DataLoader(IterDataset()) - iterator.setup(loader) - - def generate(): - generated = [ - (iterator.fetched, data, iterator.done) for i, data in enumerate(iterator, prefetch_batches + 1) - ] - assert iterator.fetched == 3 - assert iterator.done - return generated - - is_last_batch = [False, False, prefetch_batches > 0] - fetched = list(range(prefetch_batches + 1, 4)) - fetched += [3] * (3 - len(fetched)) - if use_combined_loader: - batches = [[tensor(1), tensor(1)], [tensor(2), tensor(2)], [tensor(3), tensor(3)]] - else: - batches = [1, 2, 3] - expected = list(zip(fetched, batches, is_last_batch)) - assert len(expected) == 3 - - assert generate() == expected - # validate reset works properly. - assert generate() == expected - assert iterator.fetched == 3 - - class EmptyIterDataset(IterableDataset): - def __iter__(self): - return iter([]) - - loader = DataLoader(EmptyIterDataset()) - iterator = DataFetcher() - iterator.setup(loader) - assert not list(iterator) +@pytest.mark.parametrize("dataset_cls", [IterDataset, SizedDataset]) +@pytest.mark.parametrize("prefetch_batches", list(range(5))) +def test_prefetch_iterator(use_combined_loader, dataset_cls, prefetch_batches): + fetcher = DataFetcher(prefetch_batches=prefetch_batches) + assert fetcher.prefetch_batches == prefetch_batches + + if use_combined_loader: + loader = CombinedLoader([DataLoader(dataset_cls()), DataLoader(dataset_cls())]) + else: + loader = DataLoader(dataset_cls()) + fetcher.setup(loader) + + def generate(): + generated = [(fetcher.fetched, data, fetcher.done) for data in fetcher] + assert fetcher.fetched == 3 + assert fetcher.done + return generated + + # we can only know the last batch with sized iterables or when we prefetch + is_last_batch = [False, False, prefetch_batches > 0 or dataset_cls is SizedDataset] + fetched = list(range(prefetch_batches + 1, 4)) + fetched += [3] * (3 - len(fetched)) + batches = [[1, 1], [2, 2], [3, 3]] if use_combined_loader else [1, 2, 3] + expected = list(zip(fetched, batches, is_last_batch)) + assert len(expected) == 3 + + assert generate() == expected + # validate reset works properly. + assert generate() == expected + assert fetcher.fetched == 3 + + +class EmptyIterDataset(IterableDataset): + def __iter__(self): + return iter([]) + + +class EmptySizedDataset(Dataset): + def __len__(self): + return 0 + + +@pytest.mark.parametrize("dataset_cls", [EmptyIterDataset, EmptySizedDataset]) +@pytest.mark.parametrize("prefetch_batches", list(range(2))) +def test_empty_prefetch_iterator(dataset_cls, prefetch_batches): + loader = DataLoader(dataset_cls()) + fetcher = DataFetcher(prefetch_batches=prefetch_batches) + fetcher.setup(loader) + + assert not fetcher.done + assert not list(fetcher) + assert fetcher.done def test_misconfiguration_error(): @@ -188,7 +204,7 @@ def __init__(self, check_inter_batch): def on_train_epoch_end(self, trainer, lightning_module): fetcher = trainer.fit_loop._data_fetcher assert isinstance(fetcher, InterBatchParallelDataFetcher if self._check_inter_batch else DataFetcher) - assert fetcher.prefetch_batches == 1 + assert fetcher.prefetch_batches == int(self._check_inter_batch) trainer_kwargs = dict( default_root_dir=tmpdir, @@ -269,14 +285,19 @@ def training_epoch_end(self, *_): @RunIf(min_torch="1.8.0") def test_fetching_dataloader_iter_running_stages(fn, tmpdir): class TestModel(BoringModel): - def validation_step(self, dataloader_iter, batch_idx): - assert isinstance(self.trainer.validate_loop._data_fetcher, DataLoaderIterDataFetcher) + def fetch(self, data_fetcher, dataloader_iter, batch_idx): + assert isinstance(data_fetcher, DataLoaderIterDataFetcher) + assert data_fetcher.fetched == batch_idx batch = next(dataloader_iter) + assert data_fetcher.fetched == batch_idx + 1 + return batch + + def validation_step(self, dataloader_iter, batch_idx): + batch = self.fetch(self.trainer.validate_loop._data_fetcher, dataloader_iter, batch_idx) return super().validation_step(batch, batch_idx) def test_step(self, dataloader_iter, batch_idx): - assert isinstance(self.trainer.test_loop._data_fetcher, DataLoaderIterDataFetcher) - batch = next(dataloader_iter) + batch = self.fetch(self.trainer.test_loop._data_fetcher, dataloader_iter, batch_idx) return super().test_step(batch, batch_idx) model = TestModel()