-
-
Notifications
You must be signed in to change notification settings - Fork 15.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi-GPU DDP RAM multiple-cache bug #3818
Comments
👋 Hello @aprikyan, thank you for your interest in 🚀 YOLOv5! Please visit our ⭐️ Tutorials to get started, where you can find quickstart guides for simple tasks like Custom Data Training all the way to advanced concepts like Hyperparameter Evolution. If this is a 🐛 Bug Report, please provide screenshots and minimum viable code to reproduce your issue, otherwise we can not help you. If this is a custom training ❓ Question, please provide as much information as possible, including dataset images, training logs, screenshots, and a public link to online W&B logging if available. For business inquiries or professional support requests please visit https://www.ultralytics.com or email Glenn Jocher at glenn.jocher@ultralytics.com. RequirementsPython 3.8 or later with all requirements.txt dependencies installed, including $ pip install -r requirements.txt EnvironmentsYOLOv5 may be run in any of the following up-to-date verified environments (with all dependencies including CUDA/CUDNN, Python and PyTorch preinstalled):
StatusIf this badge is green, all YOLOv5 GitHub Actions Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training (train.py), testing (test.py), inference (detect.py) and export (export.py) on MacOS, Windows, and Ubuntu every 24 hours and on every commit. |
@aprikyan your command looks fine. Cache should only happen once per split (once for train, once for val). I would avoid increasing swap file to accommodate --cache though, this is probably counterproductive. I would restart your system with no swap and try to reproduce at a smaller --img 320 (uses less RAM) to verify that it is actually trying to cache twice. --cache with DDP is not part of our CI so this is untested. If the problem persists please let us know. |
👋 Hello, this issue has been automatically marked as stale because it has not had recent activity. Please note it will be closed if no further activity occurs. Access additional YOLOv5 🚀 resources:
Access additional Ultralytics ⚡ resources:
Feel free to inform us of any other issues you discover or feature requests that come to mind in the future. Pull Requests (PRs) are also always welcomed! Thank you for your contributions to YOLOv5 🚀 and Vision AI ⭐! |
hello, I have the same problem, how can I solve it, thank you |
I have met the same problem. ┭┮﹏┭┮ |
@1757525671 @hukaixuan19970627 @aprikyan I'm able to reproduce the multi-caching on DDP training bug. I'm not sure if there's a simple solution as the COCO training example,
If you guys have any ideas or discover a solution please let us know! |
This is to address (and hopefully fix) this issue: Multi-GPU DDP RAM multiple-cache bug ultralytics#3818 (ultralytics#3818). This was a very serious and "blocking" issue until I could figure out what was going on. The problem was especially bad when running Multi-GPU jobs with 8 GPUs, RAM usage was 8x higher than expected (!), causing repeated OOM failures. Hopefully this fix will help others. DDP causes each RANK to launch it's own process (one for each GPU) with it's own trainloader, and its own RAM image cache. The DistributedSampler used by DDP (https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py) will feed only a subset of images (1/WORLD_SIZE) to each available GPU on each epoch, but since the images are shuffled between epochs, each GPU process must still cache all images. So I created a subclass of DistributedSampler called SmartDistributedSampler that forces each GPU process to always sample the same subset (using modulo arithmetic with RANK and WORLD_SIZE) while still allowing random shuffling between epochs. I don't believe this disrupts the overall "randomness" of the sampling, and I haven't noticed any performance degradation. Signed-off-by: davidsvaughn <davidsvaughn@gmail.com>
… issue (#10383) * Update dataloaders.py This is to address (and hopefully fix) this issue: Multi-GPU DDP RAM multiple-cache bug #3818 (#3818). This was a very serious and "blocking" issue until I could figure out what was going on. The problem was especially bad when running Multi-GPU jobs with 8 GPUs, RAM usage was 8x higher than expected (!), causing repeated OOM failures. Hopefully this fix will help others. DDP causes each RANK to launch it's own process (one for each GPU) with it's own trainloader, and its own RAM image cache. The DistributedSampler used by DDP (https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py) will feed only a subset of images (1/WORLD_SIZE) to each available GPU on each epoch, but since the images are shuffled between epochs, each GPU process must still cache all images. So I created a subclass of DistributedSampler called SmartDistributedSampler that forces each GPU process to always sample the same subset (using modulo arithmetic with RANK and WORLD_SIZE) while still allowing random shuffling between epochs. I don't believe this disrupts the overall "randomness" of the sampling, and I haven't noticed any performance degradation. Signed-off-by: davidsvaughn <davidsvaughn@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update dataloaders.py move extra parameter (rank) to end so won't mess up pre-existing positional args * Update dataloaders.py removing extra '#' * Update dataloaders.py sample from DDP index array (self.idx) in mixup mosaic * Merging self.indices and self.idx (DDP indices) into single attribute (self.indices). Also adding SmartDistributedSampler to segmentation dataloader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Multiply GB displayed by WORLD_SIZE --------- Signed-off-by: davidsvaughn <davidsvaughn@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
… issue (ultralytics#10383) * Update dataloaders.py This is to address (and hopefully fix) this issue: Multi-GPU DDP RAM multiple-cache bug ultralytics#3818 (ultralytics#3818). This was a very serious and "blocking" issue until I could figure out what was going on. The problem was especially bad when running Multi-GPU jobs with 8 GPUs, RAM usage was 8x higher than expected (!), causing repeated OOM failures. Hopefully this fix will help others. DDP causes each RANK to launch it's own process (one for each GPU) with it's own trainloader, and its own RAM image cache. The DistributedSampler used by DDP (https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py) will feed only a subset of images (1/WORLD_SIZE) to each available GPU on each epoch, but since the images are shuffled between epochs, each GPU process must still cache all images. So I created a subclass of DistributedSampler called SmartDistributedSampler that forces each GPU process to always sample the same subset (using modulo arithmetic with RANK and WORLD_SIZE) while still allowing random shuffling between epochs. I don't believe this disrupts the overall "randomness" of the sampling, and I haven't noticed any performance degradation. Signed-off-by: davidsvaughn <davidsvaughn@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update dataloaders.py move extra parameter (rank) to end so won't mess up pre-existing positional args * Update dataloaders.py removing extra '#' * Update dataloaders.py sample from DDP index array (self.idx) in mixup mosaic * Merging self.indices and self.idx (DDP indices) into single attribute (self.indices). Also adding SmartDistributedSampler to segmentation dataloader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Multiply GB displayed by WORLD_SIZE --------- Signed-off-by: davidsvaughn <davidsvaughn@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
I'm trying to train on a >100K dataset with large image sizes (5000*5000). There's no
results.txt
in the runs/exp, so I'm not sure what information can I provide about the data.🐛 Bug
As it says in the title, during the training it tries to cache the images again after they are cached. In this specific case it results in the memory overload, so I'll try it with more memory and report the results, but the question is, why does it happen at all?
To Reproduce (REQUIRED)
Input:
Output:
Environment
The text was updated successfully, but these errors were encountered: