Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python, core): support mutable bucket tensors #271

Merged
merged 38 commits into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
0eb6994
reset tensorpy
wangraying Oct 8, 2021
83da89c
refactor
wangraying Oct 8, 2021
2c53d8f
make it ok for allreduce
wangraying Oct 9, 2021
9351840
update
wangraying Oct 9, 2021
4957026
add tests
wangraying Oct 9, 2021
b771b48
use getter closure and setter closure
wangraying Oct 11, 2021
cf73707
update
wangraying Oct 11, 2021
0222d93
tmp save
wangraying Oct 12, 2021
a060944
Merge branch 'master' into bucket-tensor
wangraying Oct 19, 2021
b007294
.
wangraying Oct 19, 2021
e94deca
f
wangraying Oct 20, 2021
fda2391
fix and add
wangraying Oct 20, 2021
de605fc
fix
wangraying Oct 20, 2021
b32eb61
support qadam
wangraying Oct 20, 2021
dde2d56
.
wangraying Oct 20, 2021
6a0a171
Merge branch 'master' into bucket-tensor
wangraying Oct 21, 2021
2f9acbd
close https://github.com/BaguaSys/bagua/issues/287
wangraying Oct 21, 2021
7f9c081
add doc
wangraying Oct 21, 2021
7e4e5da
rename
wangraying Oct 21, 2021
02f1699
add
wangraying Oct 21, 2021
419279c
remove fallback to python
wangraying Oct 22, 2021
76c54ad
add sanity check
wangraying Oct 22, 2021
51d2450
remove unwrap
wangraying Oct 22, 2021
6355b09
.
wangraying Oct 22, 2021
31db282
update doc
wangraying Oct 27, 2021
693de1c
update
wangraying Oct 27, 2021
440f832
.
wangraying Oct 27, 2021
f8c61c8
.
wangraying Oct 27, 2021
b45c201
.
wangraying Oct 27, 2021
999bd5b
.
wangraying Oct 27, 2021
3a496dc
Update tensor.py
NOBLES5E Oct 28, 2021
785aa73
Update tensor.py
NOBLES5E Oct 28, 2021
1aae792
Update tensor.py
NOBLES5E Oct 28, 2021
c03432c
Update tensor.py
NOBLES5E Oct 28, 2021
803b64f
update doc
wangraying Oct 28, 2021
c264345
Merge branch 'bucket-tensor' of https://github.com/BaguaSys/bagua int…
wangraying Oct 28, 2021
c634b0a
.
wangraying Oct 28, 2021
b98f94d
.
wangraying Oct 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions bagua/torch_api/algorithms/async_model_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,13 @@ def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]:
tensors = []
for name, param in parameters.__reversed__():
if self.step_id < self.warmup_steps:
grad = param.bagua_ensure_grad().ensure_bagua_tensor(
name, bagua_module.bagua_module_name
param = param.bagua_ensure_grad().ensure_bagua_tensor(
name,
bagua_module.bagua_module_name,
getter_closure=lambda param: param.grad,
setter_closure=lambda param, t: setattr(param, "grad", t),
)
param._bagua_grad = grad
tensors.append(grad)
tensors.append(param)
else:
p = param.ensure_bagua_tensor(name, bagua_module.bagua_module_name)
tensors.append(p)
Expand Down Expand Up @@ -128,7 +130,7 @@ def hook(input):
def init_backward_hook(self, bagua_module: BaguaModule):
def hook(parameter_name, parameter):
if self.step_id <= self.warmup_steps:
parameter._bagua_grad.bagua_mark_communication_ready()
parameter.bagua_mark_communication_ready()

return hook

Expand Down
18 changes: 11 additions & 7 deletions bagua/torch_api/algorithms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,14 @@ def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]:
parameters = bagua_module.bagua_build_params()
tensors = []
for name, param in parameters.__reversed__():
grad = param.bagua_ensure_grad().ensure_bagua_tensor(
name, bagua_module.bagua_module_name
param = param.bagua_ensure_grad().ensure_bagua_tensor(
name,
bagua_module.bagua_module_name,
getter_closure=lambda param: param.grad,
setter_closure=lambda param, t: setattr(param, "grad", t),
)
param._bagua_grad = grad
tensors.append(grad)
tensors.append(param)

self._communication_tensor_names = set(name for name, _ in parameters)
assert len(self._communication_tensor_names) == len(
tensors
Expand Down Expand Up @@ -123,9 +126,10 @@ def init_backward_hook(self, bagua_module: BaguaModule):
def hook(parameter_name, parameter):
if parameter_name in self._communication_tensor_names:
assert (
parameter._bagua_grad.data_ptr() == parameter.grad.data_ptr()
), "bagua grad data_ptr should match parameter grad"
parameter._bagua_grad.bagua_mark_communication_ready()
parameter.bagua_backend_tensor().data_ptr()
== parameter.grad.data_ptr()
), "bagua backend tensor data_ptr should match parameter grad"
parameter.bagua_mark_communication_ready()

return hook

Expand Down
53 changes: 30 additions & 23 deletions bagua/torch_api/algorithms/q_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,12 @@ def __init__(
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(QAdamOptimizer, self).__init__(params, defaults)

self.params_in_group = []
self.exp_avgs_in_group = []
self.step_id = 0
self.warmup_steps = warmup_steps

# initialize momentum and variance
for group_id, group in enumerate(self.param_groups):
params_with_grad = []
exp_avgs = []
for p in group["params"]:
params_with_grad.append(p)
state = self.state[p]
Expand All @@ -65,9 +62,6 @@ def __init__(
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avgs.append(state["exp_avg"])
self.params_in_group.append(params_with_grad)
self.exp_avgs_in_group.append(exp_avgs)

def __setstate__(self, state):
super(QAdamOptimizer, self).__setstate__(state)
Expand Down Expand Up @@ -140,23 +134,30 @@ def init_tensors(self, bagua_module: BaguaModule):
param._q_adam_idx = idx

tensor_groups = []
for param_group, m_group in zip(
self.optimizer.params_in_group, self.optimizer.exp_avgs_in_group
):
for param, exp_avgs in zip(param_group, m_group):
for group in self.optimizer.param_groups:
for param in group["params"]:
if self.optimizer.step_id < self.warmup_steps:
# register grad
registered_tensor = param.bagua_ensure_grad().ensure_bagua_tensor(
param._q_adam_name, bagua_module.bagua_module_name
param._q_adam_name,
bagua_module.bagua_module_name,
getter_closure=lambda param: param.grad,
setter_closure=lambda param, t: setattr(param, "grad", t),
)
param._q_adam_grad = registered_tensor
registered_tensor._q_adam_idx = param._q_adam_idx
else:
registered_tensor = exp_avgs.ensure_bagua_tensor(
param._q_adam_name, bagua_module.bagua_module_name
# register first momentum
def set_momentum_fn(param, t):
self.optimizer.state[param]["exp_avg"] = t

registered_tensor = param.bagua_ensure_grad().ensure_bagua_tensor(
param._q_adam_name,
bagua_module.bagua_module_name,
getter_closure=lambda param: self.optimizer.state[param][
"exp_avg"
],
setter_closure=set_momentum_fn,
)
registered_tensor._q_adam_grad = param.bagua_ensure_grad()
param._q_adam_momentum = registered_tensor
registered_tensor._q_adam_idx = param._q_adam_idx

tensor_groups.append(registered_tensor)
tensor_groups.sort(key=lambda x: x._q_adam_idx)
return tensor_groups
Expand Down Expand Up @@ -190,7 +191,9 @@ def init_operations(
def calculate_momentum(*args):
beta1, beta2 = self.optimizer.param_groups[0]["betas"]
for tensor in bucket.tensors:
tensor.mul_(beta1).add_(tensor._q_adam_grad, alpha=1 - beta1)
tensor.bagua_getter_closure().mul_(beta1).add_(
tensor.grad, alpha=1 - beta1
)

bucket.append_python_op(calculate_momentum, group=self.process_group)
bucket.append_centralized_synchronous_op(
Expand All @@ -203,13 +206,17 @@ def calculate_momentum(*args):

def init_backward_hook(self, bagua_module: BaguaModule):
def hook_momentum(parameter_name, parameter):
parameter._q_adam_momentum.bagua_mark_communication_ready()
assert (
parameter.bagua_backend_tensor().data_ptr()
== self.optimizer.state[parameter]["exp_avg"].data_ptr()
), "bagua backend tensor data_ptr should match _q_adam_momentum data_ptr"
parameter.bagua_mark_communication_ready()

def hook_grad(parameter_name, parameter):
assert (
parameter.grad.data_ptr() == parameter._q_adam_grad.data_ptr()
), "gradient data_ptr should match _q_adam_grad data_ptr"
parameter._q_adam_grad.bagua_mark_communication_ready()
parameter.bagua_backend_tensor().data_ptr() == parameter.grad.data_ptr()
), "bagua backend tensor data_ptr should match _q_adam_grad data_ptr"
parameter.bagua_mark_communication_ready()

return (
hook_grad if self.optimizer.step_id < self.warmup_steps else hook_momentum
Expand Down
70 changes: 42 additions & 28 deletions bagua/torch_api/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@
from bagua.torch_api.tensor import BaguaTensor
from bagua.torch_api.utils import check_contiguous
from bagua.torch_api.communication import (
broadcast,
BaguaProcessGroup,
_bagua_backend_comm,
_rank_not_in_comm,
)


Expand All @@ -36,7 +34,7 @@ def __init__(
"""
self.tensors = tensors
"""
The tensors contained within the bucket.
The Bagua tensors contained in the bucket.
"""
self.bagua_module_name = tensors[0].bagua_module_name
for tensor in self.tensors:
Expand All @@ -57,7 +55,10 @@ def __init__(
# padding tensor must be of name bagua_padding_tensor, so that they are always marked as ready for communication in the backend
self.padding_tensor = torch.zeros(
padding, dtype=self.tensors[0].dtype, device=self.tensors[0].device
).to_bagua_tensor("bagua_padding_tensor_bucket_" + name)
).to_bagua_tensor(
"bagua_padding_tensor_bucket_" + name,
module_name=self.bagua_module_name,
)

self._all_tensors = (
self.tensors + [self.padding_tensor]
Expand All @@ -72,31 +73,37 @@ def __init__(
torch.cuda.empty_cache()

self.backend_bucket = B.BaguaBucketPy(
name, [tensor._bagua_backend_tensor for tensor in self._all_tensors]
name,
[tensor.bagua_backend_tensor() for tensor in self._all_tensors],
)

for tensor in self._all_tensors:
tensor._bagua_bucket = self

def flattened_tensor(self) -> BaguaTensor:
def flattened_tensor(self) -> torch.Tensor:
"""
Returns a tensor contiguous in memory which contains the same data as :attr:`self` tensors and padding tensor (if exists).
Returns a tensor contiguous in memory which contains the same data as inner tensors, i.e.
returned by calling :meth:`~bagua.torch_api.tensor.BaguaTensor.bagua_getter_closure` on
:attr:`self` tensors and padding tensor (if exists).
"""

all_registered_tensors = [
tensor.bagua_getter_closure() for tensor in self._all_tensors
]
total_size = 0
for tensor in self._all_tensors:
for tensor in all_registered_tensors:
total_size += tensor.numel()

flatten_tensor = torch.zeros(
total_size,
dtype=self._all_tensors[0].dtype,
device=self._all_tensors[0].device,
dtype=all_registered_tensors[0].dtype,
device=all_registered_tensors[0].device,
)

offset = 0
for tensor in self._all_tensors:
for tensor in all_registered_tensors:
# copy data
flatten_tensor[offset : offset + tensor.numel()] = tensor.data.reshape(-1)
flatten_tensor[offset : offset + tensor.numel()] = tensor.reshape(-1)
offset += tensor.numel()
return flatten_tensor

Expand All @@ -110,18 +117,20 @@ def _flatten_(self):
flatten_tensor = self.flattened_tensor()

if self.check_flatten():
flatten_tensor.set_(self._all_tensors[0].storage(), 0, flatten_tensor.shape)
flatten_tensor.set_(
self._all_tensors[0].bagua_getter_closure().storage(),
0,
flatten_tensor.shape,
)
self.backend_tensor = flatten_tensor
return

flatten_storage = flatten_tensor.storage()

offset = 0

for tensor in self._all_tensors:
# copy data
flatten_tensor[offset : offset + tensor.numel()] = tensor.data.reshape(-1)
tensor.bagua_set_storage(flatten_storage, offset)
offset += tensor.numel()
offset += tensor.bagua_getter_closure().numel()

# set backend tensor
self.backend_tensor = flatten_tensor
Expand All @@ -131,9 +140,11 @@ def _flatten_(self):
def check_flatten(self) -> bool:
"""
Returns:
True if the bucket's tensors are contiguous in memory.
True if inner tensors are contiguous in memory.
"""
return check_contiguous(self._all_tensors)
return check_contiguous(
[tensor.bagua_getter_closure() for tensor in self._all_tensors]
NOBLES5E marked this conversation as resolved.
Show resolved Hide resolved
)

def append_python_op(
self,
Expand Down Expand Up @@ -252,15 +263,15 @@ def append_decentralized_synchronous_op(
_bagua_backend_comm(group.get_intra_node_communicator()),
hierarchical=hierarchical,
peer_selection_mode=peer_selection_mode,
peer_weight=peer_weight._bagua_backend_tensor,
peer_weight=peer_weight.bagua_backend_tensor(),
)
else:
return self.backend_bucket.append_decentralized_synchronous_op(
_bagua_backend_comm(group.get_global_communicator()),
None,
hierarchical=hierarchical,
peer_selection_mode=peer_selection_mode,
peer_weight=peer_weight._bagua_backend_tensor,
peer_weight=peer_weight.bagua_backend_tensor(),
)

def append_low_precision_decentralized_synchronous_op(
Expand Down Expand Up @@ -304,9 +315,9 @@ def append_low_precision_decentralized_synchronous_op(
hierarchical=hierarchical,
peer_selection_mode="ring",
compression=compression,
weight=weight._bagua_backend_tensor,
left_peer_weight=left_peer_weight._bagua_backend_tensor,
right_peer_weight=right_peer_weight._bagua_backend_tensor,
weight=weight.bagua_backend_tensor(),
left_peer_weight=left_peer_weight.bagua_backend_tensor(),
right_peer_weight=right_peer_weight.bagua_backend_tensor(),
)
else:
self.backend_bucket.append_low_precision_decentralized_synchronous_op(
Expand All @@ -315,9 +326,9 @@ def append_low_precision_decentralized_synchronous_op(
hierarchical=hierarchical,
peer_selection_mode="ring",
compression=compression,
weight=weight._bagua_backend_tensor,
left_peer_weight=left_peer_weight._bagua_backend_tensor,
right_peer_weight=right_peer_weight._bagua_backend_tensor,
weight=weight.bagua_backend_tensor(),
left_peer_weight=left_peer_weight.bagua_backend_tensor(),
right_peer_weight=right_peer_weight.bagua_backend_tensor(),
)

def append_asynchronous_model_average_op(
Expand Down Expand Up @@ -361,4 +372,7 @@ def clear_ops(self) -> BaguaBucket:

def bytes(self) -> int:
"""Returns the total number of bytes occupied by the bucket."""
return sum(tensor.numel() * tensor.element_size() for tensor in self.tensors)
registered_tensors = [tensor.bagua_getter_closure() for tensor in self.tensors]
return sum(
tensor.numel() * tensor.element_size() for tensor in registered_tensors
)
12 changes: 8 additions & 4 deletions bagua/torch_api/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"allreduce", "allreduce_inplace", "allgather", "allgather_inplace",
"gather", "gather_inplace", "scatter", "scatter_inplace",
"reduce_scatter", "reduce_scatter_inplace", "alltoall", "alltoall_inplace",
"barrier"
"barrier", "BaguaProcessGroup"
]

# Process group's global rank to local rank mapping
Expand Down Expand Up @@ -193,6 +193,7 @@ def from_torch_group(group, stream: Optional[torch.cuda.Stream] = None):


class BaguaProcessGroup:
"""Definition of Bagua process group."""
def __init__(self, ranks, stream, group_name):
self.ranks = ranks
self.stream = stream
Expand All @@ -213,13 +214,16 @@ def __init__(self, ranks, stream, group_name):

logging.debug(f"Initialize Bagua process group of ranks {self.ranks}")

def get_global_communicator(self):
def get_global_communicator(self) -> B.BaguaSingleCommunicatorPy:
"""Returns the global communicator of current process group."""
return get_communicator(self.group_name, "global")

def get_inter_node_communicator(self):
def get_inter_node_communicator(self) -> B.BaguaSingleCommunicatorPy:
"""Returns the inter-node communicator of current process group."""
return get_communicator(self.group_name, "inter")
NOBLES5E marked this conversation as resolved.
Show resolved Hide resolved

def get_intra_node_communicator(self):
def get_intra_node_communicator(self) -> B.BaguaSingleCommunicatorPy:
"""Returns the intra-node communicator of current process group."""
return get_communicator(self.group_name, "intra")


Expand Down
12 changes: 10 additions & 2 deletions bagua/torch_api/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,15 @@ class BaguaModule:
:ivar bagua_optimizers: The optimizers passed in by :meth:`~bagua.torch_api.distributed.BaguaModule.with_bagua`.
:vartype bagua_optimizers: List[torch.optim.Optimizer]

:ivar bagua_algorithm: The algorithm passed in by :meth:`~bagua.torch_api.distributed.BaguaModule.with_bagua`.
:vartype bagua_algorithm: bagua.torch_api.algorithms.Algorithm
:ivar bagua_algorithm: The algorithm implementation used by the module, reified by the algorithm passed in
by :meth:`~bagua.torch_api.distributed.BaguaModule.with_bagua`.
:vartype bagua_algorithm: bagua.torch_api.algorithms.AlgorithmImpl

:ivar process_group: The process group used by the module.
:vartype process_group: bagua.torch_api.communication.BaguaProcessGroup

:ivar bagua_module_name: The module's name. Bagua uses the module name to distinguish different modules.
:vartype bagua_optimizers: str

:ivar parameters_to_ignore: The parameter names in ``"{module_name}.{param_name}"`` format to ignore
when calling ``self.bagua_build_params()``.
Expand Down Expand Up @@ -308,6 +315,7 @@ def with_bagua( # pytype: disable=module-attr
self._bagua_reset_module()

if _rank_not_in_group(self._bagua_process_group):
# return if not a participant
return self

self.parameters_to_ignore = (
Expand Down
Loading