Skip to content

Commit

Permalink
Call set_epoch on Dataloader.batch_sampler if defined (#3124)
Browse files Browse the repository at this point in the history
* adding _get_dist_sampler function

* apply _get_dist_sampler across all trainer.py

* Switch condition to check if batch_sampler if filled first since sampler is always defined

* adding docstring

* removing antipattenr from _get_sampler and linting

* removed default attrs from exception in the attrs dict (#3126)

* Bump coverage[toml] from 7.4.3 to 7.4.4 (#3121)

Bumps [coverage[toml]](https://github.com/nedbat/coveragepy) from 7.4.3 to 7.4.4.
- [Release notes](https://github.com/nedbat/coveragepy/releases)
- [Changelog](https://github.com/nedbat/coveragepy/blob/master/CHANGES.rst)
- [Commits](nedbat/coveragepy@7.4.3...7.4.4)

---
updated-dependencies:
- dependency-name: coverage[toml]
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Refactor initialization (#3127)

Co-authored-by: practicinginhell <oneforall1412@gmail.com>

* Bump databricks sdk version (#3128)

* Update packaging requirement from <23.3,>=21.3.0 to >=21.3.0,<24.1 (#3122)

Updates the requirements on [packaging](https://github.com/pypa/packaging) to permit the latest version.
- [Release notes](https://github.com/pypa/packaging/releases)
- [Changelog](https://github.com/pypa/packaging/blob/main/CHANGELOG.rst)
- [Commits](pypa/packaging@21.3...24.0)

---
updated-dependencies:
- dependency-name: packaging
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>

* remove rng from save weights only ckpt (#3129)

* More compression options (#3118)

* fix documentation about which file extensions perform compression

* refactor tar handling when writing checkpoint files

* add CLI compressor options to checkpoint saving

* update documentation

* restructure CLI compressor handling. Added more tests

* fix skipping test when compressor not installed

* add documentation and fix some code style issues

* use correct file extension for zstd

* add missing full stop

* fix failing tests

* remove type ignore

Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>

* combine imports into one

* fix capitalisation

Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>

---------

Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>

* Only broadcast distcp files (#3130)

* filter

* remove metadata skip

* Bump version to 0.21 (#3132)

* bump version to 021

* fix lint

* fix docstring

* remove old ones

* Adding a test for checking the update of epoch on batch_sampler

* proper formating

* adding _get_dist_sampler function

* apply _get_dist_sampler across all trainer.py

* Switch condition to check if batch_sampler if filled first since sampler is always defined

* adding docstring

* removing antipattenr from _get_sampler and linting

* Adding a test for checking the update of epoch on batch_sampler

* proper formating

* Chang _get_sampler to _get_distributed_sampler to bypass torch default batch_sampler

* Move batch_sampler instanciation for test

* fix batch_sampler setting in eval_loop

* trainer

* fix

* lint

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: Jane Zhang <jane.zhang@databricks.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Practicinginhell <147235329+Practicinginhell@users.noreply.github.com>
Co-authored-by: practicinginhell <oneforall1412@gmail.com>
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
Co-authored-by: Evan Racah <evan@mosaicml.com>
Co-authored-by: Matt Broadway <mattdbway@gmail.com>
Co-authored-by: Alexandre Ghelfi <alexandre.ghelfi@helsing.ai>
  • Loading branch information
10 people authored Apr 2, 2024
1 parent 3a0e2b2 commit a9b9791
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 21 deletions.
33 changes: 22 additions & 11 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,15 @@ def _generate_run_name() -> str:
return generated_run_name


def _get_distributed_sampler(dataloader: DataLoader) -> Optional[DistributedSampler]:
"""Fetch a distributed sampler from a `dataloader` if it exists."""
if isinstance(dataloader.batch_sampler, DistributedSampler):
return dataloader.batch_sampler
if isinstance(dataloader.sampler, DistributedSampler):
return dataloader.sampler
return None


class Trainer:
"""Train models with Composer algorithms.
Expand Down Expand Up @@ -2267,24 +2276,26 @@ def _spin_dataloaders_to_cur_epoch(self):
"""
log.debug('Spinning the dataloaders')

# spin the evaluator dataloaders once to initialize its sampler deterministically
# Spin the evaluator dataloaders once to initialize its sampler deterministically
# so it does not affect any other RNG reads
eval_state = self.state.dataset_resumption.get('eval', {})
for evaluator in self.state.evaluators:
dataloader = evaluator.dataloader.dataloader
if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler):
dataloader.sampler.set_epoch(0)
sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None
if isinstance(sampler, DistributedSampler):
sampler.set_epoch(0)
if evaluator.label not in eval_state:
for _ in dataloader:
break

# spin the train dataloader's sampler to get to the state of the desired epoch
# Spin the train dataloader's sampler to get to the state of the desired epoch
dataloader = self.state.dataloader
assert dataloader is not None, 'train dataloader is set on state after FIT_START'
if 'train' not in self.state.dataset_resumption:
sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None
for epoch in range(int(self.state.timestamp.epoch)):
if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler):
dataloader.sampler.set_epoch(epoch)
if isinstance(sampler, DistributedSampler):
sampler.set_epoch(epoch)
for _ in dataloader:
break

Expand Down Expand Up @@ -2366,8 +2377,9 @@ def _train_loop(self) -> None:
self.logger.log_metrics({'time/epoch': self.state.timestamp.epoch.value})

dataloader = self.state.dataloader
if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler):
dataloader.sampler.set_epoch(int(self.state.timestamp.epoch))
sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None
if isinstance(sampler, DistributedSampler):
sampler.set_epoch(int(self.state.timestamp.epoch))

for batch_idx, self.state.batch in enumerate(self._iter_dataloader(TrainerMode.TRAIN)):
# Spin dataloader forward unless dataloader handles internally with dataset_resumption
Expand Down Expand Up @@ -3221,16 +3233,15 @@ def _eval_loop(
metric.reset()

dataloader = self.state.dataloader
dist_sampler = None
drop_last = None
dataset_len = None
last_batch = False
if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler):
dist_sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None
if isinstance(dist_sampler, DistributedSampler) and isinstance(dataloader, DataLoader):
# The distributed sampler uses `set_epoch` to set the random seed
# Because evaluation can run on each batch, we use the batch to seed the sampler
# so each evaluation will get a proper shuffle.
# The epoch provided to `set_epoch` need not be sequential, so this is fine.
dist_sampler = dataloader.sampler
dist_sampler.set_epoch(int(self.state.timestamp.batch))
drop_last = dataloader.drop_last
# Only compute the dataset length if drop_last is False, as otherwise we don't need
Expand Down
119 changes: 109 additions & 10 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch.distributed
from packaging import version
from pytest import MonkeyPatch
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset, DistributedSampler

from composer.algorithms import NoOpModel
from composer.callbacks import CheckpointSaver
Expand Down Expand Up @@ -1246,6 +1246,8 @@ def get_trainer(
precision='fp32',
max_duration='2ep',
train_subset_num_batches=5,
use_batch_sampler: bool = False,
with_eval_dataloader: bool = True,
**kwargs,
):
model = SimpleModel()
Expand All @@ -1257,18 +1259,83 @@ def get_trainer(
eval_dataset = RandomClassificationDataset(size=12)
train_batch_size = 2

return Trainer(
model=model,
train_dataloader=DataLoader(
class _DistributedBatchSampler(DistributedSampler):

def __init__(
self,
dataset: Dataset,
batch_size: int,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
):
super().__init__(
dataset=dataset,
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
)
self._batch_size = batch_size

def __iter__(self):
indices = list(super().__iter__())
for ind_ in range(len(self)):
yield indices[ind_ * self._batch_size:(ind_ + 1) * self._batch_size]

def __len__(self) -> int:
return self.num_samples // self._batch_size

if use_batch_sampler:
train_batch_sampler = _DistributedBatchSampler(
dataset=train_dataset,
drop_last=True,
shuffle=True,
num_replicas=dist.get_world_size(),
rank=dist.get_global_rank(),
batch_size=train_batch_size,
)
train_dataloader = DataLoader(
dataset=train_dataset,
batch_sampler=train_batch_sampler,
)
else:
train_dataloader = DataLoader(
dataset=train_dataset,
batch_size=train_batch_size,
sampler=dist.get_sampler(train_dataset),
),
eval_dataloader=DataLoader(
dataset=eval_dataset,
batch_size=2,
sampler=dist.get_sampler(eval_dataset),
),
)

if with_eval_dataloader is True:
if use_batch_sampler:
eval_batch_sampler = _DistributedBatchSampler(
dataset=eval_dataset,
drop_last=False,
shuffle=False,
num_replicas=dist.get_world_size(),
rank=dist.get_global_rank(),
batch_size=train_batch_size,
)
eval_dataloader = DataLoader(
dataset=eval_dataset,
batch_sampler=eval_batch_sampler,
)
else:
eval_dataloader = DataLoader(
dataset=eval_dataset,
batch_size=train_batch_size,
sampler=dist.get_sampler(eval_dataset),
)
else:
eval_dataloader = None

return Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
device_train_microbatch_size=train_batch_size // 2,
precision=precision,
train_subset_num_batches=train_subset_num_batches,
Expand Down Expand Up @@ -1412,6 +1479,38 @@ def test_resumption(
save_folder / 'second' / final_checkpoint,
)

@world_size(2)
@pytest.mark.parametrize('max_duration', [1, 2])
@pytest.mark.filterwarnings('ignore:An unexpected prefix is detected. This case.*')
@pytest.mark.filterwarnings(
'ignore:``FullyShardedDataParallel.scatter_full_optim_state_dict``is being deprecated and is replaced by.*',
)
def test_set_dataloaders_to_cur_epoch(
self,
world_size: int,
max_duration: int,
tmp_path: pathlib.Path,
):
# All ranks use rank 0 folder
tmp_paths = dist.all_gather_object(os.path.abspath(tmp_path))
save_folder = pathlib.Path(tmp_paths[0])

trainer = self.get_trainer(
save_folder=os.path.join(save_folder, 'first'),
precision='fp32',
max_duration=f'{max_duration}ep',
train_subset_num_batches=2,
use_batch_sampler=True,
with_eval_dataloader=False,
)

trainer.fit()

assert isinstance(trainer.state.train_dataloader, DataLoader)
assert isinstance(trainer.state.train_dataloader.batch_sampler, DistributedSampler)
# Epoch count starts at O
assert trainer.state.train_dataloader.batch_sampler.epoch == max_duration - 1

@pytest.mark.parametrize(
'world_size',
[
Expand Down

0 comments on commit a9b9791

Please sign in to comment.