From 1c5e92aba11f0dd007716821e7cd151d532342a8 Mon Sep 17 00:00:00 2001 From: UnglvKitDe <100289696+UnglvKitDe@users.noreply.github.com> Date: Sat, 23 Jul 2022 01:25:17 +0200 Subject: [PATCH] Add generator and worker seed (#8602) * Add generator and worker seed * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update dataloaders.py * Update dataloaders.py * Update dataloaders.py Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher --- utils/dataloaders.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/utils/dataloaders.py b/utils/dataloaders.py index 4f1c98fd880d..85a39ab52f82 100755 --- a/utils/dataloaders.py +++ b/utils/dataloaders.py @@ -91,6 +91,13 @@ def exif_transpose(image): return image +def seed_worker(worker_id): + # Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader + worker_seed = torch.initial_seed() % 2 ** 32 + np.random.seed(worker_seed) + random.seed(worker_seed) + + def create_dataloader(path, imgsz, batch_size, @@ -130,13 +137,17 @@ def create_dataloader(path, nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates + generator = torch.Generator() + generator.manual_seed(0) return loader(dataset, batch_size=batch_size, shuffle=shuffle and sampler is None, num_workers=nw, sampler=sampler, pin_memory=True, - collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset + collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn, + worker_init_fn=seed_worker, + generator=generator), dataset class InfiniteDataLoader(dataloader.DataLoader):