Skip to content

Commit

Permalink
proper formating
Browse files Browse the repository at this point in the history
  • Loading branch information
Ghelfi committed Mar 25, 2024
1 parent e7f7379 commit e496bd0
Showing 1 changed file with 37 additions and 16 deletions.
53 changes: 37 additions & 16 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,32 +1260,49 @@ def get_trainer(
train_batch_size = 2

class _DistributedBatchSampler(DistributedSampler):

def __init__(
self, dataset: Dataset, batch_size: int, num_replicas: Optional[int] = None,
rank: Optional[int] = None, shuffle: bool = True,
seed: int = 0, drop_last: bool = False,
self,
dataset: Dataset,
batch_size: int,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
):
super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last)
self._batch_size =batch_size
super().__init__(
dataset=dataset,
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
seed=seed,
drop_last=drop_last
)
self._batch_size = batch_size

def __iter__(self):
indices = list(super().__iter__())
for ind_ in range(len(self)):
yield indices[ind_*self._batch_size: (ind_+1)*self._batch_size]
yield indices[ind_ * self._batch_size:(ind_ + 1) * self._batch_size]

def __len__(self) -> int:
return self.num_samples//self._batch_size
return self.num_samples // self._batch_size

train_dataloader = DataLoader(
dataset=train_dataset,
batch_size=train_batch_size,
sampler=dist.get_sampler(train_dataset),
) if not use_batch_sampler else DataLoader(
dataset=train_dataset,
batch_sampler=_DistributedBatchSampler(dataset=train_dataset, drop_last=True,
shuffle=True,
num_replicas=dist.get_world_size(),
rank=dist.get_global_rank(), batch_size=train_batch_size),
batch_sampler=_DistributedBatchSampler(
dataset=train_dataset,
drop_last=True,
shuffle=True,
num_replicas=dist.get_world_size(),
rank=dist.get_global_rank(),
batch_size=train_batch_size
),
)

if with_eval_dataloader is True:
Expand All @@ -1295,10 +1312,14 @@ def __len__(self) -> int:
sampler=dist.get_sampler(eval_dataset),
) if not use_batch_sampler else DataLoader(
dataset=eval_dataset,
batch_sampler=_DistributedBatchSampler(dataset=eval_dataset, drop_last=False,
shuffle=False,
num_replicas=dist.get_world_size(),
rank=dist.get_global_rank(), batch_size=train_batch_size),
batch_sampler=_DistributedBatchSampler(
dataset=eval_dataset,
drop_last=False,
shuffle=False,
num_replicas=dist.get_world_size(),
rank=dist.get_global_rank(),
batch_size=train_batch_size
),
)
else:
eval_dataloader = None
Expand Down

0 comments on commit e496bd0

Please sign in to comment.