Skip to content

Commit

Permalink
feat: support qadam in fused optimizer (#477)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangraying authored Jan 11, 2022
1 parent c6b1c92 commit b1183c7
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 66 deletions.
65 changes: 38 additions & 27 deletions bagua/torch_api/algorithms/q_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,16 @@ def __init__(
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
if warmup_steps <= 0:
raise ValueError(
"Invalid warmup_steps parameter, must be larger than 0: {}".format(
warmup_steps
)
)

super(QAdamOptimizer, self).__init__(params, defaults)
# TODO: qadam optimizer maintain `step_id` in its state
self.step_id = 0
self.warmup_steps = warmup_steps

# initialize momentum and variance
for group_id, group in enumerate(self.param_groups):
params_with_grad = []
for p in group["params"]:
params_with_grad.append(p)
state = self.state[p]
if len(state) == 0:
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)

def __setstate__(self, state):
super(QAdamOptimizer, self).__setstate__(state)

Expand All @@ -71,7 +62,6 @@ def step(self, closure=None):
with torch.enable_grad():
loss = closure()

self.step_id += 1
for group_id, group in enumerate(self.param_groups):

lr = group["lr"]
Expand All @@ -82,14 +72,30 @@ def step(self, closure=None):
for param_id, param in enumerate(group["params"]):
state = self.state[param]

if self.step_id < self.warmup_steps:
state["exp_avg"].mul_(beta1).add_(param.grad, alpha=1 - beta1)
if len(state) == 0:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
state["exp_avg_sq"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)

state["step"] += 1
step_id = state["step"]

grad = param.grad
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)

if step_id < self.warmup_steps:
state["exp_avg"].mul_(beta1).add_(grad, alpha=1 - beta1)
state["exp_avg_sq"].mul_(beta2).addcmul_(
param.grad, param.grad, value=1 - beta2
grad, grad, value=1 - beta2
)

bias_correction1 = 1 - beta1 ** self.step_id
bias_correction2 = 1 - beta2 ** self.step_id
bias_correction1 = 1 - beta1 ** step_id
bias_correction2 = 1 - beta2 ** step_id

denom = (state["exp_avg_sq"].sqrt() / math.sqrt(bias_correction2)).add_(
eps
Expand Down Expand Up @@ -123,10 +129,15 @@ def __init__(
self.optimizer = q_adam_optimizer
self.warmup_steps = self.optimizer.warmup_steps

@property
def optimizer_step_id(self):
param = self.optimizer.param_groups[0]["params"][0]
return self.optimizer.state[param].get("step", 0)

def need_reset(self):
if self.optimizer.step_id == self.warmup_steps:
if self.optimizer_step_id == self.warmup_steps:
print(
"QAdam starts to compress from step {}".format(self.optimizer.step_id)
"QAdam starts to compress from step {}".format(self.optimizer_step_id)
)
return True
else:
Expand All @@ -142,7 +153,7 @@ def init_tensors(self, bagua_ddp: BaguaDistributedDataParallel):
tensor_groups = []
for group in self.optimizer.param_groups:
for param in group["params"]:
if self.optimizer.step_id < self.warmup_steps:
if self.optimizer_step_id < self.warmup_steps:
# register grad
registered_tensor = param.bagua_ensure_grad().ensure_bagua_tensor(
param._q_adam_name,
Expand Down Expand Up @@ -188,7 +199,7 @@ def init_operations(
bucket: BaguaBucket,
):
bucket.clear_ops()
if self.optimizer.step_id < self.warmup_steps:
if self.optimizer_step_id < self.warmup_steps:
bucket.append_centralized_synchronous_op(
hierarchical=False,
average=True,
Expand Down Expand Up @@ -227,7 +238,7 @@ def hook_grad(parameter_name, parameter):
parameter.bagua_mark_communication_ready()

return (
hook_grad if self.optimizer.step_id < self.warmup_steps else hook_momentum
hook_grad if self.optimizer_step_id < self.warmup_steps else hook_momentum
)


Expand Down
71 changes: 40 additions & 31 deletions bagua/torch_api/contrib/fuse/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import itertools
from bagua.torch_api.utils import check_contiguous, get_flattened_tensor
from collections import defaultdict
import gorilla


Expand Down Expand Up @@ -303,12 +304,14 @@ def make_optimizer_instance(optimizer: torch.optim.Optimizer):
setattr(new_optimizer, attr, getattr(optimizer, attr))
optimizer._bagua_cloned_attrs[attr] = getattr(optimizer, attr)

# Note: optimizer and fused optimizer share the same state, but different param groups
# Note: fused optimizer has its own copy of param groups and param state
new_optimizer.param_groups = []
for group in optimizer.param_groups:
new_group = {"params": list(group["params"])}
new_optimizer.add_param_group(new_group)

new_optimizer.state = defaultdict(dict)

return new_optimizer


Expand Down Expand Up @@ -351,6 +354,7 @@ def do_fuse(optimizer: torch.optim.Optimizer):
weights = [p.data for p in params]
grads = [p.grad for p in params]

# Find original param state
state_tensors, state_scalars = get_optimizer_param_states(optimizer, params)

if state_tensors is None:
Expand Down Expand Up @@ -418,18 +422,18 @@ def do_fuse(optimizer: torch.optim.Optimizer):
new_params.append(fp)
active_param_ids.append(id(fp))

# sync state tensors
# sync state tensors for fused optimizer
for name, tensors in state_tensors.items():
grouped_state = group_tensors(tensors, indices)

if fp not in optimizer.state or not _can_reuse_tensor(
optimizer.state[fp][name], *grouped_state
if fp not in _fused_optimizer.state or not _can_reuse_tensor(
_fused_optimizer.state[fp].get(name, None), *grouped_state
):
optimizer.state[fp][name] = _create_tensor(*grouped_state)
_fused_optimizer.state[fp][name] = _create_tensor(*grouped_state)

# sync state scalars
# sync state scalars for fused optimizer
for name, scalar in state_scalars.items():
optimizer.state[fp][name] = scalar
_fused_optimizer.state[fp][name] = scalar

# add non-contiguous params
grouped_indices_flat = list(itertools.chain.from_iterable(grouped_indices))
Expand All @@ -438,11 +442,14 @@ def do_fuse(optimizer: torch.optim.Optimizer):
new_params.append(param)
active_param_ids.append(id(param))

for name, v in optimizer.state[param].items():
_fused_optimizer.state[param][name] = v

# clear outdated states
for fp in fused_group["params"]:
if id(fp) not in active_param_ids:
if id(fp) not in active_param_ids and fp in _fused_optimizer.state:
logging.debug("delete outdated params")
del optimizer.state[fp]
del _fused_optimizer.state[fp]

fused_group["params"] = new_params

Expand All @@ -469,36 +476,38 @@ def sync_param_group_scalars(src_group, dst_group):
def sync_optimizer_state(optimizer):
# write back state for original params
# Note: we should make sure every module parameter in original params groups has the right state
_fused_optimizer = optimizer._bagua_fused_optimizer
for group, fused_group in zip(
optimizer.param_groups, optimizer._bagua_fused_optimizer.param_groups
optimizer.param_groups, _fused_optimizer.param_groups
):

params = group["params"]
fused_params = fused_group["params"]

for fp in fused_params:
if not hasattr(fp, "_bagua_fused_param_ids"):
# skip original params
continue

original_params = [params[i] for i in fp._bagua_fused_param_ids]

for name, v in optimizer.state[fp].items():
if isinstance(v, torch.Tensor):
state_tensors = infer_state_tensors(
optimizer, fp, original_params, name
)

if state_tensors is not None:
for p, state in zip(original_params, state_tensors):
optimizer.state[p][name] = state
else:
for p in original_params:
optimizer.state[p][name] = v

for p in original_params:
if len(optimizer.state[p]) != len(optimizer.state[fp]):
logging.warning("Something went wrong with optimizer state.")
for name, v in _fused_optimizer.state[fp].items():
optimizer.state[fp][name] = v

else:
original_params = [params[i] for i in fp._bagua_fused_param_ids]

for name, v in _fused_optimizer.state[fp].items():
if isinstance(v, torch.Tensor):
state_tensors = infer_state_tensors(
_fused_optimizer, fp, original_params, name
)

if state_tensors is not None:
for p, state in zip(original_params, state_tensors):
optimizer.state[p][name] = state
else:
for p in original_params:
optimizer.state[p][name] = v

for p in original_params:
if len(optimizer.state[p]) != len(_fused_optimizer.state[fp]):
logging.warning("Something went wrong with optimizer state.")


def get_tensor_state(optimizer, param, name):
Expand Down
5 changes: 3 additions & 2 deletions examples/benchmark/synthetic_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@
model.cuda()

optimizer = optim.SGD(model.parameters(), lr=0.01 * bagua.get_world_size())
if args.fuse_optimizer:
optimizer = bagua.contrib.fuse_optimizer(optimizer)

if args.algorithm == "gradient_allreduce":
from bagua.torch_api.algorithms import gradient_allreduce
Expand Down Expand Up @@ -161,6 +159,9 @@

model = model.with_bagua([optimizer], algorithm, do_flatten=not args.fuse_optimizer)

if args.fuse_optimizer:
optimizer = bagua.contrib.fuse_optimizer(optimizer)

# Set up fixed fake data
data = torch.randn(args.batch_size, 3, 224, 224)
target = torch.LongTensor(args.batch_size).random_() % 1000
Expand Down
6 changes: 3 additions & 3 deletions examples/mnist/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,6 @@ def main():
model = Net().cuda()
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

if args.fuse_optimizer:
optimizer = bagua.contrib.fuse_optimizer(optimizer)

if args.algorithm == "gradient_allreduce":
from bagua.torch_api.algorithms import gradient_allreduce

Expand Down Expand Up @@ -264,6 +261,9 @@ def main():
do_flatten=not args.fuse_optimizer,
)

if args.fuse_optimizer:
optimizer = bagua.contrib.fuse_optimizer(optimizer)

scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
if args.algorithm == "async":
Expand Down
4 changes: 1 addition & 3 deletions tests/contrib/test_fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def bagua_init(model, optimizer, algorithm, do_flatten):
elif algorithm == "qadam":
from bagua.torch_api.algorithms.q_adam import QAdamAlgorithm, QAdamOptimizer

optimizer = QAdamOptimizer(model.parameters(), warmup_steps=1)
optimizer = QAdamOptimizer(optimizer.param_groups, warmup_steps=1)
bagua_algorithm = QAdamAlgorithm(optimizer, hierarchical=False)
else:
raise ValueError("unsupported algorithm")
Expand Down Expand Up @@ -304,7 +304,6 @@ def run_all_optimizers_once(self, fn1, fn2, device, num_epochs, fused_count):
count += 1
if count % 5 == 0:
logging.info(f"Tests Passed [{count}/{len(optimizer_list)}]")
# return

def run_fused_with_bagua_wrapper(self, fn1, fn2, num_epochs, fused_count):
self.run_all_optimizers_once(fn1, fn2, "cuda:0", num_epochs, fused_count)
Expand Down Expand Up @@ -428,7 +427,6 @@ def test_low_prec_decentralized(self):

@skip_if_cuda_not_available()
def test_qadam(self):
return
setup_bagua_env()
self.run_qadam(
device="cuda:0",
Expand Down
52 changes: 52 additions & 0 deletions tests/torch_api/test_qadam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import unittest
import torch
import torch.nn as nn
import torch.nn.functional as F
from bagua.torch_api.algorithms.q_adam import QAdamOptimizer
from tests import skip_if_cuda_available


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(2, 10, bias=False)
self.fc2 = nn.Linear(10, 50, bias=True)
self.fc3 = nn.Linear(50, 4, bias=False)
self.relu = nn.ReLU()

def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return F.softmax(x, dim=1)


def run_step(opt_cls, opt_flags, seed):
torch.manual_seed(seed)
model = Net()
optimizer = opt_cls(model.parameters(), **opt_flags)

for step in range(1000):
data = torch.randn(4, 2)
target = torch.randn(4, 4)

optimizer.zero_grad()
output = model(data)
loss = nn.MSELoss()(output, target)

loss.backward()
optimizer.step()

return loss


class TestQAdam(unittest.TestCase):
@skip_if_cuda_available()
def test_qadam_optimizer(self):
loss1 = run_step(torch.optim.Adam, {"lr": 0.001, "weight_decay": 0.1}, seed=13)
loss2 = run_step(
QAdamOptimizer,
{"lr": 0.001, "weight_decay": 0.1, "warmup_steps": 2000},
seed=13,
)
self.assertEqual(loss1.item(), loss2.item())

0 comments on commit b1183c7

Please sign in to comment.