Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU memory issues (leak?) #439

Open
psinger opened this issue Aug 16, 2019 · 43 comments
Open

GPU memory issues (leak?) #439

psinger opened this issue Aug 16, 2019 · 43 comments

Comments

@psinger
Copy link

psinger commented Aug 16, 2019

I am running a loop where I initialize a new model in each loop and train it. I am using NVIDIA Apex for mixed precision training. My current issue is that there seems to be some unwanted memory allocations across different steps in the loop. The GPU memory accumulates and after a few steps in the loop CUDA memory runs out.

I have debugged everything, have monitored memory, and have deleted every single thing possible. Only after removing apex the memory allocation seems to be consistent. I am doing nothing else than adding the three lines of code from the tutorial for initializing and backward passing.

Any ideas?

@ptrblck
Copy link
Contributor

ptrblck commented Aug 16, 2019

Hi @psinger,

we currently support running amp.initialize only once (docs).
Would it be possible in your use case to reset the model's parameters or initialize the models in a single amp.initialize call?

@psinger
Copy link
Author

psinger commented Aug 18, 2019

Thanks @ptrblck. I changed my code to only run amp.initialize once and reload model weights in each iteration. I add multiple optimizers to the initialization and load a separate optimizer then in each loop. Unfortunately I still have memory issues and after a few iterations through the loop CUDA memory runs out.

@ptrblck
Copy link
Contributor

ptrblck commented Aug 18, 2019

Thanks for the update!
I believe the increase in memory might not be related to amp, but is probably caused by storing the internal parameters of the optimizers.
Have a look at this dummy example:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from apex import amp

import copy


use_amp = False
clean_opt = False

device='cuda'

model = models.resnet18()
model.to(device)
state_dict = copy.deepcopy(model.state_dict())
optimizers = [optim.Adam(model.parameters(), lr=1e-3) for _ in range(3)]
if use_amp:
    model, optimizers = amp.initialize(model, optimizers, opt_level='O1')

dataset = datasets.FakeData(transform=transforms.ToTensor())
loader = DataLoader(
    dataset,
    batch_size=64,
    num_workers=4,
    pin_memory=True,
    shuffle=True
)
criterion = nn.CrossEntropyLoss()

print('Memory allocated {:.3f}MB'.format(
    torch.cuda.memory_allocated() / 1014**2))

for opt_idx, optimizer in enumerate(optimizers):
    # reset model
    model.load_state_dict(state_dict)
    
    # Train
    for epoch in range(5):
        for data, target in loader:
            data = data.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            if use_amp:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()
            
        print('OptIdx {}, epoch {}, loss {}, mem allocated {:.3f}MB'.format(
            opt_idx, epoch, loss.item(), torch.cuda.memory_allocated()/1024**2))

    if clean_opt:
        optimizers[opt_idx] = None

If you leave the default settings as use_amp = False, clean_opt = False, you will see a constant memory usage during the training and an increase after switching to the next optimizer.
Setting clean_opt=True will delete the optimizers and thus clean the additional memory.
However, this cleanup doesn't seem to work properly using amp at the moment.

Thanks for reporting! We'll have a look.

@psinger
Copy link
Author

psinger commented Aug 19, 2019

Thanks again @ptrblck. I am actually doing what you are suggesting, namely deleting the optimizers in the end, but as you correctly state it does not seem to work with amp (similarly it might not work for other variables). Hope you can take a look.

@BramVanroy
Copy link
Contributor

BramVanroy commented Oct 8, 2019

This is an important bug for me as well. I am testing multiple hyperparameters sequentially. That means that in each iteration I set up everything from scratch. So also the model and optimizer are recreated in each iteration, and thus amp.initialize is also called on each iteration. After a model has been trained, I try to clear caches by running

torch.cuda.empty_cache()
gc.collect()

Unfortunately to no avail. After some iterations, I run into a CUDA OOM error when using AMP. Any form of parameter tuning doesn't seem possible, then. (For reference: I could test 5 parameter settings when fine-tuning RoBERTa before I got an OOM on an RTX 2080 Ti.)

