Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 committed Apr 2, 2024
1 parent 14eb4f9 commit 2c6a3b4
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 36 deletions.
29 changes: 11 additions & 18 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,18 +465,13 @@ def _generate_run_name() -> str:
return generated_run_name


def _get_distributed_sampler(dataloader: DataLoader) -> DistributedSampler | None:
"""Fetch a distributed sampler from a `dataloader` if it exists est returns None.
Checks first the batch_sampler, then the sampler.
If no DistributedSampler is found, returns None.
"""
def _get_distributed_sampler(dataloader: DataLoader) -> Optional[DistributedSampler]:
"""Fetch a distributed sampler from a `dataloader` if it exists."""
if isinstance(dataloader.batch_sampler, DistributedSampler):
return dataloader.batch_sampler
if isinstance(dataloader.sampler, DistributedSampler):
elif isinstance(dataloader.sampler, DistributedSampler):
return dataloader.sampler

return
return None


class Trainer:
Expand Down Expand Up @@ -2281,7 +2276,7 @@ def _spin_dataloaders_to_cur_epoch(self):
"""
log.debug('Spinning the dataloaders')

# spin the evaluator dataloaders once to initialize its sampler deterministically
# Spin the evaluator dataloaders once to initialize its sampler deterministically
# so it does not affect any other RNG reads
eval_state = self.state.dataset_resumption.get('eval', {})
for evaluator in self.state.evaluators:
Expand All @@ -2293,12 +2288,12 @@ def _spin_dataloaders_to_cur_epoch(self):
for _ in dataloader:
break

# spin the train dataloader's sampler to get to the state of the desired epoch
# Spin the train dataloader's sampler to get to the state of the desired epoch
dataloader = self.state.dataloader
assert dataloader is not None, 'train dataloader is set on state after FIT_START'
if 'train' not in self.state.dataset_resumption:
sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None
for epoch in range(int(self.state.timestamp.epoch)):
sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None
if isinstance(sampler, DistributedSampler):
sampler.set_epoch(epoch)
for _ in dataloader:
Expand Down Expand Up @@ -3238,24 +3233,22 @@ def _eval_loop(
metric.reset()

dataloader = self.state.dataloader
dist_sampler = None
drop_last = None
dataset_len = None
last_batch = False
sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None
if isinstance(sampler, DistributedSampler) and isinstance(dataloader, DataLoader):
dist_sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None
if isinstance(dist_sampler, DistributedSampler) and isinstance(dataloader, DataLoader):
# The distributed sampler uses `set_epoch` to set the random seed
# Because evaluation can run on each batch, we use the batch to seed the sampler
# so each evaluation will get a proper shuffle.
# The epoch provided to `set_epoch` need not be sequential, so this is fine.
dist_sampler = sampler
sampler.set_epoch(int(self.state.timestamp.batch))
dist_sampler.set_epoch(int(self.state.timestamp.batch))
drop_last = dataloader.drop_last
# Only compute the dataset length if drop_last is False, as otherwise we don't need
# to remove any duplicate samples.
if drop_last == False:
try:
dataset_len = len(sampler.dataset) # type: ignore
dataset_len = len(dist_sampler.dataset) # type: ignore
except AttributeError:
warnings.warn(
"DistributedSampler's dataset does not have length defined. When "
Expand Down
31 changes: 13 additions & 18 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,40 +1289,35 @@ def __iter__(self):
def __len__(self) -> int:
return self.num_samples // self._batch_size

train_batch_sampler = _DistributedBatchSampler(
train_sampler = _DistributedBatchSampler(
dataset=train_dataset,
drop_last=True,
shuffle=True,
num_replicas=dist.get_world_size(),
rank=dist.get_global_rank(),
batch_size=train_batch_size,
)
eval_batch_sampler = _DistributedBatchSampler(
dataset=eval_dataset,
drop_last=False,
shuffle=False,
num_replicas=dist.get_world_size(),
rank=dist.get_global_rank(),
batch_size=train_batch_size,
)
) if use_batch_sampler else dist.get_sampler(train_dataset)

train_dataloader = DataLoader(
dataset=train_dataset,
batch_size=train_batch_size,
sampler=dist.get_sampler(train_dataset),
) if not use_batch_sampler else DataLoader(
dataset=train_dataset,
batch_sampler=train_batch_sampler,
sampler=train_sampler,
)

if with_eval_dataloader is True:
eval_sampler = _DistributedBatchSampler(
dataset=eval_dataset,
drop_last=False,
shuffle=False,
num_replicas=dist.get_world_size(),
rank=dist.get_global_rank(),
batch_size=train_batch_size,
) if use_batch_sampler else dist.get_sampler(eval_dataset)

eval_dataloader = DataLoader(
dataset=eval_dataset,
batch_size=2,
sampler=dist.get_sampler(eval_dataset),
) if not use_batch_sampler else DataLoader(
dataset=eval_dataset,
batch_sampler=eval_batch_sampler,
sampler=eval_sampler,
)
else:
eval_dataloader = None
Expand Down

0 comments on commit 2c6a3b4

Please sign in to comment.