diff --git a/utils/dataloaders.py b/utils/dataloaders.py index 4f1c98fd880d..85a39ab52f82 100755 --- a/utils/dataloaders.py +++ b/utils/dataloaders.py @@ -91,6 +91,13 @@ def exif_transpose(image): return image +def seed_worker(worker_id): + # Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader + worker_seed = torch.initial_seed() % 2 ** 32 + np.random.seed(worker_seed) + random.seed(worker_seed) + + def create_dataloader(path, imgsz, batch_size, @@ -130,13 +137,17 @@ def create_dataloader(path, nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates + generator = torch.Generator() + generator.manual_seed(0) return loader(dataset, batch_size=batch_size, shuffle=shuffle and sampler is None, num_workers=nw, sampler=sampler, pin_memory=True, - collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset + collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn, + worker_init_fn=seed_worker, + generator=generator), dataset class InfiniteDataLoader(dataloader.DataLoader):