I was running the same process (but fine-tuning another model) on 2x V100's. That process is still running (they have 16GB of RAM, RTX 2080 Ti only has 11GB), but looking at its memory usage via nvidia-smi I can see that the main device is using much more memory than the secondary GPU. Perhaps that's useful information

+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0     19000      C   python                                     12379MiB |
|    1     19000      C   python                                      7021MiB |
+-----------------------------------------------------------------------------+

@ptrblck
Copy link
Contributor

ptrblck commented Oct 8, 2019

@BramVanroy would it be possible to reuse the model and optimizer and just reinitialize them?
If not, could you post a (minimal) code snippet to reproduce this issue?

@BramVanroy
Copy link
Contributor

BramVanroy commented Oct 9, 2019

@ptrblck Thank you for your time.

Re-using is not possible in my case, since depending on the parameters the architecture can change (e.g. no intermediate linear layer). I distilled the scripts that I use to a working example that only uses DistilBERT. I included testing data of 2000 sentences train/dev. You can find the repo here. Documentation is bad, but the only thing you need to use is the gh_predict.py as an entry point. Requirements are P3.6(+), torch1.1(+), transformers, tqdm.

Training shouldn't take too long (dataset is quite small). The script will train on four different parameters. You can see the difference between just running python gh_predict.py and python gh_predict.py --fp16 O1.

Initialising amp happens here

https://github.com/BramVanroy/apex_mem_issue/blob/2b5513bf01cf6f123f5d9da78daf95179aac500f/gh_trainer.py#L79-L87

Below you can find the nvidia-smi results when each new parameter test started, first without using AMP, and then when using AMP.

Without AMP

Parameter setting #1 
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0      9681      C   python                                      4903MiB |
|    1      9681      C   python                                      3971MiB |
|    2      9681      C   python                                      3971MiB |
|    3      9681      C   python                                      3971MiB |
+-----------------------------------------------------------------------------+

Parameter setting #2
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0      9681      C   python                                      4911MiB |
|    1      9681      C   python                                      3973MiB |
|    2      9681      C   python                                      3973MiB |
|    3      9681      C   python                                      3973MiB |
+-----------------------------------------------------------------------------+

Parameter setting #3
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0      9681      C   python                                      4949MiB |
|    1      9681      C   python                                      3979MiB |
|    2      9681      C   python                                      3979MiB |
|    3      9681      C   python                                      3979MiB |
+-----------------------------------------------------------------------------+

Parameter setting #4
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0      9681      C   python                                      5001MiB |
|    1      9681      C   python                                      3991MiB |
|    2      9681      C   python                                      3991MiB |
|    3      9681      C   python                                      3991MiB |
+-----------------------------------------------------------------------------+

With AMP

Parameter setting #1
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0     13185      C   python                                      4235MiB |
|    1     13185      C   python                                      3237MiB |
|    2     13185      C   python                                      3237MiB |
|    3     13185      C   python                                      3237MiB |
+-----------------------------------------------------------------------------+

Parameter setting #2
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0     13185      C   python                                      4589MiB |
|    1     13185      C   python                                      3245MiB |
|    2     13185      C   python                                      3245MiB |
|    3     13185      C   python                                      3245MiB |
+-----------------------------------------------------------------------------+

Parameter setting #3
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0     13185      C   python                                      5039MiB |
|    1     13185      C   python                                      3251MiB |
|    2     13185      C   python                                      3251MiB |
|    3     13185      C   python                                      3251MiB |
+-----------------------------------------------------------------------------+

Parameter setting #4
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0     13185      C   python                                      5603MiB |
|    1     13185      C   python                                      3309MiB |
|    2     13185      C   python                                      3309MiB |
|    3     13185      C   python                                      3309MiB |
+-----------------------------------------------------------------------------+

The important part is not the absolute numbers, but the relative increase of memory usage of GPU:0 at each new parameter setting. In the case of using AMP, it increases with 10% every time, whereas without AMP it only increases by 1%.

@psinger
Copy link
Author

psinger commented Oct 9, 2019

@ptrblck According to my insights and your example above, re-using the model is no help as deleting e.g., the optimizer does not work.

@BramVanroy
Copy link
Contributor

I've been casually reading through the source code, and I was wondering how hard it would be to implement a 'destroy' method that frees all AMP objects, ready for garbage collection. It seems that most things go through the AmpHandle, _amp_state, and master_params. Is there a way to clear those out?

