From e496bd0b6b1c07b6d7a7cd08498d7cc0c7811305 Mon Sep 17 00:00:00 2001 From: Alexandre Ghelfi Date: Mon, 25 Mar 2024 11:03:38 +0100 Subject: [PATCH] proper formating --- tests/trainer/test_checkpoint.py | 53 ++++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index c3e6c6477fc..2e7e9a4c25f 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -1260,21 +1260,34 @@ def get_trainer( train_batch_size = 2 class _DistributedBatchSampler(DistributedSampler): + def __init__( - self, dataset: Dataset, batch_size: int, num_replicas: Optional[int] = None, - rank: Optional[int] = None, shuffle: bool = True, - seed: int = 0, drop_last: bool = False, + self, + dataset: Dataset, + batch_size: int, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, ): - super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last) - self._batch_size =batch_size + super().__init__( + dataset=dataset, + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + seed=seed, + drop_last=drop_last + ) + self._batch_size = batch_size def __iter__(self): indices = list(super().__iter__()) for ind_ in range(len(self)): - yield indices[ind_*self._batch_size: (ind_+1)*self._batch_size] - + yield indices[ind_ * self._batch_size:(ind_ + 1) * self._batch_size] + def __len__(self) -> int: - return self.num_samples//self._batch_size + return self.num_samples // self._batch_size train_dataloader = DataLoader( dataset=train_dataset, @@ -1282,10 +1295,14 @@ def __len__(self) -> int: sampler=dist.get_sampler(train_dataset), ) if not use_batch_sampler else DataLoader( dataset=train_dataset, - batch_sampler=_DistributedBatchSampler(dataset=train_dataset, drop_last=True, - shuffle=True, - num_replicas=dist.get_world_size(), - rank=dist.get_global_rank(), batch_size=train_batch_size), + batch_sampler=_DistributedBatchSampler( + dataset=train_dataset, + drop_last=True, + shuffle=True, + num_replicas=dist.get_world_size(), + rank=dist.get_global_rank(), + batch_size=train_batch_size + ), ) if with_eval_dataloader is True: @@ -1295,10 +1312,14 @@ def __len__(self) -> int: sampler=dist.get_sampler(eval_dataset), ) if not use_batch_sampler else DataLoader( dataset=eval_dataset, - batch_sampler=_DistributedBatchSampler(dataset=eval_dataset, drop_last=False, - shuffle=False, - num_replicas=dist.get_world_size(), - rank=dist.get_global_rank(), batch_size=train_batch_size), + batch_sampler=_DistributedBatchSampler( + dataset=eval_dataset, + drop_last=False, + shuffle=False, + num_replicas=dist.get_world_size(), + rank=dist.get_global_rank(), + batch_size=train_batch_size + ), ) else: eval_dataloader = None