Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support qadam in fused optimizer #477

Merged
merged 3 commits into from
Jan 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 38 additions & 27 deletions bagua/torch_api/algorithms/q_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,16 @@ def __init__(
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
if warmup_steps <= 0:
raise ValueError(
"Invalid warmup_steps parameter, must be larger than 0: {}".format(
warmup_steps
)
)

super(QAdamOptimizer, self).__init__(params, defaults)
# TODO: qadam optimizer maintain `step_id` in its state
self.step_id = 0
self.warmup_steps = warmup_steps

# initialize momentum and variance
for group_id, group in enumerate(self.param_groups):
params_with_grad = []
for p in group["params"]:
params_with_grad.append(p)
state = self.state[p]
if len(state) == 0:
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)

def __setstate__(self, state):
super(QAdamOptimizer, self).__setstate__(state)

Expand All @@ -71,7 +62,6 @@ def step(self, closure=None):
with torch.enable_grad():
loss = closure()

self.step_id += 1
for group_id, group in enumerate(self.param_groups):

lr = group["lr"]
Expand All @@ -82,14 +72,30 @@ def step(self, closure=None):
for param_id, param in enumerate(group["params"]):
state = self.state[param]

if self.step_id < self.warmup_steps:
state["exp_avg"].mul_(beta1).add_(param.grad, alpha=1 - beta1)
if len(state) == 0:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
state["exp_avg_sq"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)

state["step"] += 1
step_id = state["step"]

grad = param.grad
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)

if step_id < self.warmup_steps:
state["exp_avg"].mul_(beta1).add_(grad, alpha=1 - beta1)
state["exp_avg_sq"].mul_(beta2).addcmul_(
param.grad, param.grad, value=1 - beta2
grad, grad, value=1 - beta2
)

bias_correction1 = 1 - beta1 ** self.step_id
bias_correction2 = 1 - beta2 ** self.step_id
bias_correction1 = 1 - beta1 ** step_id
bias_correction2 = 1 - beta2 ** step_id

denom = (state["exp_avg_sq"].sqrt() / math.sqrt(bias_correction2)).add_(
eps
Expand Down Expand Up @@ -123,10 +129,15 @@ def __init__(
self.optimizer = q_adam_optimizer
self.warmup_steps = self.optimizer.warmup_steps

@property
def optimizer_step_id(self):
param = self.optimizer.param_groups[0]["params"][0]
return self.optimizer.state[param].get("step", 0)

def need_reset(self):
if self.optimizer.step_id == self.warmup_steps:
if self.optimizer_step_id == self.warmup_steps:
print(
"QAdam starts to compress from step {}".format(self.optimizer.step_id)
"QAdam starts to compress from step {}".format(self.optimizer_step_id)
)
return True
else:
Expand All @@ -142,7 +153,7 @@ def init_tensors(self, bagua_ddp: BaguaDistributedDataParallel):
tensor_groups = []
for group in self.optimizer.param_groups:
for param in group["params"]:
if self.optimizer.step_id < self.warmup_steps:
if self.optimizer_step_id < self.warmup_steps:
# register grad
registered_tensor = param.bagua_ensure_grad().ensure_bagua_tensor(
param._q_adam_name,
Expand Down Expand Up @@ -188,7 +199,7 @@ def init_operations(
bucket: BaguaBucket,
):
bucket.clear_ops()
if self.optimizer.step_id < self.warmup_steps:
if self.optimizer_step_id < self.warmup_steps:
bucket.append_centralized_synchronous_op(
hierarchical=False,
average=True,
Expand Down Expand Up @@ -227,7 +238,7 @@ def hook_grad(parameter_name, parameter):
parameter.bagua_mark_communication_ready()

return (
hook_grad if self.optimizer.step_id < self.warmup_steps else hook_momentum
hook_grad if self.optimizer_step_id < self.warmup_steps else hook_momentum
)


Expand Down
71 changes: 40 additions & 31 deletions bagua/torch_api/contrib/fuse/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import itertools
from bagua.torch_api.utils import check_contiguous, get_flattened_tensor
from collections import defaultdict
import gorilla


Expand Down Expand Up @@ -303,12 +304,14 @@ def make_optimizer_instance(optimizer: torch.optim.Optimizer):
setattr(new_optimizer, attr, getattr(optimizer, attr))
optimizer._bagua_cloned_attrs[attr] = getattr(optimizer, attr)

# Note: optimizer and fused optimizer share the same state, but different param groups
# Note: fused optimizer has its own copy of param groups and param state
new_optimizer.param_groups = []
for group in optimizer.param_groups:
new_group = {"params": list(group["params"])}
new_optimizer.add_param_group(new_group)

new_optimizer.state = defaultdict(dict)

return new_optimizer


Expand Down Expand Up @@ -351,6 +354,7 @@ def do_fuse(optimizer: torch.optim.Optimizer):
weights = [p.data for p in params]
grads = [p.grad for p in params]

# Find original param state
state_tensors, state_scalars = get_optimizer_param_states(optimizer, params)

if state_tensors is None:
Expand Down Expand Up @@ -418,18 +422,18 @@ def do_fuse(optimizer: torch.optim.Optimizer):
new_params.append(fp)
active_param_ids.append(id(fp))

# sync state tensors
# sync state tensors for fused optimizer
for name, tensors in state_tensors.items():
grouped_state = group_tensors(tensors, indices)

if fp not in optimizer.state or not _can_reuse_tensor(
optimizer.state[fp][name], *grouped_state
if fp not in _fused_optimizer.state or not _can_reuse_tensor(
_fused_optimizer.state[fp].get(name, None), *grouped_state
):
optimizer.state[fp][name] = _create_tensor(*grouped_state)
_fused_optimizer.state[fp][name] = _create_tensor(*grouped_state)

# sync state scalars
# sync state scalars for fused optimizer
for name, scalar in state_scalars.items():
optimizer.state[fp][name] = scalar
_fused_optimizer.state[fp][name] = scalar

# add non-contiguous params
grouped_indices_flat = list(itertools.chain.from_iterable(grouped_indices))
Expand All @@ -438,11 +442,14 @@ def do_fuse(optimizer: torch.optim.Optimizer):
new_params.append(param)
active_param_ids.append(id(param))

for name, v in optimizer.state[param].items():
_fused_optimizer.state[param][name] = v

# clear outdated states
for fp in fused_group["params"]:
if id(fp) not in active_param_ids:
if id(fp) not in active_param_ids and fp in _fused_optimizer.state:
logging.debug("delete outdated params")
del optimizer.state[fp]
del _fused_optimizer.state[fp]

fused_group["params"] = new_params

Expand All @@ -469,36 +476,38 @@ def sync_param_group_scalars(src_group, dst_group):
def sync_optimizer_state(optimizer):
# write back state for original params
# Note: we should make sure every module parameter in original params groups has the right state
_fused_optimizer = optimizer._bagua_fused_optimizer
for group, fused_group in zip(
optimizer.param_groups, optimizer._bagua_fused_optimizer.param_groups
optimizer.param_groups, _fused_optimizer.param_groups
):

params = group["params"]
fused_params = fused_group["params"]

for fp in fused_params:
if not hasattr(fp, "_bagua_fused_param_ids"):
# skip original params
continue

original_params = [params[i] for i in fp._bagua_fused_param_ids]

for name, v in optimizer.state[fp].items():
if isinstance(v, torch.Tensor):
state_tensors = infer_state_tensors(
optimizer, fp, original_params, name
)

if state_tensors is not None:
for p, state in zip(original_params, state_tensors):
optimizer.state[p][name] = state
else:
for p in original_params:
optimizer.state[p][name] = v

for p in original_params:
if len(optimizer.state[p]) != len(optimizer.state[fp]):
logging.warning("Something went wrong with optimizer state.")
for name, v in _fused_optimizer.state[fp].items():
optimizer.state[fp][name] = v

else:
original_params = [params[i] for i in fp._bagua_fused_param_ids]

for name, v in _fused_optimizer.state[fp].items():
if isinstance(v, torch.Tensor):
state_tensors = infer_state_tensors(
_fused_optimizer, fp, original_params, name
)

if state_tensors is not None:
for p, state in zip(original_params, state_tensors):
optimizer.state[p][name] = state
else:
for p in original_params:
optimizer.state[p][name] = v

for p in original_params:
if len(optimizer.state[p]) != len(_fused_optimizer.state[fp]):
logging.warning("Something went wrong with optimizer state.")


def get_tensor_state(optimizer, param, name):
Expand Down
5 changes: 3 additions & 2 deletions examples/benchmark/synthetic_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@
model.cuda()

optimizer = optim.SGD(model.parameters(), lr=0.01 * bagua.get_world_size())
if args.fuse_optimizer:
optimizer = bagua.contrib.fuse_optimizer(optimizer)

if args.algorithm == "gradient_allreduce":
from bagua.torch_api.algorithms import gradient_allreduce
Expand Down Expand Up @@ -161,6 +159,9 @@

model = model.with_bagua([optimizer], algorithm, do_flatten=not args.fuse_optimizer)

if args.fuse_optimizer:
optimizer = bagua.contrib.fuse_optimizer(optimizer)

# Set up fixed fake data
data = torch.randn(args.batch_size, 3, 224, 224)
target = torch.LongTensor(args.batch_size).random_() % 1000
Expand Down
6 changes: 3 additions & 3 deletions examples/mnist/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,6 @@ def main():
model = Net().cuda()
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

if args.fuse_optimizer:
optimizer = bagua.contrib.fuse_optimizer(optimizer)

if args.algorithm == "gradient_allreduce":
from bagua.torch_api.algorithms import gradient_allreduce

Expand Down Expand Up @@ -264,6 +261,9 @@ def main():
do_flatten=not args.fuse_optimizer,
)

if args.fuse_optimizer:
optimizer = bagua.contrib.fuse_optimizer(optimizer)

scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
if args.algorithm == "async":
Expand Down
4 changes: 1 addition & 3 deletions tests/contrib/test_fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def bagua_init(model, optimizer, algorithm, do_flatten):
elif algorithm == "qadam":
from bagua.torch_api.algorithms.q_adam import QAdamAlgorithm, QAdamOptimizer

optimizer = QAdamOptimizer(model.parameters(), warmup_steps=1)
optimizer = QAdamOptimizer(optimizer.param_groups, warmup_steps=1)
bagua_algorithm = QAdamAlgorithm(optimizer, hierarchical=False)
else:
raise ValueError("unsupported algorithm")
Expand Down Expand Up @@ -304,7 +304,6 @@ def run_all_optimizers_once(self, fn1, fn2, device, num_epochs, fused_count):
count += 1
if count % 5 == 0:
logging.info(f"Tests Passed [{count}/{len(optimizer_list)}]")
# return

def run_fused_with_bagua_wrapper(self, fn1, fn2, num_epochs, fused_count):
self.run_all_optimizers_once(fn1, fn2, "cuda:0", num_epochs, fused_count)
Expand Down Expand Up @@ -428,7 +427,6 @@ def test_low_prec_decentralized(self):

@skip_if_cuda_not_available()
def test_qadam(self):
return
setup_bagua_env()
self.run_qadam(
device="cuda:0",
Expand Down
52 changes: 52 additions & 0 deletions tests/torch_api/test_qadam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import unittest
import torch
import torch.nn as nn
import torch.nn.functional as F
from bagua.torch_api.algorithms.q_adam import QAdamOptimizer
from tests import skip_if_cuda_available


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(2, 10, bias=False)
self.fc2 = nn.Linear(10, 50, bias=True)
self.fc3 = nn.Linear(50, 4, bias=False)
self.relu = nn.ReLU()

def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return F.softmax(x, dim=1)


def run_step(opt_cls, opt_flags, seed):
torch.manual_seed(seed)
model = Net()
optimizer = opt_cls(model.parameters(), **opt_flags)

for step in range(1000):
data = torch.randn(4, 2)
target = torch.randn(4, 4)

optimizer.zero_grad()
output = model(data)
loss = nn.MSELoss()(output, target)

loss.backward()
optimizer.step()

return loss


class TestQAdam(unittest.TestCase):
@skip_if_cuda_available()
def test_qadam_optimizer(self):
loss1 = run_step(torch.optim.Adam, {"lr": 0.001, "weight_decay": 0.1}, seed=13)
loss2 = run_step(
QAdamOptimizer,
{"lr": 0.001, "weight_decay": 0.1, "warmup_steps": 2000},
seed=13,
)
self.assertEqual(loss1.item(), loss2.item())