Skip to content

Commit

Permalink
avoid unnecessary workers with sequential CombinedLoader (#17639)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

(cherry picked from commit c3ad756)
  • Loading branch information
mukhery authored and lantiga committed Jun 2, 2023
1 parent a81a956 commit 64d84cc
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 7 deletions.
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- `CombinedLoader` only starts DataLoader workers when necessary when operating in sequential mode ([#17639](https://github.com/Lightning-AI/lightning/pull/17639))


- Fixed a potential bug with uploading model checkpoints to Neptune.ai by uploading files from stream ([#17430](https://github.com/Lightning-AI/lightning/pull/17430))


Expand Down
15 changes: 12 additions & 3 deletions src/lightning/pytorch/utilities/combined_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def limits(self, limits: Optional[List[Union[int, float]]]) -> None:
self._limits = limits

def __next__(self) -> Tuple[Any, int, int]:
n = len(self.iterators)
n = len(self.iterables)
if n == 0 or self._iterator_idx >= n:
raise StopIteration

Expand All @@ -120,7 +120,7 @@ def __next__(self) -> Tuple[Any, int, int]:
raise StopIteration

try:
out = next(self.iterators[self._iterator_idx])
out = next(self.iterators[0])
index = self._idx
self._idx += 1
# batch, batch_idx, dataloader_idx
Expand All @@ -131,19 +131,28 @@ def __next__(self) -> Tuple[Any, int, int]:
return self.__next__()

def __iter__(self) -> Self:
super().__iter__()
self._iterator_idx = 0
self._idx = 0
self._load_current_iterator()
return self

def reset(self) -> None:
super().reset()
self._iterator_idx = 0
self._idx = 0

def _load_current_iterator(self) -> None:
# Load a single DataLoader, prevents multiple sets of workers from starting unnecessarily
if self._iterator_idx < len(self.iterables):
self.iterators = [iter(self.iterables[self._iterator_idx])]
else:
# No more iterables to step through, return an empty list
self.iterators = []

def _use_next_iterator(self) -> None:
self._iterator_idx += 1
self._idx = 0
self._load_current_iterator()


class _MaxSize(_ModeIterator[List]):
Expand Down
6 changes: 2 additions & 4 deletions tests/tests_pytorch/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,8 +844,7 @@ def _get_iterator(self):
# iterable check
0,
# epoch ends
1,
# teardown
0,
1,
]
else:
Expand All @@ -855,9 +854,8 @@ def _get_iterator(self):
# iterable check
0,
# epoch ends
0,
1,
2,
# teardown
3,
]
assert val_dataloader.shutdown_workers_epochs == expected
34 changes: 34 additions & 0 deletions tests/tests_pytorch/utilities/test_combined_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,40 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader
assert idx == expected - 1


@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle", "max_size", "sequential"])
def test_combined_loader_simultaneous_workers(mode):
"""Test `CombinedLoader` to check how it initializes dataloader workers."""

class TestDataLoader(DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.workers_active = False

def _get_iterator(self):
self.workers_active = True
return super()._get_iterator()

def _shutdown_workers(self):
self.workers_active = False
super()._shutdown_workers()

loaders = [
TestDataLoader(range(10), batch_size=2, num_workers=0),
TestDataLoader(range(20), batch_size=2, num_workers=0),
]
combined_loader = CombinedLoader(loaders, mode)
# Start the dataloader
_ = iter(combined_loader)

workers_active = []
for loader in loaders:
workers_active.append(loader.workers_active)

# Sequential only starts the first dataloader, other modes start both
expected = [True, False] if mode == "sequential" else [True, True]
assert workers_active == expected


@pytest.mark.parametrize(
("limits", "expected"),
[
Expand Down

0 comments on commit 64d84cc

Please sign in to comment.