From 2f549b96a859a5d0c3b8133930145a6a901e94e9 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 1 Sep 2020 01:01:25 +0700 Subject: [PATCH] Add InfiniteDataLoader class (#876) * Add InfiniteDataLoader Only initializes at first epoch. Saves time. * Moved class to a better location --- utils/datasets.py | 48 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/utils/datasets.py b/utils/datasets.py index d6c220f3a820..edb6b10fa050 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -63,15 +63,51 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa batch_size = min(batch_size, len(dataset)) nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None - dataloader = torch.utils.data.DataLoader(dataset, - batch_size=batch_size, - num_workers=nw, - sampler=train_sampler, - pin_memory=True, - collate_fn=LoadImagesAndLabels.collate_fn) + dataloader = InfiniteDataLoader (dataset, + batch_size=batch_size, + num_workers=nw, + sampler=train_sampler, + pin_memory=True, + collate_fn=LoadImagesAndLabels.collate_fn) return dataloader, dataset +class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader): + ''' + Dataloader that reuses workers. + + Uses same syntax as vanilla DataLoader. + ''' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) + self.iterator = super().__iter__() + + def __len__(self): + return len(self.batch_sampler.sampler) + + def __iter__(self): + for i in range(len(self)): + yield next(self.iterator) + + +class _RepeatSampler(object): + ''' + Sampler that repeats forever. + + Args: + sampler (Sampler) + ''' + + def __init__(self, sampler): + self.sampler = sampler + + def __iter__(self): + while True: + yield from iter(self.sampler) + + class LoadImages: # for inference def __init__(self, path, img_size=640): p = str(Path(path)) # os-agnostic