Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 committed Apr 2, 2024
1 parent 522afe4 commit fdb2efa
Showing 1 changed file with 38 additions and 33 deletions.
71 changes: 38 additions & 33 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,41 +1289,46 @@ def __iter__(self):
def __len__(self) -> int:
return self.num_samples // self._batch_size

train_batch_sampler = _DistributedBatchSampler(
dataset=train_dataset,
drop_last=True,
shuffle=True,
num_replicas=dist.get_world_size(),
rank=dist.get_global_rank(),
batch_size=train_batch_size,
)
eval_batch_sampler = _DistributedBatchSampler(
dataset=eval_dataset,
drop_last=False,
shuffle=False,
num_replicas=dist.get_world_size(),
rank=dist.get_global_rank(),
batch_size=train_batch_size,
)

train_dataloader = DataLoader(
dataset=train_dataset,
batch_size=train_batch_size,
sampler=dist.get_sampler(train_dataset),
) if not use_batch_sampler else DataLoader(
dataset=train_dataset,
batch_sampler=train_batch_sampler,
)
if use_batch_sampler:
train_batch_sampler = _DistributedBatchSampler(
dataset=train_dataset,
drop_last=True,
shuffle=True,
num_replicas=dist.get_world_size(),
rank=dist.get_global_rank(),
batch_size=train_batch_size,
)
train_dataloader = DataLoader(
dataset=train_dataset,
batch_sampler=train_batch_sampler,
)
else:
train_dataloader = DataLoader(
dataset=train_dataset,
batch_size=train_batch_size,
sampler=dist.get_sampler(train_dataset),
)

if with_eval_dataloader is True:
eval_dataloader = DataLoader(
dataset=eval_dataset,
batch_size=2,
sampler=dist.get_sampler(eval_dataset),
) if not use_batch_sampler else DataLoader(
dataset=eval_dataset,
batch_sampler=eval_batch_sampler,
)
if use_batch_sampler:
eval_batch_sampler = _DistributedBatchSampler(
dataset=eval_dataset,
drop_last=False,
shuffle=False,
num_replicas=dist.get_world_size(),
rank=dist.get_global_rank(),
batch_size=train_batch_size,
)
eval_dataloader = DataLoader(
dataset=eval_dataset,
batch_sampler=eval_batch_sampler,
)
else:
eval_dataloader = DataLoader(
dataset=eval_dataset,
batch_size=train_batch_size,
sampler=dist.get_sampler(eval_dataset),
)
else:
eval_dataloader = None

Expand Down

0 comments on commit fdb2efa

Please sign in to comment.