Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

how to use the new api in cross validation #392

Closed
raven4752 opened this issue Jul 14, 2019 · 7 comments
Closed

how to use the new api in cross validation #392

raven4752 opened this issue Jul 14, 2019 · 7 comments

Comments

@raven4752
Copy link

Hi,
As indicated in the document, the amp.initialize() should be called once. But in cross validation, models need to be created in different folds, and creating them before CV requires extra effort. So what is the suggested way to use amp.initialize() with cross validation? I tried to call initialize() multiple times but the gpu memory seems leakaging.

@ptrblck
Copy link
Contributor

ptrblck commented Jul 31, 2019

Hi @raven4752,

do you have a code snippet to reproduce the memory leak?

Regarding the CV: would it be possible to reset the model to the random parameters before starting to train it on a particular fold, and thus just use a single model?

@raven4752
Copy link
Author

Thank you for your reply.
I tried to create a minimum reproducing snippet, running in ubuntu 16.04, python 3.6.8, apex 0.1, torch 1.10, cuda 10.0 with a Titan Xp.

import torch
import torch.nn as nn
from apex import amp
from torch.optim import Adam


def get_gpu_memory_usage(device_id):
    return round(torch.cuda.max_memory_allocated(device_id) / 1000 / 1000)


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(128, 128)

    def forward(self, x):
        return self.fc(x)


for i in range(10):
    model = Model()
    model.to('cuda')
    optimizer = Adam(model.parameters())
    model, optimizer = amp.initialize(model, optimizer, opt_level="O2", verbosity=0)
    for j in range(100):
        t = torch.zeros(100000, 128).to('cuda')
        z = model(t).mean()
        # z.backward()
        with amp.scale_loss(z, optimizer) as scaled_loss:
            scaled_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    print('memory {}'.format(get_gpu_memory_usage(0)))

I can see memory increasing from 154 to 155, while the occupancy is constant with fp32 training. The leakage is more significant in my real scenes.
Regarding the CV, thank you for your suggestion. It is ok to re-init the model in each fold, but it just requires some extra effort to re-design the routine. Compared to the old api, I need to make more changes.

@ptrblck
Copy link
Contributor

ptrblck commented Jul 31, 2019

Thanks for the code!
The leak is reproducible. Since the initialization of multiple models is currently not supported, we would ask you to stick to the second approach (i.e. resetting a single model) for now.

@JohnGiorgi
Copy link

JohnGiorgi commented Aug 15, 2019

Hi, I am having almost the exact same problem. I am using cross-validation and in each fold I call

model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

which was causing memory to accumulate eventually leading to an OOM. To fix this, I changed my code as suggested above, i.e. I only make this call once (before cross-val begins) and then at the end of each fold reset the parameters model rather than create a whole newmodel in each fold.

My question is what to do with the optimizer? Is there a method of resetting its state in PyTorch? If I just re-initialize the optimizer at the end of each fold, it won't have been initialized with amp.initialize().

@raven4752
Copy link
Author

I tried to call optimizer.load_state_dict() of a new optimizer's state, but the information seems leakaging. I turn to old FP_16Optimizer api and call model.half() manually as a solution .

@JohnGiorgi
Copy link

JohnGiorgi commented Aug 15, 2019

@raven4752 I see, thanks for the workaround.

@ptrblck Is it safe to say that we shouldn't be using the new amp API in a cross-validation setting? Because I can only call amp.initialize() once, I can't re-initialize model and optimizer in a cross-val loop, so I need to reset them instead. Resetting parameters can be non-trival for complicated models and as far as I know, there's no good way to reset the optimizers state.

@mcarilli
Copy link
Contributor

mcarilli commented Apr 6, 2020

The recently merged native Amp API should fix this issue and straightforwardly accommodate cross-validation:
https://pytorch.org/docs/master/amp.html
https://pytorch.org/docs/master/notes/amp_examples.html

See #439 (comment) for details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants