Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generator and worker seed #8602

Merged
merged 6 commits into from
Jul 22, 2022

Conversation

UnglvKitDe
Copy link
Contributor

@UnglvKitDe UnglvKitDe commented Jul 16, 2022

Worker seed and generator inserted into the dataloader as described in #8601

πŸ› οΈ PR Summary

Made with ❀️ by Ultralytics Actions

🌟 Summary

Enhance randomness control in YOLOv5 dataloaders for more consistent training results. πŸ”„

πŸ“Š Key Changes

  • Added seed_worker function to initialize the random seeds for dataloader workers.
  • Incorporated setting of manual seed and generator for PyTorch DataLoader.

🎯 Purpose & Impact

  • 🎲 Improve Consistency: Ensure that each worker in a dataloader initializes with a specific seed for better reproducibility.
  • ✨ Enhanced Randomness: The changes allow for more controlled randomness during data loading, which can contribute to more stable training performance across different runs.
  • πŸ‘©β€πŸ”¬ Research Friendly: Facilitate experimental consistency for researchers and developers, enabling fair comparison of models and training regimes.

@glenn-jocher glenn-jocher linked an issue Jul 17, 2022 that may be closed by this pull request
2 tasks
@glenn-jocher
Copy link
Member

@UnglvKitDe I'm seeing identical reproducible results with master and torch>=1.12.0. It doesn't seem we need any more modifications. i.e. in Colab with python train.py --epochs 3:
Screen Shot 2022-07-19 at 6 21 53 PM

@glenn-jocher
Copy link
Member

@UnglvKitDe are these DDP-specific requirements for reproducibility?

@UnglvKitDe
Copy link
Contributor Author

@glenn-jocher Mh, so I implemented it in my version because it was recommended and I didn't want to take any risks with edge cases. On the other hand, because I once read about a bug like this. But it seems that it has been fixed in the meantime. I could not recreate the case, but it is also still described in tutorials like here that it can come to this problem especially with many workers.

@glenn-jocher
Copy link
Member

@UnglvKitDe I can verify that I am not seeing reproducible results with DDP trainings, we probably do need this PR merged then.

I'm worried about the same seed on all workers though, I think this may impact augmentation by repeating the same augmentations. Perhaps we should set worker seeds equal to their RANK?

@glenn-jocher glenn-jocher merged commit 1c5e92a into ultralytics:master Jul 22, 2022
@glenn-jocher
Copy link
Member

@UnglvKitDe PR is merged. Thank you for your contributions to YOLOv5 πŸš€ and Vision AI ⭐

@glenn-jocher
Copy link
Member

@UnglvKitDe tested DDP training but I do not see reproducible results after this PR unfortunately. It seems something else is missing.

@UnglvKitDe
Copy link
Contributor Author

@UnglvKitDe I can verify that I am not seeing reproducible results with DDP trainings, we probably do need this PR merged then.

I'm worried about the same seed on all workers though, I think this may impact augmentation by repeating the same augmentations. Perhaps we should set worker seeds equal to their RANK?

@glenn-jocher Each worker has a different worker_seed. Here is a small example:

class RandomDataset(Dataset):
    def __getitem__(self, index):
        return np.random.randint(0, 1, 1)

    def __len__(self):
        return 16

def seed_worker(worker_id):
    # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
    init_seed = torch.initial_seed()
    worker_seed = init_seed % 2**32
    print(os.getpid(), worker_seed, init_seed)
    np.random.seed(worker_seed)
    random.seed(worker_seed)
dataset = RandomDataset()
dataloader = DataLoader(dataset, batch_size=2, num_workers=4, 
                        worker_init_fn=seed_worker)
for epoch in range(1):
    print(f"epoch: {epoch}")
    for batch in dataloader:
        print(batch)
    print("-"*25)

image

@glenn-jocher
Copy link
Member

@UnglvKitDe oh perfect, got it! I tried to implement an additional PR #8688 but still don't get reproducible DDP results even afterwards. Not sure unfortunately.

ctjanuhowski pushed a commit to ctjanuhowski/yolov5 that referenced this pull request Sep 8, 2022
* Add generator and worker seed

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update dataloaders.py

* Update dataloaders.py

* Update dataloaders.py

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Reproducibility in multi-process data loading
2 participants