Skip to content

Commit

Permalink
fix(python): fix fused optimizer with multiple param groups (#356)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangraying authored Nov 2, 2021
1 parent 35d2229 commit 5228e75
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 32 deletions.
52 changes: 29 additions & 23 deletions bagua/torch_api/contrib/fuse/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,44 +12,46 @@

def flatten_params_and_states(optimizer: torch.optim.Optimizer):
"""
Flatten parameter tensors in the sampe group into contiguous ones.
Flatten parameter tensors in the same group into contiguous ones.
"""

type_params = {}
for group in optimizer.param_groups:
type_params = {}
# flatten by param group
for param in group["params"]:

params_of_type = type_params.get(param.type(), [])
params_of_type.append(param)
type_params[param.type()] = params_of_type

for param_type, params in type_params.items():
grads = [p.bagua_ensure_grad().grad for p in params]
state_tensors, state_scalars = get_optimizer_param_states(optimizer, params)

if state_tensors is None:
continue

flatten_tensors(params)
flatten_tensors_with_closure(
grads,
params,
getter_closure=lambda p: p.grad,
setter_closure=lambda p, new_grad: setattr(p, "grad", new_grad),
)

for name, tensors in state_tensors.items():
for param_type, params in type_params.items():
grads = [p.bagua_ensure_grad().grad for p in params]
state_tensors, state_scalars = get_optimizer_param_states(optimizer, params)

def set_state_fn(p, t):
optimizer.state[p][name] = t
if state_tensors is None:
continue

flatten_tensors(params)
flatten_tensors_with_closure(
tensors,
grads,
params,
getter_closure=lambda p: optimizer.state[p][name],
setter_closure=set_state_fn,
getter_closure=lambda p: p.grad,
setter_closure=lambda p, new_grad: setattr(p, "grad", new_grad),
)

for name, tensors in state_tensors.items():

def set_state_fn(p, t):
optimizer.state[p][name] = t

flatten_tensors_with_closure(
tensors,
params,
getter_closure=lambda p: optimizer.state[p][name],
setter_closure=set_state_fn,
)
torch.cuda.empty_cache()


def flatten_tensors(tensors: List[torch.Tensor]):
"""
Expand Down Expand Up @@ -422,6 +424,10 @@ def do_fuse(optimizer: torch.optim.Optimizer):
new_params.append(param)

new_group = {"params": new_params}
for k, v in group.items():
if k != "params":
new_group[k] = v

_fused_optimizer.add_param_group(new_group)


Expand Down
41 changes: 32 additions & 9 deletions tests/contrib/test_fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,27 @@ def construct_model_and_optimizer(opt, flag_param, device):
model.load_state_dict(pretrained_dict)

model = model.to(device)
optimizer = opt(model.parameters(), **flag_param)
no_decay = ["bias"]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in model.named_parameters()
if not any(nd in n for nd in no_decay)
],
"weight_decay": 0.1,
},
{
"params": [
p
for n, p in model.named_parameters()
if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0,
},
]

optimizer = opt(optimizer_grouped_parameters, **flag_param)

return model, optimizer

Expand Down Expand Up @@ -73,7 +93,9 @@ def train_model_fused(model, optimizer, device, num_epochs):
optimizer.fuse_step()
else:
optimizer.step()
# logging.debug(f"#train model fused#{epoch} params: {optimizer._bagua_fused_optimizer.param_groups}")
# logging.debug(
# f"#train model fused#{epoch} params: {optimizer._bagua_fused_optimizer.param_groups}"
# )
# logging.debug(f"#train model fused#{epoch} state: {optimizer.state}")


Expand Down Expand Up @@ -282,14 +304,15 @@ def run_all_optimizers_once(self, fn1, fn2, device, num_epochs, fused_count):
count += 1
if count % 5 == 0:
logging.info(f"Tests Passed [{count}/{len(optimizer_list)}]")
# return

def run_fused_with_bagua_wrapper(self, fn1, fn2, num_epochs, fused_count):
self.run_all_optimizers_once(fn1, fn2, "cuda:0", num_epochs, fused_count)

@skip_if_cuda_available()
def test_fused_optimizer(self):
self.run_all_optimizers_once(
fn1=run, fn2=run_fused, device="cpu", num_epochs=101, fused_count=51
fn1=run, fn2=run_fused, device="cpu", num_epochs=101, fused_count=102
)

@skip_if_cuda_not_available()
Expand All @@ -302,7 +325,7 @@ def test_gradient_allreduce(self):
p1, p2, device, num_epochs, "gradient_allreduce", True, False
),
num_epochs=101,
fused_count=51,
fused_count=102,
)
# check: both are falttened, should not fuse
self.run_fused_with_bagua_wrapper(
Expand Down Expand Up @@ -335,7 +358,7 @@ def test_bytegrad(self):
p1, p2, device, num_epochs, "bytegrad", True, False
),
num_epochs=101,
fused_count=51,
fused_count=102,
)

@skip_if_cuda_not_available()
Expand All @@ -348,7 +371,7 @@ def test_decentralized(self):
p1, p2, device, num_epochs, "decentralized", True, False
),
num_epochs=101,
fused_count=51,
fused_count=102,
)
self.run_fused_with_bagua_wrapper(
fn1=run,
Expand Down Expand Up @@ -377,7 +400,7 @@ def test_async(self):
p1, p2, device, num_epochs, "async", True, False
),
num_epochs=101,
fused_count=51,
fused_count=102,
)
self.run_fused_with_bagua_wrapper(
fn1=run,
Expand All @@ -400,7 +423,7 @@ def test_low_prec_decentralized(self):
p1, p2, device, num_epochs, "low_prec_decentralized", True, False
),
num_epochs=101,
fused_count=51,
fused_count=102,
)

@skip_if_cuda_not_available()
Expand All @@ -410,7 +433,7 @@ def test_qadam(self):
self.run_qadam(
device="cuda:0",
num_epochs=101,
fused_count=51,
fused_count=102,
optimizer_flatten=True,
bagua_flatten=False,
)
Expand Down

0 comments on commit 5228e75

Please sign in to comment.