@hatzel
Copy link

hatzel commented Oct 25, 2019

I ran into this issue as well, sadly this results in pretty hard to debug memory leaks. I'd suggest showing a warning when calling initialize twice. I'd also love to see @BramVanroy's idea implemented, but no idea how hard it would actually be.

In my case I can actually reuse the model but it's not convenient. I run hypermeter optimization and the whole model initialization is capsuled off from the main optimization loop. So passing around the model and optimizer states is cumbersome.

@yangyiben
Copy link

this issue seriously affects me

@mcarilli
Copy link
Contributor

Are you using O1 or O2 when you see these memory leak issues? O1 does not create any additional parameters, so it should not leak memory.

@yangyiben
Copy link

Are you using O1 or O2 when you see these memory leak issues? O1 does not create any additional parameters, so it should not leak memory.

I am using O1, the memory usage constantly going up when I recreate the model and optimizer

@yangyiben
Copy link

yangyiben commented Oct 25, 2019

Are you using O1 or O2 when you see these memory leak issues? O1 does not create any additional parameters, so it should not leak memory.

Also, if i turn off fp16 mode, then there is no memory issue. So it must be apex's problem

@mcarilli
Copy link
Contributor

Is the memory consumption going up on the CPU, GPU, or both?

@yangyiben
Copy link

Is the memory consumption going up on the CPU, GPU, or both?

I am not monitoring cpu, for GPU, it is like 10% more memory each time i recreate model and optimizer.

@psinger
Copy link
Author

psinger commented Oct 25, 2019

This bug actually hinders me from using Apex most of the time. I am constantly running multiple models in single notebooks / scripts.

@hatzel
Copy link

hatzel commented Oct 26, 2019

Just in case you have trouble reproducing this this is what I used for debugging:

from collections import Counter
import gc
import torch
from apex import amp
import torch.nn as nn

cached = []
allocated = []


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.linear = nn.Linear(1024, 1024)

    def forward(self, x):
        x = self.linear(x)
        return x


def get_tensor_sizes():
    ctr = Counter()
    gc.collect()
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
                ctr.update([tuple(obj.size())])
        except OSError: pass
        except RuntimeError: pass
    return ctr


for _ in range(1000):
    gc.collect()
    print(get_tensor_sizes())
    model = Net().to("cuda")
    optimizer = torch.optim.SGD(model.parameters(), 0.1)
    model(torch.rand(1024).to("cuda"))
    amp.initialize(
        model,
        optimizer,
        opt_level="O1",
    )
    allocated.append(torch.cuda.memory_allocated())
    cached.append(torch.cuda.memory_cached())
    print(f"Cached: {cached}")
    print(f"Allocated: {allocated}")

As you will see cuda memory usage rises over time. Even though the model clearly should be deleted (which should be enforced by the gc.collect call). With sufficient model size you will eventually run into a cuda out of memory error.

I am not 100% sure if this example could be further minimized but the model() call seems to be required.

@BramVanroy
Copy link
Contributor

Are you using O1 or O2 when you see these memory leak issues? O1 does not create any additional parameters, so it should not leak memory.

Please see my previous comment #439 (comment) where I provide a repo to test. The numbers given is the memory usage increase with O1.

