Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU Memory Leak on Loading Pre-Trained Checkpoint #6515

Closed
2 tasks done
bilzard opened this issue Feb 2, 2022 · 8 comments · Fixed by #6516
Closed
2 tasks done

GPU Memory Leak on Loading Pre-Trained Checkpoint #6515

bilzard opened this issue Feb 2, 2022 · 8 comments · Fixed by #6516
Labels
bug Something isn't working

Comments

@bilzard
Copy link
Contributor

bilzard commented Feb 2, 2022

Search before asking

  • I have searched the YOLOv5 issues and found no similar bug report.

YOLOv5 Component

Training

Bug

Training YOLO from a checkpoint (*.pt) consumes more GPU memory than training from a pre-trained weight (i.e. yolov5l).

Environment

  • YOLO: YOLOv5 (latest; how to check the yolo version?)
  • CUDA: 11.6 (Tesla T4, 15360MiB)
  • OS: Ubuntu 18.04.6 LTS (Bionic Beaver)
  • Python: 3.8.12

Minimal Reproducible Example

In the below training command, case 2 requires more GPU memory than case 1.

# 1. train from pre-trained model
train.py ... --weights yolov5l

# 2. train from pre-trained checkpoint
train.py ... --weights pre_trained_checkpoint.pt

Additional

As reported on the pytorch forum[1], loading state dict on CUDA device causes memory leak. We should load it on CPU memory:

state_dict = torch.load(directory, map_location=lambda storage, loc: storage)

Are you willing to submit a PR?

  • Yes I'd like to help by submitting a PR!
@bilzard bilzard added the bug Something isn't working label Feb 2, 2022
@glenn-jocher
Copy link
Member

@bilzard thanks for the PR! Would it make more sense (less code or easier to understand) to just load directly on CPU with one of these other options? i.e.

state_dict = torch.load(directory, map_location=lambda storage, loc: storage)
state_dict = torch.load(directory)  # option 2
state_dict = torch.load(directory, map_location=torch.device('cpu'))  # option 3

@bilzard
Copy link
Contributor Author

bilzard commented Feb 4, 2022

@glenn-jocher Option 2 shouldn't work because default is loaded to GPU.
Option 3 seems O.K. according to the official document[1]. let me check if it work.

When you call torch.load() on a file which contains GPU tensors, those tensors will be loaded to GPU by default. You can call torch.load(.., map_location='cpu') and then load_state_dict() to avoid GPU RAM surge when loading a model checkpoint.

@bilzard
Copy link
Contributor Author

bilzard commented Feb 4, 2022

I checked option 3 worked on my server (GPU memory wasn't increased).

$ watch nvidia-smi # map-location = cpu, --weigts='yolov5l'
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.39.01    Driver Version: 510.39.01    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            On   | 00000000:00:1E.0 Off |                    0 |
| N/A   61C    P0    41W /  70W |  14603MiB / 15360MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      4509      C   python                          14599MiB |
+-----------------------------------------------------------------------------+

$ watch nvidia-smi # map-location = cpu, --weigts=path_to_pretrained.pt
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.39.01    Driver Version: 510.39.01    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            On   | 00000000:00:1E.0 Off |                    0 |
| N/A   35C    P0    71W /  70W |  14541MiB / 15360MiB |     91%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      2669      C   python                          14537MiB |
+-----------------------------------------------------------------------------+

@bilzard
Copy link
Contributor Author

bilzard commented Feb 4, 2022

@glenn-jocher

I think we should also fix code for loading model from torch hub, but I don't know how to test.
What should I do for that?

https://github.com/ultralytics/yolov5/blob/master/hubconf.py#L52

@bilzard
Copy link
Contributor Author

bilzard commented Feb 4, 2022

FYI: I fixed code as option 3.

@glenn-jocher
Copy link
Member

@glenn-jocher

I think we should also fix code for loading model from torch hub, but I don't know how to test. What should I do for that?

https://github.com/ultralytics/yolov5/blob/master/hubconf.py#L52

Good question. For PyTorch Hub we may want to leave as is for startup time speeds. Since PyTorch Hub models may be used in APIs like https://ultralytics.com/yolov5 that are only called once the response time may be more important than reducing the CUDA usage slightly.

Another point is that simple inference uses much less CUDA memory than training, mabe only about 1/3 or 1/2 of training memory. But I also am not sure, it would need some study.

@bilzard
Copy link
Contributor Author

bilzard commented Feb 5, 2022

O.K. We need to study response time when changing loading method. Then I will stay it as is at this time.

@glenn-jocher
Copy link
Member

@bilzard yes that's correct. For training an extra second on initializing won't matter, nor in val.py, but for detect.py and PyTorch Hub we probably want to prioritize fastest time to get first results returned.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants