Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add low precision decentralized algorithm #103

Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 82 additions & 6 deletions bagua/torch_api/algorithms/decentralized.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python3

from bagua.torch_api.bucket import BaguaBucket
from bagua.torch_api.tensor import BaguaTensor
from bagua.torch_api.distributed import BaguaModule
Expand All @@ -11,8 +10,8 @@
class DecentralizedAlgorithm(Algorithm):
def __init__(
self,
hierarchical: bool = True,
peer_selection_mode: str = "all",
compression: str = None,
communication_interval: int = 1,
):
"""
Expand All @@ -21,16 +20,15 @@ def __init__(
algorithm.

Args:
hierarchical (bool): Enable hierarchical communication.
peer_selection_mode (str): Can be "all" or "shift_one". "all" means all workers'
weights are averaged in each communication step. "shift_one" means each worker
selects a different peer to do weights average in each communication step.
compression (str): Not supported yet.
communication_interval (int): Number of iterations between two communication steps.
"""
self.hierarchical = hierarchical
self.peer_selection_mode = peer_selection_mode
self.compression = compression
self.communication_interval = communication_interval
self.tensor_groups = []

def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]:
parameters = bagua_module.bagua_build_params()
Expand Down Expand Up @@ -68,7 +66,85 @@ def init_operations(
):
bucket.clear_ops()
bucket.append_decentralized_synchronous_op(
hierarchical=True,
hierarchical=self.hierarchical,
peer_selection_mode=self.peer_selection_mode,
communication_interval=self.communication_interval,
)


class LowPrecisionDecentralizedAlgorithm(Algorithm):
def __init__(self, hierarchical: bool = True, communication_interval: int = 1):
"""
Create an instance of the
`Difference Compression Decentralized <https://arxiv.org/pdf/1803.06443.pdf>`_
algorithm.

Args:
hierarchical (bool): Enable hierarchical communication.
communication_interval (int): Number of iterations between two communication steps.
"""
self.hierarchical = hierarchical
self.communication_interval = communication_interval

def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]:
parameters = bagua_module.bagua_build_params()
self.tensors = [
param.ensure_bagua_tensor(name) for name, param in parameters.__reversed__()
wangraying marked this conversation as resolved.
Show resolved Hide resolved
]
return self.tensors

def init_backward_hook(self, bagua_module: BaguaModule):
def hook(parameter_name, parameter):
pass

return hook

def init_post_backward_hook(self, bagua_module: BaguaModule):
def hook():
pass

return hook

def init_post_optimizer_step_hook(self, bagua_module: BaguaModule):
def hook(optimizer: torch.optim.Optimizer):
wangraying marked this conversation as resolved.
Show resolved Hide resolved
for group in optimizer.param_groups:
for param in group["params"]:
if param.is_bagua_tensor():
param.bagua_mark_communication_ready()

bagua_module._bagua_backend.wait_pending_comm_ops()
wangraying marked this conversation as resolved.
Show resolved Hide resolved

return hook

def _init_states(self, bucket: BaguaBucket):
bucket_flattened_tensor = bucket.flattened_tensor()

weight_tensor = bucket_flattened_tensor.detach().clone()
left_peer_weight_tensor = bucket_flattened_tensor.detach().clone()
right_peer_weight_tensor = bucket_flattened_tensor.detach().clone()

bucket._weight = weight_tensor.to_bagua_tensor("weight")
bucket._left_peer_weight = left_peer_weight_tensor.to_bagua_tensor(
"left_peer_weight"
)
bucket._right_peer_weight = right_peer_weight_tensor.to_bagua_tensor(
"right_peer_weight"
)

def init_operations(
self,
bagua_module: BaguaModule,
bucket: BaguaBucket,
):
self._init_states(bucket)
torch.cuda.synchronize()
bucket.clear_ops()
bucket.append_decentralized_synchronous_op(
hierarchical=self.hierarchical,
peer_selection_mode="ring",
communication_interval=self.communication_interval,
compression="MinMaxUInt8",
weight=bucket._weight,
left_peer_weight=bucket._left_peer_weight,
right_peer_weight=bucket._right_peer_weight,
)
97 changes: 81 additions & 16 deletions bagua/torch_api/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def __init__(
tensors: A list of Bagua tensors to be put in the
bucket.
name: The unique name of the bucket.
flatten: If True, flatten the input tensors so that they are
flatten: If ``True``, flatten the input tensors so that they are
contiguous in memory.
alignment: If alignment > 1, Bagua will create a padding tensor to
alignment: If `alignment > 1`, Bagua will create a padding tensor to
the bucket so that the total number of elements in the bucket divides
the given alignment.
"""
Expand Down Expand Up @@ -65,6 +65,30 @@ def __init__(
for tensor in self._all_tensors:
tensor._bagua_bucket = self

def flattened_tensor(self) -> torch.Tensor:
wangraying marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns a tensor contiguous in memory which contains the same data as `self` tensors and padding tensor (if exists).
If `self` tensors and padding tensor are already flattened, this function returns a tensor corresponding to their
underlying storage.
"""
if self.flatten:
return self.backend_tensor
wangraying marked this conversation as resolved.
Show resolved Hide resolved

total_size = 0
for tensor in self._all_tensors:
total_size += tensor.numel()

flatten_tensor = torch.zeros(total_size, dtype=self._all_tensors[0].dtype).to(
self._all_tensors[0].device
)

offset = 0
for tensor in self._all_tensors:
# copy data
flatten_tensor[offset : offset + tensor.numel()] = tensor.data.reshape(-1)
offset += tensor.numel()
return flatten_tensor

def _flatten_(self):
"""
Flatten inner tensors in place.
Expand All @@ -89,6 +113,9 @@ def _flatten_(self):
flatten_tensor[offset : offset + tensor.numel()] = tensor.data.reshape(-1)
tensor.bagua_set_storage(flatten_storage, offset)
offset += tensor.numel()

# set backend tensor
self.backend_tensor = flatten_tensor
# check
assert self.check_flatten()

Expand Down Expand Up @@ -143,11 +170,12 @@ def append_centralized_synchronous_op(
hierarchical (bool): Enable hierarchical communication. Which means the GPUs on the same machine
will communicate will each other first. After that, machines do inter-node communication. This can
boost performance when the inter-node communication cost is high.
average (bool): If True, the gradients on each worker are averaged. Otherwise, they are summed.
scattergather (bool): If true, the communication between workers are done with scatter gather instead
average (bool): If ``True``, the gradients on each worker are averaged. Otherwise, they are summed.
scattergather (bool): If ``True``, the communication between workers are done with scatter gather instead
of allreduce. This is required for using compression.
compression: If not None, the tensors will be compressed for communication. Currently "MinMaxUInt8" is
compression: If not ``None``, the tensors will be compressed for communication. Currently "MinMaxUInt8" is
supported.

Returns:
The bucket itself.
"""
Expand Down Expand Up @@ -176,6 +204,10 @@ def append_decentralized_synchronous_op(
hierarchical: bool = True,
peer_selection_mode: str = "all",
wangraying marked this conversation as resolved.
Show resolved Hide resolved
communication_interval: int = 1,
compression: Optional[str] = None,
weight: Optional[torch.Tensor] = None,
left_peer_weight: Optional[torch.Tensor] = None,
right_peer_weight: Optional[torch.Tensor] = None,
) -> BaguaBucket:
"""
Append a decentralized synchronous operation to a bucket. It will do gossipy style model averaging among workers.
Expand All @@ -187,21 +219,54 @@ def append_decentralized_synchronous_op(
hierarchical (bool): Enable hierarchical communication. Which means the GPUs on the same machine
will communicate will each other first. After that, machines do inter-node communication. This can
boost performance when the inter-node communication cost is high.
peer_selection_mode (str): Can be "all" or "shift_one". "all" means all workers'
weights are averaged in each communication step. "shift_one" means each worker
selects a different peer to do weights average in each communication step.
peer_selection_mode (str): Can be "all" or "shift_one" for full precision decentralized operation, while "ring" for
wangraying marked this conversation as resolved.
Show resolved Hide resolved
low precision decentralized operation. "all" means all workers' weights are averaged in each communication step.
"shift_one" means each worker selects a different peer to do weights average in each communication step.
"ring" means all workers are connected into a ring, and each worker communicate with its neighbors.
communication_interval (int): Number of iterations between two communication steps.
compression: If not ``None``, the tensors will be compressed for communication. Currently "MinMaxUInt8" is
supported.
weight (torch.Tensor): Local model of current worker, required for low precision decentralized operation.
left_peer_weight (torch.Tensor): Model replica of current worker's connected left peer, required for low
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to explain what "model replica" is to make user understand how to use it.

A tensor with the same size as weight?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ganshaoduo any suggestions?

precision decentralized operation.
right_peer_weight (torch.Tensor): Model replica of current worker's connected right peer, required for
low precision decentralized operation.
Returns:
The bucket itself.
"""
self.backend_bucket.append_decentralized_synchronous_op(
_get_global_state().get_internode_communicator(),
_get_global_state().get_intranode_communicator(),
hierarchical=hierarchical,
compression=None,
peer_selection_mode=peer_selection_mode,
communication_interval=communication_interval,
)
if hierarchical:
self.backend_bucket.append_decentralized_synchronous_op(
_get_global_state().get_internode_communicator(),
_get_global_state().get_intranode_communicator(),
hierarchical=hierarchical,
peer_selection_mode=peer_selection_mode,
communication_interval=communication_interval,
compression=compression,
weight=weight._bagua_backend_tensor if weight is not None else None,
left_peer_weight=left_peer_weight._bagua_backend_tensor
if left_peer_weight is not None
else None,
right_peer_weight=right_peer_weight._bagua_backend_tensor
if right_peer_weight is not None
else None,
)
else:
self.backend_bucket.append_decentralized_synchronous_op(
_get_global_state().get_global_communicator(),
None,
hierarchical=hierarchical,
peer_selection_mode=peer_selection_mode,
communication_interval=communication_interval,
compression=compression,
weight=weight._bagua_backend_tensor if weight is not None else None,
left_peer_weight=left_peer_weight._bagua_backend_tensor
if left_peer_weight is not None
else None,
right_peer_weight=right_peer_weight._bagua_backend_tensor
if right_peer_weight is not None
else None,
)

return self

def clear_ops(self) -> BaguaBucket:
Expand Down
3 changes: 2 additions & 1 deletion bagua/torch_api/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ def algorithm_reset_hook(self, input):
self._bagua_init_algorithm()

def algorithm_forward_pre_hook(self, input):
self.bagua_algorithm.init_forward_pre_hook(self)(input)
if self.training:
self.bagua_algorithm.init_forward_pre_hook(self)(input)

def record_speed_metrics_event(self, _):
if not self._speed_metrics_switch_on:
Expand Down