hatzel added a commit to hatzel/neural-spoiler-detection that referenced this issue Nov 2, 2019
As documented here (and in the official documentation)
NVIDIA/apex#439 we shouldn't call
apex.initialize twice. To avoid this we retain the original model,
loading state dicts of new optmizers and models for each run.
nalourie-ai2 added a commit to allenai/scruples that referenced this issue Nov 14, 2019
apex creates a memory leak (see NVIDIA/apex#439)
where models cannot be garbage collected. This makes apex very
difficult to use at the moment during hyper-parameter tuning.
@MrHuff
Copy link

MrHuff commented Nov 26, 2019

Hi!

Bumping this post, I am also experiencing the exact same issue with O1 mode. Using hyperopt library and making sequential calls to initialize new models with different hyperparameters that are then wrapped in amp.initialize. GPU memory growth is roughly 10% per iteration.

Thank you for your help!
Best regards,
Robert

@sooperset
Copy link

I also have this problem with 'O1' and 'O2' modes.

@BramVanroy
Copy link
Contributor

I think the scenario of @MrHuff is common: using hyperparameter optimization requires multiple initializations. I expect that most time on Apex is dedicated to integrating it in upstream, so it would be great that by default the amp initialization is NOT indestructible.

@psinger
Copy link
Author

psinger commented Jan 28, 2020

This is still a major bug. Tons of people complaining about the same.

@BramVanroy
Copy link
Contributor

This is still a major bug. Tons of people complaining about the same.

It's not a bug per se, I think. AMP was simply never intended to be initialised more than once per session. That being said, it would be nice to see an improvement on this aspect when it gets upstream support.

@avostryakov
Copy link

Just for the record. We hit the same problem with memory leaks with apex. Without apex, the same code works without leaks as all people said above.

@uint64t
Copy link

uint64t commented Mar 12, 2020

For recording as well. I hit the same problem when I use AMP in training steps, and it seems to have memory leaks problem when I do the evaluation.

@jgsch
Copy link

jgsch commented Mar 23, 2020

Same issue using Ray Tune for hyperparameters optimization.

@Moradnejad
Copy link

I have the same issue in Kaggle for testing a few models. Can someone tell us the best practice?

@BramVanroy
Copy link
Contributor

AMP is now available in PyTorch's nightlies. The brave experimentalists can try it out now. Documentation looks good already!

https://pytorch.org/docs/master/notes/amp_examples.html#amp-examples

@psinger
Copy link
Author

psinger commented Apr 3, 2020

AMP is now available in PyTorch's nightlies. The brave experimentalists can try it out now. Documentation looks good already!

https://pytorch.org/docs/master/notes/amp_examples.html#amp-examples

If someone testing it could check if it fixes this issue I would be grateful.

@BramVanroy
Copy link
Contributor

It;s probably a bit too early for that to make exhaustive comments on the functionality: the functionality has only just been introduced as part of a nightly version.

@psinger
Copy link
Author

psinger commented Apr 6, 2020

It;s probably a bit too early for that to make exhaustive comments on the functionality: the functionality has only just been introduced as part of a nightly version.

I have been testing it a bit, and it looks way better. Could not replicate the memory issues yet. As you say, too early to say, but looks promising.

@mcarilli
Copy link
Contributor

mcarilli commented Apr 6, 2020

Glad you're trying the upstream API!! I anticipate it'll solve any memory leak issues associated with amp.initialize, since that's not a thing anymore. (NB: Right now topk is broken with FP16 on master. I'd give it a week or two to be sure the fix goes in.)

One thing that may not be obvious from the upstream docs:
If you're performing multiple convergence runs in the same script, you should use a new GradScaler instance for each run. GradScaler is a lightweight, self-contained object, so you can construct a new one anytime with the usual

scaler = torch.cuda.amp.GradScaler() # replaces the old scaler instance

I could add a scaler.reset() method to serve the same purpose, but in the meantime, the above should work.

It's also permissible to have multiple GradScalers constructed at once, as long as

  • each convergence run gets a fresh scaler, and uses only that scaler for all epochs
  • each scaler is used for only one convergence run

(in other words, there is a 1-1 correspondence between convergence runs and GradScaler instances). Their memory use is small, but not zero: if you create new GradScalers over time, just don’t keep references to dozens or hundreds of stale ones after you're done using them.

@BramVanroy
Copy link
Contributor

Hey @mcarilli thanks for the reply and your work on integrating AMP with upstream!

A reset method seems useful but (to me) it seems that this would only be useful if it is faster than creating a new instance from scratch. Otherwise recreating the object is just as easy.

Finally, the docs might benefit from a pros/cons comparison. Discussing the cons (if any) might be useful. (Because people might wonder what the downside is, if this AMP is not enabled by default.)

Thanks again

@mcarilli
Copy link
Contributor

mcarilli commented Apr 6, 2020

A reset method seems useful but (to me) it seems that this would only be useful if it is faster than creating a new instance from scratch.

it's so lightweight (a few ints/floats/dicts and a couple of one-element tensors) that it shouldn't make a difference either way.

