Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 committed Apr 2, 2024
1 parent 2c6a3b4 commit c3d7539
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 13 deletions.
2 changes: 1 addition & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def _get_distributed_sampler(dataloader: DataLoader) -> Optional[DistributedSamp
"""Fetch a distributed sampler from a `dataloader` if it exists."""
if isinstance(dataloader.batch_sampler, DistributedSampler):
return dataloader.batch_sampler
elif isinstance(dataloader.sampler, DistributedSampler):
if isinstance(dataloader.sampler, DistributedSampler):
return dataloader.sampler
return None

Expand Down
15 changes: 3 additions & 12 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,7 +1300,7 @@ def __len__(self) -> int:

train_dataloader = DataLoader(
dataset=train_dataset,
batch_size=train_batch_size,
batch_size=1,
sampler=train_sampler,
)

Expand All @@ -1316,7 +1316,7 @@ def __len__(self) -> int:

eval_dataloader = DataLoader(
dataset=eval_dataset,
batch_size=2,
batch_size=1,
sampler=eval_sampler,
)
else:
Expand Down Expand Up @@ -1470,13 +1470,6 @@ def test_resumption(
)

@world_size(2)
@pytest.mark.parametrize(
'device',
[
pytest.param('gpu', marks=pytest.mark.gpu),
pytest.param('cpu'),
],
)
@pytest.mark.parametrize('max_duration', [1, 2])
@pytest.mark.filterwarnings('ignore:An unexpected prefix is detected. This case.*')
@pytest.mark.filterwarnings(
Expand All @@ -1485,7 +1478,6 @@ def test_resumption(
def test_set_dataloaders_to_cur_epoch(
self,
world_size: int,
device: str,
max_duration: int,
tmp_path: pathlib.Path,
):
Expand All @@ -1495,7 +1487,6 @@ def test_set_dataloaders_to_cur_epoch(

trainer = self.get_trainer(
save_folder=os.path.join(save_folder, 'first'),
device=device,
precision='fp32',
max_duration=f'{max_duration}ep',
train_subset_num_batches=2,
Expand All @@ -1507,7 +1498,7 @@ def test_set_dataloaders_to_cur_epoch(

assert isinstance(trainer.state.train_dataloader, DataLoader)
assert isinstance(trainer.state.train_dataloader.batch_sampler, DistributedSampler)
# Epochs count starts at O
# Epoch count starts at O
assert trainer.state.train_dataloader.batch_sampler.epoch == max_duration - 1

@pytest.mark.parametrize(
Expand Down

0 comments on commit c3d7539

Please sign in to comment.