-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Why does FusedOptimizer has a huge impact on model precision? #335
Comments
Thanks for opening the issue. FusedOptimizer is expected to give the same result as the original one for any torch optimizer. We'll investigate this case. |
I reproduced this problem on a very simply example, fixed the model params and input, and got the same result. When using Lamb optimizer, the param update result in each step is different compared with that without FusedOptimizer. When using Adam optimizer, the param update result is the same. So I think it's probably related with the lamb optimizer. import torch
from torch.nn.modules.loss import CrossEntropyLoss
from utils.LAMB_pt import LAMB
from bagua.torch_api.contrib import FusedOptimizer
import torch.nn as nn
import torch.optim
if __name__ == '__main__':
input = torch.load('input.pt')
label = torch.load('label.pt')
model = torch.load('model.pt')
# model = nn.Sequential(
# nn.Linear(10, 5),
# nn.Linear(5, 2),
# nn.Linear(2, 1),
# )
# optimizer = torch.optim.Adam(
# params=model.parameters(),
# lr=0.1,
# betas=(0.9, 0.999),
# eps=1e-06,
# weight_decay=0
# )
optimizer = LAMB(
params=model.parameters(),
lr=0.1,
betas=(0.9, 0.999),
eps=1e-06,
weight_decay=0
)
model.to(0)
optimizer = FusedOptimizer(optimizer, do_flatten=True)
input = input.to(0)
label = label.to(0)
print('original:')
print(optimizer.param_groups[0]['params'][0])
for i in range(10):
print('running new step')
optimizer.zero_grad()
output = model(input)
loss = (output - label).pow(2).sum()
loss.backward()
optimizer.step()
print(optimizer.param_groups[0]['params'][0]) |
Thanks! An example is super useful for us to debug. Could you help provide the pt files also? |
It seems that lamb optimizer uses weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
if group['weight_decay'] != 0:
adam_step.add_(p.data, alpha=group['weight_decay'])
adam_norm = adam_step.pow(2).sum().sqrt()
if weight_norm == 0 or adam_norm == 0:
trust_ratio = 1
else:
trust_ratio = weight_norm / adam_norm
state['weight_norm'] = weight_norm
state['adam_norm'] = adam_norm
state['trust_ratio'] = trust_ratio
p.data.add_(adam_step, alpha=-step_size * trust_ratio) |
@wangraying is working on a less intrusive way to implement fused optimizer in #207. Let's see whether that works in this case. In the worst case we can still go with the easiest solution (that would be disabling fusing for |
oh, that's embarrassing. I'll look into this problem soon. |
The fused optimizer makes an assumption that parameter and its state tensors should have the same data type and size (which is the case for all Pytorch official optimizers). The However, we can easily make it compliant by changing the following two lines in the code you provided: state['weight_norm'] = weight_norm
state['adam_norm'] = adam_norm to state['weight_norm'] = weight_norm.item()
state['adam_norm'] = adam_norm.item() Note that by doing this the Let us know if it works! Thanks |
I'll close this if no more problems raised. |
I wrapped my custom optimizer with FusedOptimizer and the precision was way worse than that without FusedOptimizer. I think FusedOptimizer shouldn't be affecting the model precision. Or is there something wrong with my custom optimizer?
Here is the optimizer I use:
https://github.com/cybertronai/pytorch-lamb/blob/master/pytorch_lamb/lamb.py
The text was updated successfully, but these errors were encountered: