Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call set_epoch on Dataloader.batch_sampler if defined #3124

Merged
merged 31 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
408179e
adding _get_dist_sampler function
Ghelfi Feb 4, 2024
514a8b1
apply _get_dist_sampler across all trainer.py
Ghelfi Feb 4, 2024
ef5540d
Switch condition to check if batch_sampler if filled first since samp…
Ghelfi Mar 17, 2024
413c4c3
adding docstring
Ghelfi Mar 17, 2024
ed8535f
removing antipattenr from _get_sampler and linting
Ghelfi Mar 17, 2024
ca7873f
removed default attrs from exception in the attrs dict (#3126)
jjanezhang Mar 18, 2024
027156e
Bump coverage[toml] from 7.4.3 to 7.4.4 (#3121)
dependabot[bot] Mar 18, 2024
9f768a4
Refactor initialization (#3127)
Practicinginhell Mar 18, 2024
06b8afd
Bump databricks sdk version (#3128)
dakinggg Mar 18, 2024
c9b41f7
Update packaging requirement from <23.3,>=21.3.0 to >=21.3.0,<24.1 (#…
dependabot[bot] Mar 18, 2024
0c42c51
remove rng from save weights only ckpt (#3129)
eracah Mar 18, 2024
ccaeec5
More compression options (#3118)
mbway Mar 20, 2024
2fdbf45
Only broadcast distcp files (#3130)
mvpatel2000 Mar 20, 2024
88677c3
Bump version to 0.21 (#3132)
mvpatel2000 Mar 20, 2024
5161f59
Adding a test for checking the update of epoch on batch_sampler
Ghelfi Mar 25, 2024
c731f67
proper formating
Ghelfi Mar 25, 2024
570558d
adding _get_dist_sampler function
Ghelfi Feb 4, 2024
ab09d43
apply _get_dist_sampler across all trainer.py
Ghelfi Feb 4, 2024
df73325
Switch condition to check if batch_sampler if filled first since samp…
Ghelfi Mar 17, 2024
61da1b8
adding docstring
Ghelfi Mar 17, 2024
12245cc
removing antipattenr from _get_sampler and linting
Ghelfi Mar 17, 2024
b2c147a
Adding a test for checking the update of epoch on batch_sampler
Ghelfi Mar 25, 2024
68e6d9d
proper formating
Ghelfi Mar 25, 2024
98c4e51
fixing test
Ghelfi Mar 27, 2024
b3e6551
Chang _get_sampler to _get_distributed_sampler to bypass torch defaul…
Ghelfi Mar 27, 2024
76b888b
Move batch_sampler instanciation for test
Ghelfi Mar 28, 2024
dd45d3e
fix batch_sampler setting in eval_loop
Ghelfi Mar 28, 2024
14eb4f9
Merge branch 'dev' into set-epoch-on-batch-sampler
mvpatel2000 Apr 1, 2024
448d531
trainer
mvpatel2000 Apr 2, 2024
522afe4
fix
mvpatel2000 Apr 2, 2024
fdb2efa
lint
mvpatel2000 Apr 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,15 @@ def _generate_run_name() -> str:
return generated_run_name


def _get_distributed_sampler(dataloader: DataLoader) -> Optional[DistributedSampler]:
"""Fetch a distributed sampler from a `dataloader` if it exists."""
if isinstance(dataloader.batch_sampler, DistributedSampler):
return dataloader.batch_sampler
if isinstance(dataloader.sampler, DistributedSampler):
return dataloader.sampler
return None


class Trainer:
"""Train models with Composer algorithms.
Expand Down Expand Up @@ -2267,24 +2276,26 @@ def _spin_dataloaders_to_cur_epoch(self):
"""
log.debug('Spinning the dataloaders')

# spin the evaluator dataloaders once to initialize its sampler deterministically
# Spin the evaluator dataloaders once to initialize its sampler deterministically
# so it does not affect any other RNG reads
eval_state = self.state.dataset_resumption.get('eval', {})
for evaluator in self.state.evaluators:
dataloader = evaluator.dataloader.dataloader
if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler):
dataloader.sampler.set_epoch(0)
sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None
if isinstance(sampler, DistributedSampler):
sampler.set_epoch(0)
if evaluator.label not in eval_state:
for _ in dataloader:
break

# spin the train dataloader's sampler to get to the state of the desired epoch
# Spin the train dataloader's sampler to get to the state of the desired epoch
dataloader = self.state.dataloader
assert dataloader is not None, 'train dataloader is set on state after FIT_START'
if 'train' not in self.state.dataset_resumption:
sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None
for epoch in range(int(self.state.timestamp.epoch)):
if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler):
dataloader.sampler.set_epoch(epoch)
if isinstance(sampler, DistributedSampler):
sampler.set_epoch(epoch)
for _ in dataloader:
break

Expand Down Expand Up @@ -2366,8 +2377,9 @@ def _train_loop(self) -> None:
self.logger.log_metrics({'time/epoch': self.state.timestamp.epoch.value})

dataloader = self.state.dataloader
if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler):
dataloader.sampler.set_epoch(int(self.state.timestamp.epoch))
sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None
if isinstance(sampler, DistributedSampler):
sampler.set_epoch(int(self.state.timestamp.epoch))

for batch_idx, self.state.batch in enumerate(self._iter_dataloader(TrainerMode.TRAIN)):
# Spin dataloader forward unless dataloader handles internally with dataset_resumption
Expand Down Expand Up @@ -3221,16 +3233,15 @@ def _eval_loop(
metric.reset()

dataloader = self.state.dataloader
dist_sampler = None
drop_last = None
dataset_len = None
last_batch = False
if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler):
dist_sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None
if isinstance(dist_sampler, DistributedSampler) and isinstance(dataloader, DataLoader):
# The distributed sampler uses `set_epoch` to set the random seed
# Because evaluation can run on each batch, we use the batch to seed the sampler
# so each evaluation will get a proper shuffle.
# The epoch provided to `set_epoch` need not be sequential, so this is fine.
dist_sampler = dataloader.sampler
dist_sampler.set_epoch(int(self.state.timestamp.batch))
drop_last = dataloader.drop_last
# Only compute the dataset length if drop_last is False, as otherwise we don't need
Expand Down
119 changes: 109 additions & 10 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch.distributed
from packaging import version
from pytest import MonkeyPatch
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset, DistributedSampler

from composer.algorithms import NoOpModel
from composer.callbacks import CheckpointSaver
Expand Down Expand Up @@ -1246,6 +1246,8 @@ def get_trainer(
precision='fp32',
max_duration='2ep',
train_subset_num_batches=5,
use_batch_sampler: bool = False,
with_eval_dataloader: bool = True,
**kwargs,
):
model = SimpleModel()
Expand All @@ -1257,18 +1259,83 @@ def get_trainer(
eval_dataset = RandomClassificationDataset(size=12)
train_batch_size = 2

return Trainer(
model=model,
train_dataloader=DataLoader(
class _DistributedBatchSampler(DistributedSampler):
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
dataset: Dataset,
batch_size: int,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
):
super().__init__(
dataset=dataset,
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
)
self._batch_size = batch_size

def __iter__(self):
indices = list(super().__iter__())
for ind_ in range(len(self)):
yield indices[ind_ * self._batch_size:(ind_ + 1) * self._batch_size]

def __len__(self) -> int:
return self.num_samples // self._batch_size

if use_batch_sampler:
train_batch_sampler = _DistributedBatchSampler(
dataset=train_dataset,
drop_last=True,
shuffle=True,
num_replicas=dist.get_world_size(),
rank=dist.get_global_rank(),
batch_size=train_batch_size,
)
train_dataloader = DataLoader(
dataset=train_dataset,
batch_sampler=train_batch_sampler,
)
else:
train_dataloader = DataLoader(
dataset=train_dataset,
batch_size=train_batch_size,
sampler=dist.get_sampler(train_dataset),
),
eval_dataloader=DataLoader(
dataset=eval_dataset,
batch_size=2,
sampler=dist.get_sampler(eval_dataset),
),
)

if with_eval_dataloader is True:
if use_batch_sampler:
eval_batch_sampler = _DistributedBatchSampler(
dataset=eval_dataset,
drop_last=False,
shuffle=False,
num_replicas=dist.get_world_size(),
rank=dist.get_global_rank(),
batch_size=train_batch_size,
)
eval_dataloader = DataLoader(
dataset=eval_dataset,
batch_sampler=eval_batch_sampler,
)
else:
eval_dataloader = DataLoader(
dataset=eval_dataset,
batch_size=train_batch_size,
sampler=dist.get_sampler(eval_dataset),
)
else:
eval_dataloader = None

return Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
device_train_microbatch_size=train_batch_size // 2,
precision=precision,
train_subset_num_batches=train_subset_num_batches,
Expand Down Expand Up @@ -1412,6 +1479,38 @@ def test_resumption(
save_folder / 'second' / final_checkpoint,
)

@world_size(2)
@pytest.mark.parametrize('max_duration', [1, 2])
@pytest.mark.filterwarnings('ignore:An unexpected prefix is detected. This case.*')
@pytest.mark.filterwarnings(
'ignore:``FullyShardedDataParallel.scatter_full_optim_state_dict``is being deprecated and is replaced by.*',
)
def test_set_dataloaders_to_cur_epoch(
self,
world_size: int,
max_duration: int,
tmp_path: pathlib.Path,
):
# All ranks use rank 0 folder
tmp_paths = dist.all_gather_object(os.path.abspath(tmp_path))
save_folder = pathlib.Path(tmp_paths[0])

trainer = self.get_trainer(
save_folder=os.path.join(save_folder, 'first'),
precision='fp32',
max_duration=f'{max_duration}ep',
train_subset_num_batches=2,
use_batch_sampler=True,
with_eval_dataloader=False,
)

trainer.fit()

assert isinstance(trainer.state.train_dataloader, DataLoader)
assert isinstance(trainer.state.train_dataloader.batch_sampler, DistributedSampler)
# Epoch count starts at O
assert trainer.state.train_dataloader.batch_sampler.epoch == max_duration - 1

@pytest.mark.parametrize(
'world_size',
[
Expand Down
Loading