Skip to content

Commit

Permalink
feat(python, core): support mutable bucket tensors (#271)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: `BaguaTensor::bagua_ensure_grad` returns the tensor itself now
  • Loading branch information
wangraying authored and NOBLES5E committed Oct 28, 2021
1 parent d4b7dd7 commit 2fe6eeb
Show file tree
Hide file tree
Showing 11 changed files with 552 additions and 240 deletions.
12 changes: 7 additions & 5 deletions bagua/torch_api/algorithms/async_model_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,13 @@ def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]:
tensors = []
for name, param in parameters.__reversed__():
if self.step_id < self.warmup_steps:
grad = param.bagua_ensure_grad().ensure_bagua_tensor(
name, bagua_module.bagua_module_name
param = param.bagua_ensure_grad().ensure_bagua_tensor(
name,
bagua_module.bagua_module_name,
getter_closure=lambda param: param.grad,
setter_closure=lambda param, t: setattr(param, "grad", t),
)
param._bagua_grad = grad
tensors.append(grad)
tensors.append(param)
else:
p = param.ensure_bagua_tensor(name, bagua_module.bagua_module_name)
tensors.append(p)
Expand Down Expand Up @@ -128,7 +130,7 @@ def hook(input):
def init_backward_hook(self, bagua_module: BaguaModule):
def hook(parameter_name, parameter):
if self.step_id <= self.warmup_steps:
parameter._bagua_grad.bagua_mark_communication_ready()
parameter.bagua_mark_communication_ready()

return hook

Expand Down
18 changes: 11 additions & 7 deletions bagua/torch_api/algorithms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,14 @@ def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]:
parameters = bagua_module.bagua_build_params()
tensors = []
for name, param in parameters.__reversed__():
grad = param.bagua_ensure_grad().ensure_bagua_tensor(
name, bagua_module.bagua_module_name
param = param.bagua_ensure_grad().ensure_bagua_tensor(
name,
bagua_module.bagua_module_name,
getter_closure=lambda param: param.grad,
setter_closure=lambda param, t: setattr(param, "grad", t),
)
param._bagua_grad = grad
tensors.append(grad)
tensors.append(param)

self._communication_tensor_names = set(name for name, _ in parameters)
assert len(self._communication_tensor_names) == len(
tensors
Expand Down Expand Up @@ -123,9 +126,10 @@ def init_backward_hook(self, bagua_module: BaguaModule):
def hook(parameter_name, parameter):
if parameter_name in self._communication_tensor_names:
assert (
parameter._bagua_grad.data_ptr() == parameter.grad.data_ptr()
), "bagua grad data_ptr should match parameter grad"
parameter._bagua_grad.bagua_mark_communication_ready()
parameter.bagua_backend_tensor().data_ptr()
== parameter.grad.data_ptr()
), "bagua backend tensor data_ptr should match parameter grad"
parameter.bagua_mark_communication_ready()

return hook

Expand Down
53 changes: 30 additions & 23 deletions bagua/torch_api/algorithms/q_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,12 @@ def __init__(
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(QAdamOptimizer, self).__init__(params, defaults)

self.params_in_group = []
self.exp_avgs_in_group = []
self.step_id = 0
self.warmup_steps = warmup_steps

# initialize momentum and variance
for group_id, group in enumerate(self.param_groups):
params_with_grad = []
exp_avgs = []
for p in group["params"]:
params_with_grad.append(p)
state = self.state[p]
Expand All @@ -65,9 +62,6 @@ def __init__(
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avgs.append(state["exp_avg"])
self.params_in_group.append(params_with_grad)
self.exp_avgs_in_group.append(exp_avgs)

def __setstate__(self, state):
super(QAdamOptimizer, self).__setstate__(state)
Expand Down Expand Up @@ -140,23 +134,30 @@ def init_tensors(self, bagua_module: BaguaModule):
param._q_adam_idx = idx

tensor_groups = []
for param_group, m_group in zip(
self.optimizer.params_in_group, self.optimizer.exp_avgs_in_group
):
for param, exp_avgs in zip(param_group, m_group):
for group in self.optimizer.param_groups:
for param in group["params"]:
if self.optimizer.step_id < self.warmup_steps:
# register grad
registered_tensor = param.bagua_ensure_grad().ensure_bagua_tensor(
param._q_adam_name, bagua_module.bagua_module_name
param._q_adam_name,
bagua_module.bagua_module_name,
getter_closure=lambda param: param.grad,
setter_closure=lambda param, t: setattr(param, "grad", t),
)
param._q_adam_grad = registered_tensor
registered_tensor._q_adam_idx = param._q_adam_idx
else:
registered_tensor = exp_avgs.ensure_bagua_tensor(
param._q_adam_name, bagua_module.bagua_module_name
# register first momentum
def set_momentum_fn(param, t):
self.optimizer.state[param]["exp_avg"] = t

registered_tensor = param.bagua_ensure_grad().ensure_bagua_tensor(
param._q_adam_name,
bagua_module.bagua_module_name,
getter_closure=lambda param: self.optimizer.state[param][
"exp_avg"
],
setter_closure=set_momentum_fn,
)
registered_tensor._q_adam_grad = param.bagua_ensure_grad()
param._q_adam_momentum = registered_tensor
registered_tensor._q_adam_idx = param._q_adam_idx

tensor_groups.append(registered_tensor)
tensor_groups.sort(key=lambda x: x._q_adam_idx)
return tensor_groups
Expand Down Expand Up @@ -190,7 +191,9 @@ def init_operations(
def calculate_momentum(*args):
beta1, beta2 = self.optimizer.param_groups[0]["betas"]
for tensor in bucket.tensors:
tensor.mul_(beta1).add_(tensor._q_adam_grad, alpha=1 - beta1)
tensor.bagua_getter_closure().mul_(beta1).add_(
tensor.grad, alpha=1 - beta1
)

bucket.append_python_op(calculate_momentum, group=self.process_group)
bucket.append_centralized_synchronous_op(
Expand All @@ -203,13 +206,17 @@ def calculate_momentum(*args):

def init_backward_hook(self, bagua_module: BaguaModule):
def hook_momentum(parameter_name, parameter):
parameter._q_adam_momentum.bagua_mark_communication_ready()
assert (
parameter.bagua_backend_tensor().data_ptr()
== self.optimizer.state[parameter]["exp_avg"].data_ptr()
), "bagua backend tensor data_ptr should match _q_adam_momentum data_ptr"
parameter.bagua_mark_communication_ready()

def hook_grad(parameter_name, parameter):
assert (
parameter.grad.data_ptr() == parameter._q_adam_grad.data_ptr()
), "gradient data_ptr should match _q_adam_grad data_ptr"
parameter._q_adam_grad.bagua_mark_communication_ready()
parameter.bagua_backend_tensor().data_ptr() == parameter.grad.data_ptr()
), "bagua backend tensor data_ptr should match _q_adam_grad data_ptr"
parameter.bagua_mark_communication_ready()

return (
hook_grad if self.optimizer.step_id < self.warmup_steps else hook_momentum
Expand Down
72 changes: 43 additions & 29 deletions bagua/torch_api/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@
from bagua.torch_api.tensor import BaguaTensor
from bagua.torch_api.utils import check_contiguous
from bagua.torch_api.communication import (
broadcast,
BaguaProcessGroup,
_bagua_backend_comm,
_rank_not_in_comm,
)


Expand All @@ -36,7 +34,7 @@ def __init__(
"""
self.tensors = tensors
"""
The tensors contained within the bucket.
The Bagua tensors contained in the bucket.
"""
self.bagua_module_name = tensors[0].bagua_module_name
for tensor in self.tensors:
Expand All @@ -57,7 +55,10 @@ def __init__(
# padding tensor must be of name bagua_padding_tensor, so that they are always marked as ready for communication in the backend
self.padding_tensor = torch.zeros(
padding, dtype=self.tensors[0].dtype, device=self.tensors[0].device
).to_bagua_tensor("bagua_padding_tensor_bucket_" + name)
).to_bagua_tensor(
"bagua_padding_tensor_bucket_" + name,
module_name=self.bagua_module_name,
)

self._all_tensors = (
self.tensors + [self.padding_tensor]
Expand All @@ -72,56 +73,64 @@ def __init__(
torch.cuda.empty_cache()

self.backend_bucket = B.BaguaBucketPy(
name, [tensor._bagua_backend_tensor for tensor in self._all_tensors]
name,
[tensor.bagua_backend_tensor() for tensor in self._all_tensors],
)

for tensor in self._all_tensors:
tensor._bagua_bucket = self

def flattened_tensor(self) -> BaguaTensor:
def flattened_tensor(self) -> torch.Tensor:
"""
Returns a tensor contiguous in memory which contains the same data as :attr:`self` tensors and padding tensor (if exists).
Returns a tensor contiguous in memory which contains the same data as effective tensors, i.e.
returned by calling :meth:`~bagua.torch_api.tensor.BaguaTensor.bagua_getter_closure` on
:attr:`self` tensors and padding tensor (if exists).
"""

all_registered_tensors = [
tensor.bagua_getter_closure() for tensor in self._all_tensors
]
total_size = 0
for tensor in self._all_tensors:
for tensor in all_registered_tensors:
total_size += tensor.numel()

flatten_tensor = torch.zeros(
total_size,
dtype=self._all_tensors[0].dtype,
device=self._all_tensors[0].device,
dtype=all_registered_tensors[0].dtype,
device=all_registered_tensors[0].device,
)

offset = 0
for tensor in self._all_tensors:
for tensor in all_registered_tensors:
# copy data
flatten_tensor[offset : offset + tensor.numel()] = tensor.data.reshape(-1)
flatten_tensor[offset : offset + tensor.numel()] = tensor.reshape(-1)
offset += tensor.numel()
return flatten_tensor

def _flatten_(self):
"""
Flatten inner tensors in place.
Flatten effective tensors in place.
"""
if len(self._all_tensors) == 0:
return

flatten_tensor = self.flattened_tensor()

if self.check_flatten():
flatten_tensor.set_(self._all_tensors[0].storage(), 0, flatten_tensor.shape)
flatten_tensor.set_(
self._all_tensors[0].bagua_getter_closure().storage(),
0,
flatten_tensor.shape,
)
self.backend_tensor = flatten_tensor
return

flatten_storage = flatten_tensor.storage()

offset = 0

for tensor in self._all_tensors:
# copy data
flatten_tensor[offset : offset + tensor.numel()] = tensor.data.reshape(-1)
tensor.bagua_set_storage(flatten_storage, offset)
offset += tensor.numel()
offset += tensor.bagua_getter_closure().numel()

# set backend tensor
self.backend_tensor = flatten_tensor
Expand All @@ -131,9 +140,11 @@ def _flatten_(self):
def check_flatten(self) -> bool:
"""
Returns:
True if the bucket's tensors are contiguous in memory.
True if effective tensors are contiguous in memory.
"""
return check_contiguous(self._all_tensors)
return check_contiguous(
[tensor.bagua_getter_closure() for tensor in self._all_tensors]
)

def append_python_op(
self,
Expand Down Expand Up @@ -252,15 +263,15 @@ def append_decentralized_synchronous_op(
_bagua_backend_comm(group.get_intra_node_communicator()),
hierarchical=hierarchical,
peer_selection_mode=peer_selection_mode,
peer_weight=peer_weight._bagua_backend_tensor,
peer_weight=peer_weight.bagua_backend_tensor(),
)
else:
return self.backend_bucket.append_decentralized_synchronous_op(
_bagua_backend_comm(group.get_global_communicator()),
None,
hierarchical=hierarchical,
peer_selection_mode=peer_selection_mode,
peer_weight=peer_weight._bagua_backend_tensor,
peer_weight=peer_weight.bagua_backend_tensor(),
)

def append_low_precision_decentralized_synchronous_op(
Expand Down Expand Up @@ -304,9 +315,9 @@ def append_low_precision_decentralized_synchronous_op(
hierarchical=hierarchical,
peer_selection_mode="ring",
compression=compression,
weight=weight._bagua_backend_tensor,
left_peer_weight=left_peer_weight._bagua_backend_tensor,
right_peer_weight=right_peer_weight._bagua_backend_tensor,
weight=weight.bagua_backend_tensor(),
left_peer_weight=left_peer_weight.bagua_backend_tensor(),
right_peer_weight=right_peer_weight.bagua_backend_tensor(),
)
else:
self.backend_bucket.append_low_precision_decentralized_synchronous_op(
Expand All @@ -315,9 +326,9 @@ def append_low_precision_decentralized_synchronous_op(
hierarchical=hierarchical,
peer_selection_mode="ring",
compression=compression,
weight=weight._bagua_backend_tensor,
left_peer_weight=left_peer_weight._bagua_backend_tensor,
right_peer_weight=right_peer_weight._bagua_backend_tensor,
weight=weight.bagua_backend_tensor(),
left_peer_weight=left_peer_weight.bagua_backend_tensor(),
right_peer_weight=right_peer_weight.bagua_backend_tensor(),
)

def append_asynchronous_model_average_op(
Expand Down Expand Up @@ -361,4 +372,7 @@ def clear_ops(self) -> BaguaBucket:

def bytes(self) -> int:
"""Returns the total number of bytes occupied by the bucket."""
return sum(tensor.numel() * tensor.element_size() for tensor in self.tensors)
registered_tensors = [tensor.bagua_getter_closure() for tensor in self.tensors]
return sum(
tensor.numel() * tensor.element_size() for tensor in registered_tensors
)
Loading

0 comments on commit 2fe6eeb

Please sign in to comment.