Finally, the docs might benefit from a pros/cons comparison.

which docs? native docs? do you mean pros and cons of Apex Amp vs native amp? I'd rather not mention Apex at all in the native docs, I want native docs to be self-contained.

@BramVanroy
Copy link
Contributor

After thinking about it a bit more, I think a reset method would be useful after all if you initialised the scaler outside the loop.

Yes, the native docs. But I meant a comparison of using the native AMP vs. not using AMP at all. So a bit more general. This can be one simple line, but it would help beginners to understand the benefit of AMP and any downsides (if any). The latter would be important because one may wonder: if there are no downsides, then why is it not enabled by default?

@psinger
Copy link
Author

psinger commented Apr 7, 2020

@mcarilli Thanks a lot for your contribution! This is a life safer for many!

I think reeinitializing the GradScaler should be not a problem. What I am wondering about is when to properly call scaler.update() in case of gradient accumulation? After each scale or after each step call? I would expect the second one to be correct. Thanks!

@mcarilli
Copy link
Contributor

mcarilli commented Apr 7, 2020

@BramVanroy

I think a reset method would be useful after all if you initialised the scaler outside the loop.

Why would overwriting the local reference named scaler not be sufficient? Are you imagining a case where the same scaler instance is used/passed to many local scopes, and saying

scaler = torch.cuda.amp.GradScaler() # replaces scaler in current scope

would only overwrite it in that scope, while scaler.reset() would affect scaler globally, in all scopes that held a reference to it?

@psinger

What I am wondering about is when to properly call scaler.update() in case of gradient accumulation? After each scale or after each step call? I would expect the second one to be correct

You're right, the second, you should only call update() on iterations where you actually step()ed. I'll add gradient accumulation to the examples.

@BramVanroy
Copy link
Contributor

I think that it is cleaner. For instance, assume that you have a Trainer class with a self.scaler property that you set in the constructor, then using a reset in a method is a lot cleaner than reassigning imo.

class Trainer:
    def __init__(self):
        self.scaler = torch.cuda.amp.GradScaler()

    def my_loop(self):
        for config in configs:
            # train models with different configurations
            self.scaler.reset()
            # seems a lot cleaner than here self.scaler = torch.cuda.amp.GradScaler()

facebook-github-bot pushed a commit to pytorch/pytorch that referenced this issue Apr 17, 2020
Summary:
Several people have asked me about proper Amp usage with gradient accumulation.  In particular, it's [unclear to people](NVIDIA/apex#439 (comment)) that you should only call `scaler.unscale_()` (if desired) and `scaler.update()` in iterations where you actually plan to step.  This PR adds a minimal accumulation example.

I built the docs locally and it looks free from sphinx errors, at least.
Pull Request resolved: #36601

Differential Revision: D21082295

Pulled By: ngimel

fbshipit-source-id: b2faa6c02b9f7e1972618a0f1d5360a03f0450ac
@songhuiming
Copy link

Same here. Run successfully with fp32 in the original code, but failed with apex mixed precision because of gpu memory insufficient.

@Daemon-ser
Copy link

Same here. Run successfully with fp32 in the original code, but failed with apex mixed precision because of gpu memory insufficient.

Hello, I got the same problem. But I find that there is no reset() fro GradScaler() according to

I think that it is cleaner. For instance, assume that you have a Trainer class with a self.scaler property that you set in the constructor, then using a reset in a method is a lot cleaner than reassigning imo.

class Trainer:
    def __init__(self):
        self.scaler = torch.cuda.amp.GradScaler()

    def my_loop(self):
        for config in configs:
            # train models with different configurations
            self.scaler.reset()
            # seems a lot cleaner than here self.scaler = torch.cuda.amp.GradScaler()

@zzz123xyz
Copy link

so what is the memory leak solution for the original amp?? put the amp.initialize() in the training loop to initialize the model and optimizer in each training iteration ? Currently I only use amp.initialize() in the beginning after the model is constructed, I got the out of memory problem as well. (btw, I want to train on multiple gpus with average memory, because the model is bigger than the gpu memory of individual Card, but I could not make it because of the out of memory problem)

@Daemon-ser
Copy link

Daemon-ser commented May 17, 2021 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests