diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 064c22a73b..6ff090988e 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -465,6 +465,15 @@ def _generate_run_name() -> str: return generated_run_name +def _get_distributed_sampler(dataloader: DataLoader) -> Optional[DistributedSampler]: + """Fetch a distributed sampler from a `dataloader` if it exists.""" + if isinstance(dataloader.batch_sampler, DistributedSampler): + return dataloader.batch_sampler + if isinstance(dataloader.sampler, DistributedSampler): + return dataloader.sampler + return None + + class Trainer: """Train models with Composer algorithms. @@ -2267,24 +2276,26 @@ def _spin_dataloaders_to_cur_epoch(self): """ log.debug('Spinning the dataloaders') - # spin the evaluator dataloaders once to initialize its sampler deterministically + # Spin the evaluator dataloaders once to initialize its sampler deterministically # so it does not affect any other RNG reads eval_state = self.state.dataset_resumption.get('eval', {}) for evaluator in self.state.evaluators: dataloader = evaluator.dataloader.dataloader - if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler): - dataloader.sampler.set_epoch(0) + sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None + if isinstance(sampler, DistributedSampler): + sampler.set_epoch(0) if evaluator.label not in eval_state: for _ in dataloader: break - # spin the train dataloader's sampler to get to the state of the desired epoch + # Spin the train dataloader's sampler to get to the state of the desired epoch dataloader = self.state.dataloader assert dataloader is not None, 'train dataloader is set on state after FIT_START' if 'train' not in self.state.dataset_resumption: + sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None for epoch in range(int(self.state.timestamp.epoch)): - if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler): - dataloader.sampler.set_epoch(epoch) + if isinstance(sampler, DistributedSampler): + sampler.set_epoch(epoch) for _ in dataloader: break @@ -2366,8 +2377,9 @@ def _train_loop(self) -> None: self.logger.log_metrics({'time/epoch': self.state.timestamp.epoch.value}) dataloader = self.state.dataloader - if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler): - dataloader.sampler.set_epoch(int(self.state.timestamp.epoch)) + sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None + if isinstance(sampler, DistributedSampler): + sampler.set_epoch(int(self.state.timestamp.epoch)) for batch_idx, self.state.batch in enumerate(self._iter_dataloader(TrainerMode.TRAIN)): # Spin dataloader forward unless dataloader handles internally with dataset_resumption @@ -3221,16 +3233,15 @@ def _eval_loop( metric.reset() dataloader = self.state.dataloader - dist_sampler = None drop_last = None dataset_len = None last_batch = False - if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler): + dist_sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None + if isinstance(dist_sampler, DistributedSampler) and isinstance(dataloader, DataLoader): # The distributed sampler uses `set_epoch` to set the random seed # Because evaluation can run on each batch, we use the batch to seed the sampler # so each evaluation will get a proper shuffle. # The epoch provided to `set_epoch` need not be sequential, so this is fine. - dist_sampler = dataloader.sampler dist_sampler.set_epoch(int(self.state.timestamp.batch)) drop_last = dataloader.drop_last # Only compute the dataset length if drop_last is False, as otherwise we don't need diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 797dc4b112..c6712c8f43 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -20,7 +20,7 @@ import torch.distributed from packaging import version from pytest import MonkeyPatch -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset, DistributedSampler from composer.algorithms import NoOpModel from composer.callbacks import CheckpointSaver @@ -1246,6 +1246,8 @@ def get_trainer( precision='fp32', max_duration='2ep', train_subset_num_batches=5, + use_batch_sampler: bool = False, + with_eval_dataloader: bool = True, **kwargs, ): model = SimpleModel() @@ -1257,18 +1259,83 @@ def get_trainer( eval_dataset = RandomClassificationDataset(size=12) train_batch_size = 2 - return Trainer( - model=model, - train_dataloader=DataLoader( + class _DistributedBatchSampler(DistributedSampler): + + def __init__( + self, + dataset: Dataset, + batch_size: int, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ): + super().__init__( + dataset=dataset, + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + seed=seed, + drop_last=drop_last, + ) + self._batch_size = batch_size + + def __iter__(self): + indices = list(super().__iter__()) + for ind_ in range(len(self)): + yield indices[ind_ * self._batch_size:(ind_ + 1) * self._batch_size] + + def __len__(self) -> int: + return self.num_samples // self._batch_size + + if use_batch_sampler: + train_batch_sampler = _DistributedBatchSampler( + dataset=train_dataset, + drop_last=True, + shuffle=True, + num_replicas=dist.get_world_size(), + rank=dist.get_global_rank(), + batch_size=train_batch_size, + ) + train_dataloader = DataLoader( + dataset=train_dataset, + batch_sampler=train_batch_sampler, + ) + else: + train_dataloader = DataLoader( dataset=train_dataset, batch_size=train_batch_size, sampler=dist.get_sampler(train_dataset), - ), - eval_dataloader=DataLoader( - dataset=eval_dataset, - batch_size=2, - sampler=dist.get_sampler(eval_dataset), - ), + ) + + if with_eval_dataloader is True: + if use_batch_sampler: + eval_batch_sampler = _DistributedBatchSampler( + dataset=eval_dataset, + drop_last=False, + shuffle=False, + num_replicas=dist.get_world_size(), + rank=dist.get_global_rank(), + batch_size=train_batch_size, + ) + eval_dataloader = DataLoader( + dataset=eval_dataset, + batch_sampler=eval_batch_sampler, + ) + else: + eval_dataloader = DataLoader( + dataset=eval_dataset, + batch_size=train_batch_size, + sampler=dist.get_sampler(eval_dataset), + ) + else: + eval_dataloader = None + + return Trainer( + model=model, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, device_train_microbatch_size=train_batch_size // 2, precision=precision, train_subset_num_batches=train_subset_num_batches, @@ -1412,6 +1479,38 @@ def test_resumption( save_folder / 'second' / final_checkpoint, ) + @world_size(2) + @pytest.mark.parametrize('max_duration', [1, 2]) + @pytest.mark.filterwarnings('ignore:An unexpected prefix is detected. This case.*') + @pytest.mark.filterwarnings( + 'ignore:``FullyShardedDataParallel.scatter_full_optim_state_dict``is being deprecated and is replaced by.*', + ) + def test_set_dataloaders_to_cur_epoch( + self, + world_size: int, + max_duration: int, + tmp_path: pathlib.Path, + ): + # All ranks use rank 0 folder + tmp_paths = dist.all_gather_object(os.path.abspath(tmp_path)) + save_folder = pathlib.Path(tmp_paths[0]) + + trainer = self.get_trainer( + save_folder=os.path.join(save_folder, 'first'), + precision='fp32', + max_duration=f'{max_duration}ep', + train_subset_num_batches=2, + use_batch_sampler=True, + with_eval_dataloader=False, + ) + + trainer.fit() + + assert isinstance(trainer.state.train_dataloader, DataLoader) + assert isinstance(trainer.state.train_dataloader.batch_sampler, DistributedSampler) + # Epoch count starts at O + assert trainer.state.train_dataloader.batch_sampler.epoch == max_duration - 1 + @pytest.mark.parametrize( 'world_size', [