From 0374fba1161d81c56129948a73f57738cf2e06b6 Mon Sep 17 00:00:00 2001 From: ritaw Date: Sun, 4 Jul 2021 14:49:19 +0800 Subject: [PATCH 01/23] init commit --- bagua/torch_api/algorithms/decentralized.py | 89 +++++++++++++++++++-- bagua/torch_api/bucket.py | 43 +++++++--- 2 files changed, 117 insertions(+), 15 deletions(-) diff --git a/bagua/torch_api/algorithms/decentralized.py b/bagua/torch_api/algorithms/decentralized.py index 0588e379d..0ba678105 100644 --- a/bagua/torch_api/algorithms/decentralized.py +++ b/bagua/torch_api/algorithms/decentralized.py @@ -1,9 +1,10 @@ #!/usr/bin/env python3 - +from bagua.torch_api.globals import _get_global_state from bagua.torch_api.bucket import BaguaBucket from bagua.torch_api.tensor import BaguaTensor from bagua.torch_api.distributed import BaguaModule from bagua.torch_api.algorithms import Algorithm +from collections import defaultdict from typing import List import torch @@ -11,8 +12,8 @@ class DecentralizedAlgorithm(Algorithm): def __init__( self, + hierarchical: bool = True, peer_selection_mode: str = "all", - compression: str = None, communication_interval: int = 1, ): """ @@ -21,16 +22,15 @@ def __init__( algorithm. Args: + hierarchical (bool): Enable hierarchical communication. peer_selection_mode (str): Can be "all" or "shift_one". "all" means all workers' weights are averaged in each communication step. "shift_one" means each worker selects a different peer to do weights average in each communication step. - compression (str): Not supported yet. communication_interval (int): Number of iterations between two communication steps. """ + self.hierarchical = hierarchical self.peer_selection_mode = peer_selection_mode - self.compression = compression self.communication_interval = communication_interval - self.tensor_groups = [] def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]: parameters = bagua_module.bagua_build_params() @@ -68,7 +68,84 @@ def init_operations( ): bucket.clear_ops() bucket.append_decentralized_synchronous_op( - hierarchical=True, + hierarchical=self.hierarchical, peer_selection_mode=self.peer_selection_mode, communication_interval=self.communication_interval, ) + + +class LowPrecisionDecentralizedAlgorithm(Algorithm): + def __init__(self, hierarchical: bool = True): + self.hierarchical = hierarchical + + def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]: + parameters = bagua_module.bagua_build_params() + self.tensors = [ + param.ensure_bagua_tensor(name) for name, param in parameters.__reversed__() + ] + return self.tensors + + def init_backward_hook(self, bagua_module: BaguaModule): + def hook(parameter_name, parameter): + pass + + return hook + + def init_post_backward_hook(self, bagua_module: BaguaModule): + def hook(): + for bucket in bagua_module.bagua_buckets: + for tensor in bucket.tensors: + tensor.bagua_mark_communication_ready() + + bagua_module._bagua_backend.wait_pending_comm_ops() + + return hook + + def init_post_optimizer_step_hook(self, bagua_module: BaguaModule): + def hook(optimizer: torch.optim.Optimizer): + bagua_module._bagua_backend.execute_post_optimizer_step_comm_ops() + bagua_module._bagua_backend.wait_pending_post_optimizer_step_comm_ops() + + return hook + + def _init_states(self, bucket: BaguaBucket): + total_numel_allocated = sum( + [ + tensor._bagua_backend_tensor.num_elements_allocated() + for tensor in bucket.tensors + ] + ) + + weight_tensor = torch.zeros( + total_numel_allocated, + dtype=bucket.tensors[0].dtype, + device=bucket.tensors[0].device, + ) + offset = 0 + for tensor in bucket.tensors: + numel = tensor._bagua_backend_tensor.num_elements() + numel_allocated = tensor._bagua_backend_tensor.num_elements_allocated() + + weight_tensor[offset : offset + numel] = tensor.reshape(-1) + offset += numel_allocated + + left_peer_weight_tensor = weight_tensor.detach().clone() + right_peer_weight_tensor = weight_tensor.detach().clone() + + bucket.set_state("weight", weight_tensor) + bucket.set_state("left_peer_weight", left_peer_weight_tensor) + bucket.set_state("right_peer_weight", right_peer_weight_tensor) + + def init_operations( + self, + bagua_module: BaguaModule, + bucket: BaguaBucket, + ): + self._init_states(bucket) + torch.cuda.synchronize() + bucket.clear_ops() + bucket.append_decentralized_synchronous_op( + hierarchical=self.hierarchical, + peer_selection_mode="ring", + compression="MinMaxUInt8", + ) diff --git a/bagua/torch_api/bucket.py b/bagua/torch_api/bucket.py index 6036f2cb6..6a59f88a8 100644 --- a/bagua/torch_api/bucket.py +++ b/bagua/torch_api/bucket.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import List, Callable, Optional +from collections import defaultdict import bagua_core as B import torch @@ -36,6 +37,10 @@ def __init__( """ The bucket's name. """ + self.states = defaultdict(dict) + """ + The states contained within the bucket. + """ self.padding_tensor = None if alignment > 1: @@ -125,6 +130,15 @@ def wrapped_pyop(name): self.backend_bucket.append_python_op(wrapper_function_factory(python_function)) return self + def set_state(self, name: str, tensor: torch.Tensor): + self.states[name] = tensor + self.backend_bucket.set_state( + name, tensor.to_bagua_tensor(name).bagua_backend_tensor() + ) + + def _backend_states(self): + return self.backend_bucket.states() + def append_centralized_synchronous_op( self, hierarchical: bool = False, @@ -176,6 +190,7 @@ def append_decentralized_synchronous_op( hierarchical: bool = True, peer_selection_mode: str = "all", communication_interval: int = 1, + compression: Optional[str] = None, ) -> BaguaBucket: """ Append a decentralized synchronous operation to a bucket. It will do gossipy style model averaging among workers. @@ -194,14 +209,25 @@ def append_decentralized_synchronous_op( Returns: The bucket itself. """ - self.backend_bucket.append_decentralized_synchronous_op( - _get_global_state().get_internode_communicator(), - _get_global_state().get_intranode_communicator(), - hierarchical=hierarchical, - compression=None, - peer_selection_mode=peer_selection_mode, - communication_interval=communication_interval, - ) + if hierarchical: + self.backend_bucket.append_decentralized_synchronous_op( + _get_global_state().get_internode_communicator(), + _get_global_state().get_intranode_communicator(), + hierarchical=hierarchical, + peer_selection_mode=peer_selection_mode, + communication_interval=communication_interval, + compression=compression, + ) + else: + self.backend_bucket.append_decentralized_synchronous_op( + _get_global_state().get_global_communicator(), + None, + hierarchical=hierarchical, + peer_selection_mode=peer_selection_mode, + communication_interval=communication_interval, + compression=compression, + ) + return self def clear_ops(self) -> BaguaBucket: @@ -213,7 +239,6 @@ def clear_ops(self) -> BaguaBucket: def bytes(self) -> int: """Returns the total number of bytes occupied by the bucket. - Returns: int: number of bucket bytes """ From 5735a4f70fb2f3959aa7f9addb9891d5d5e7029e Mon Sep 17 00:00:00 2001 From: ritaw Date: Mon, 5 Jul 2021 16:19:53 +0800 Subject: [PATCH 02/23] fix algo --- bagua/torch_api/algorithms/decentralized.py | 33 ++++++----------- bagua/torch_api/bucket.py | 41 +++++++++++++++++++-- bagua/torch_api/distributed.py | 3 +- 3 files changed, 50 insertions(+), 27 deletions(-) diff --git a/bagua/torch_api/algorithms/decentralized.py b/bagua/torch_api/algorithms/decentralized.py index 0ba678105..8621fc782 100644 --- a/bagua/torch_api/algorithms/decentralized.py +++ b/bagua/torch_api/algorithms/decentralized.py @@ -76,6 +76,14 @@ def init_operations( class LowPrecisionDecentralizedAlgorithm(Algorithm): def __init__(self, hierarchical: bool = True): + """ + Create an instance of the + `Difference Compression Decentralized `_ + algorithm. + + Args: + hierarchical (bool): Enable hierarchical communication. + """ self.hierarchical = hierarchical def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]: @@ -109,30 +117,11 @@ def hook(optimizer: torch.optim.Optimizer): return hook def _init_states(self, bucket: BaguaBucket): - total_numel_allocated = sum( - [ - tensor._bagua_backend_tensor.num_elements_allocated() - for tensor in bucket.tensors - ] - ) - - weight_tensor = torch.zeros( - total_numel_allocated, - dtype=bucket.tensors[0].dtype, - device=bucket.tensors[0].device, - ) - offset = 0 - for tensor in bucket.tensors: - numel = tensor._bagua_backend_tensor.num_elements() - numel_allocated = tensor._bagua_backend_tensor.num_elements_allocated() - - weight_tensor[offset : offset + numel] = tensor.reshape(-1) - offset += numel_allocated + bucket_flattened_tensor = bucket.flattened_tensor() - left_peer_weight_tensor = weight_tensor.detach().clone() - right_peer_weight_tensor = weight_tensor.detach().clone() + left_peer_weight_tensor = bucket_flattened_tensor.detach().clone() + right_peer_weight_tensor = bucket_flattened_tensor.detach().clone() - bucket.set_state("weight", weight_tensor) bucket.set_state("left_peer_weight", left_peer_weight_tensor) bucket.set_state("right_peer_weight", right_peer_weight_tensor) diff --git a/bagua/torch_api/bucket.py b/bagua/torch_api/bucket.py index 6a59f88a8..0fa39cba5 100644 --- a/bagua/torch_api/bucket.py +++ b/bagua/torch_api/bucket.py @@ -39,7 +39,7 @@ def __init__( """ self.states = defaultdict(dict) """ - The states contained within the bucket. + The state of the bucket as a :class:`dict`. """ self.padding_tensor = None @@ -70,6 +70,30 @@ def __init__( for tensor in self._all_tensors: tensor._bagua_bucket = self + def flattened_tensor(self): + """ + Returns a tensor contiguous in memory which contains the same data as `self` tensors and padding tensor (if exists). + If `self` tensors and padding tensor are already flattened, this function returns a tensor corresponding to their + underlying storage. + """ + if self.flatten: + return self.backend_tensor + + total_size = 0 + for tensor in self._all_tensors: + total_size += tensor.numel() + + flatten_tensor = torch.zeros(total_size, dtype=self._all_tensors[0].dtype).to( + self._all_tensors[0].device + ) + + offset = 0 + for tensor in self._all_tensors: + # copy data + flatten_tensor[offset : offset + tensor.numel()] = tensor.data.reshape(-1) + offset += tensor.numel() + return flatten_tensor + def _flatten_(self): """ Flatten inner tensors in place. @@ -94,6 +118,9 @@ def _flatten_(self): flatten_tensor[offset : offset + tensor.numel()] = tensor.data.reshape(-1) tensor.bagua_set_storage(flatten_storage, offset) offset += tensor.numel() + + # set backend tensor + self.backend_tensor = flatten_tensor # check assert self.check_flatten() @@ -131,14 +158,20 @@ def wrapped_pyop(name): return self def set_state(self, name: str, tensor: torch.Tensor): + """ + Set a state to the bucket. + + This operation will create a Bagua tensor from `tensor`. The original tensor is not changed. + + Args: + name: the key of the state + tensor: the value of the state + """ self.states[name] = tensor self.backend_bucket.set_state( name, tensor.to_bagua_tensor(name).bagua_backend_tensor() ) - def _backend_states(self): - return self.backend_bucket.states() - def append_centralized_synchronous_op( self, hierarchical: bool = False, diff --git a/bagua/torch_api/distributed.py b/bagua/torch_api/distributed.py index bd2ecae1e..7edc31afb 100644 --- a/bagua/torch_api/distributed.py +++ b/bagua/torch_api/distributed.py @@ -227,7 +227,8 @@ def algorithm_reset_hook(self, input): self._bagua_init_algorithm() def algorithm_forward_pre_hook(self, input): - self.bagua_algorithm.init_forward_pre_hook(self)(input) + if self.training: + self.bagua_algorithm.init_forward_pre_hook(self)(input) def record_speed_metrics_event(self, _): if not self._speed_metrics_switch_on: From d63c389c4f3f1969f8302deb50fd68f525aa0d05 Mon Sep 17 00:00:00 2001 From: ritaw Date: Mon, 5 Jul 2021 16:32:32 +0800 Subject: [PATCH 03/23] fix style --- bagua/torch_api/algorithms/decentralized.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/bagua/torch_api/algorithms/decentralized.py b/bagua/torch_api/algorithms/decentralized.py index 8621fc782..a0fc3aa8e 100644 --- a/bagua/torch_api/algorithms/decentralized.py +++ b/bagua/torch_api/algorithms/decentralized.py @@ -1,10 +1,8 @@ #!/usr/bin/env python3 -from bagua.torch_api.globals import _get_global_state from bagua.torch_api.bucket import BaguaBucket from bagua.torch_api.tensor import BaguaTensor from bagua.torch_api.distributed import BaguaModule from bagua.torch_api.algorithms import Algorithm -from collections import defaultdict from typing import List import torch From f42f6e0873c0f523d45856cd2a59b559dc978d52 Mon Sep 17 00:00:00 2001 From: ritaw Date: Mon, 5 Jul 2021 16:52:03 +0800 Subject: [PATCH 04/23] fix doc --- bagua/torch_api/bucket.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/bagua/torch_api/bucket.py b/bagua/torch_api/bucket.py index 0fa39cba5..eeddb6b3e 100644 --- a/bagua/torch_api/bucket.py +++ b/bagua/torch_api/bucket.py @@ -23,9 +23,9 @@ def __init__( tensors: A list of Bagua tensors to be put in the bucket. name: The unique name of the bucket. - flatten: If True, flatten the input tensors so that they are + flatten: If ``True``, flatten the input tensors so that they are contiguous in memory. - alignment: If alignment > 1, Bagua will create a padding tensor to + alignment: If `alignment > 1`, Bagua will create a padding tensor to the bucket so that the total number of elements in the bucket divides the given alignment. """ @@ -164,8 +164,8 @@ def set_state(self, name: str, tensor: torch.Tensor): This operation will create a Bagua tensor from `tensor`. The original tensor is not changed. Args: - name: the key of the state - tensor: the value of the state + name: the key of the state. + tensor: the value of the state. """ self.states[name] = tensor self.backend_bucket.set_state( @@ -190,11 +190,12 @@ def append_centralized_synchronous_op( hierarchical (bool): Enable hierarchical communication. Which means the GPUs on the same machine will communicate will each other first. After that, machines do inter-node communication. This can boost performance when the inter-node communication cost is high. - average (bool): If True, the gradients on each worker are averaged. Otherwise, they are summed. - scattergather (bool): If true, the communication between workers are done with scatter gather instead + average (bool): If ``True``, the gradients on each worker are averaged. Otherwise, they are summed. + scattergather (bool): If ``True``, the communication between workers are done with scatter gather instead of allreduce. This is required for using compression. - compression: If not None, the tensors will be compressed for communication. Currently "MinMaxUInt8" is + compression: If not ``None``, the tensors will be compressed for communication. Currently "MinMaxUInt8" is supported. + Returns: The bucket itself. """ @@ -235,10 +236,13 @@ def append_decentralized_synchronous_op( hierarchical (bool): Enable hierarchical communication. Which means the GPUs on the same machine will communicate will each other first. After that, machines do inter-node communication. This can boost performance when the inter-node communication cost is high. - peer_selection_mode (str): Can be "all" or "shift_one". "all" means all workers' + peer_selection_mode (str): Can be "all" or "shift_one" or "ring". "all" means all workers' weights are averaged in each communication step. "shift_one" means each worker selects a different peer to do weights average in each communication step. + "ring" means all workers are connected into a ring, and each worker communicate with its neighbors. communication_interval (int): Number of iterations between two communication steps. + compression: If not ``None``, the tensors will be compressed for communication. Currently "MinMaxUInt8" is + supported. Returns: The bucket itself. """ @@ -272,6 +276,7 @@ def clear_ops(self) -> BaguaBucket: def bytes(self) -> int: """Returns the total number of bytes occupied by the bucket. + Returns: int: number of bucket bytes """ From 9ba20d92d43bb8b6c57e54736af4de4fb4fde430 Mon Sep 17 00:00:00 2001 From: ritaw Date: Tue, 6 Jul 2021 20:02:30 +0800 Subject: [PATCH 05/23] update --- bagua/torch_api/algorithms/decentralized.py | 24 ++++++++++++------ bagua/torch_api/bucket.py | 28 ++++++++++++++++++--- 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/bagua/torch_api/algorithms/decentralized.py b/bagua/torch_api/algorithms/decentralized.py index a0fc3aa8e..a6b07c2f9 100644 --- a/bagua/torch_api/algorithms/decentralized.py +++ b/bagua/torch_api/algorithms/decentralized.py @@ -73,7 +73,7 @@ def init_operations( class LowPrecisionDecentralizedAlgorithm(Algorithm): - def __init__(self, hierarchical: bool = True): + def __init__(self, hierarchical: bool = True, communication_interval: int = 1): """ Create an instance of the `Difference Compression Decentralized `_ @@ -81,8 +81,10 @@ def __init__(self, hierarchical: bool = True): Args: hierarchical (bool): Enable hierarchical communication. + communication_interval (int): Number of iterations between two communication steps. """ self.hierarchical = hierarchical + self.communication_interval = communication_interval def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]: parameters = bagua_module.bagua_build_params() @@ -99,27 +101,29 @@ def hook(parameter_name, parameter): def init_post_backward_hook(self, bagua_module: BaguaModule): def hook(): - for bucket in bagua_module.bagua_buckets: - for tensor in bucket.tensors: - tensor.bagua_mark_communication_ready() - - bagua_module._bagua_backend.wait_pending_comm_ops() + pass return hook def init_post_optimizer_step_hook(self, bagua_module: BaguaModule): def hook(optimizer: torch.optim.Optimizer): - bagua_module._bagua_backend.execute_post_optimizer_step_comm_ops() - bagua_module._bagua_backend.wait_pending_post_optimizer_step_comm_ops() + for bucket in bagua_module.bagua_buckets: + + for tensor in bucket.tensors: + tensor.bagua_mark_communication_ready() + + bagua_module._bagua_backend.wait_pending_comm_ops() return hook def _init_states(self, bucket: BaguaBucket): bucket_flattened_tensor = bucket.flattened_tensor() + weight_tensor = bucket_flattened_tensor.detach().clone() left_peer_weight_tensor = bucket_flattened_tensor.detach().clone() right_peer_weight_tensor = bucket_flattened_tensor.detach().clone() + bucket.set_state("weight", weight_tensor) bucket.set_state("left_peer_weight", left_peer_weight_tensor) bucket.set_state("right_peer_weight", right_peer_weight_tensor) @@ -134,5 +138,9 @@ def init_operations( bucket.append_decentralized_synchronous_op( hierarchical=self.hierarchical, peer_selection_mode="ring", + communication_interval=self.communication_interval, compression="MinMaxUInt8", + weight=bucket.states["weight"], + left_peer_weight=bucket.states["left_peer_weight"], + right_peer_weight=bucket.states["right_peer_weight"], ) diff --git a/bagua/torch_api/bucket.py b/bagua/torch_api/bucket.py index eeddb6b3e..118e4d77a 100644 --- a/bagua/torch_api/bucket.py +++ b/bagua/torch_api/bucket.py @@ -167,10 +167,8 @@ def set_state(self, name: str, tensor: torch.Tensor): name: the key of the state. tensor: the value of the state. """ - self.states[name] = tensor - self.backend_bucket.set_state( - name, tensor.to_bagua_tensor(name).bagua_backend_tensor() - ) + + self.states[name] = tensor.to_bagua_tensor(name) def append_centralized_synchronous_op( self, @@ -225,6 +223,9 @@ def append_decentralized_synchronous_op( peer_selection_mode: str = "all", communication_interval: int = 1, compression: Optional[str] = None, + weight: Optional[torch.Tensor] = None, + left_peer_weight: Optional[torch.Tensor] = None, + right_peer_weight: Optional[torch.Tensor] = None, ) -> BaguaBucket: """ Append a decentralized synchronous operation to a bucket. It will do gossipy style model averaging among workers. @@ -243,6 +244,11 @@ def append_decentralized_synchronous_op( communication_interval (int): Number of iterations between two communication steps. compression: If not ``None``, the tensors will be compressed for communication. Currently "MinMaxUInt8" is supported. + weight (torch.Tensor): Local model of current worker, required for low precision decentralized algorithm. + left_peer_weight (torch.Tensor): Model replica of current worker's connected left peer, required for low + precision decentralized algorithm. + right_peer_weight (torch.Tensor): Model replica of current worker's connected right peer, required for + low precision decentralizd algorithm. Returns: The bucket itself. """ @@ -254,6 +260,13 @@ def append_decentralized_synchronous_op( peer_selection_mode=peer_selection_mode, communication_interval=communication_interval, compression=compression, + weight=weight._bagua_backend_tensor if weight is not None else None, + left_peer_weight=left_peer_weight._bagua_backend_tensor + if left_peer_weight is not None + else None, + right_peer_weight=right_peer_weight._bagua_backend_tensor + if right_peer_weight is not None + else None, ) else: self.backend_bucket.append_decentralized_synchronous_op( @@ -263,6 +276,13 @@ def append_decentralized_synchronous_op( peer_selection_mode=peer_selection_mode, communication_interval=communication_interval, compression=compression, + weight=weight._bagua_backend_tensor if weight is not None else None, + left_peer_weight=left_peer_weight._bagua_backend_tensor + if left_peer_weight is not None + else None, + right_peer_weight=right_peer_weight._bagua_backend_tensor + if right_peer_weight is not None + else None, ) return self From cf6c040637abbb7b0caee48135fbaf2adc0a7cf5 Mon Sep 17 00:00:00 2001 From: ritaw Date: Tue, 6 Jul 2021 20:09:29 +0800 Subject: [PATCH 06/23] add return type for flatten_tensor --- bagua/torch_api/bucket.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bagua/torch_api/bucket.py b/bagua/torch_api/bucket.py index 118e4d77a..367f51f1a 100644 --- a/bagua/torch_api/bucket.py +++ b/bagua/torch_api/bucket.py @@ -70,7 +70,7 @@ def __init__( for tensor in self._all_tensors: tensor._bagua_bucket = self - def flattened_tensor(self): + def flattened_tensor(self) -> torch.Tensor: """ Returns a tensor contiguous in memory which contains the same data as `self` tensors and padding tensor (if exists). If `self` tensors and padding tensor are already flattened, this function returns a tensor corresponding to their From eb15e21061546d1e93cc478ca296ce30bbc33807 Mon Sep 17 00:00:00 2001 From: ritaw Date: Wed, 7 Jul 2021 17:27:42 +0800 Subject: [PATCH 07/23] update --- bagua/torch_api/algorithms/decentralized.py | 29 ++++++++++++++------- bagua/torch_api/bucket.py | 29 +++++---------------- 2 files changed, 25 insertions(+), 33 deletions(-) diff --git a/bagua/torch_api/algorithms/decentralized.py b/bagua/torch_api/algorithms/decentralized.py index a6b07c2f9..41a674890 100644 --- a/bagua/torch_api/algorithms/decentralized.py +++ b/bagua/torch_api/algorithms/decentralized.py @@ -85,6 +85,7 @@ def __init__(self, hierarchical: bool = True, communication_interval: int = 1): """ self.hierarchical = hierarchical self.communication_interval = communication_interval + self.optimizer_step_count = 0 def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]: parameters = bagua_module.bagua_build_params() @@ -107,12 +108,16 @@ def hook(): def init_post_optimizer_step_hook(self, bagua_module: BaguaModule): def hook(optimizer: torch.optim.Optimizer): - for bucket in bagua_module.bagua_buckets: + self.optimizer_step_count += 1 - for tensor in bucket.tensors: - tensor.bagua_mark_communication_ready() + if self.optimizer_step_count == len(bagua_module.bagua_optimizers): + for bucket in bagua_module.bagua_buckets: - bagua_module._bagua_backend.wait_pending_comm_ops() + for tensor in bucket.tensors: + tensor.bagua_mark_communication_ready() + + bagua_module._bagua_backend.wait_pending_comm_ops() + self.optimizer_step_count = 0 return hook @@ -123,9 +128,13 @@ def _init_states(self, bucket: BaguaBucket): left_peer_weight_tensor = bucket_flattened_tensor.detach().clone() right_peer_weight_tensor = bucket_flattened_tensor.detach().clone() - bucket.set_state("weight", weight_tensor) - bucket.set_state("left_peer_weight", left_peer_weight_tensor) - bucket.set_state("right_peer_weight", right_peer_weight_tensor) + bucket._weight = weight_tensor.to_bagua_tensor("weight") + bucket._left_peer_weight = left_peer_weight_tensor.to_bagua_tensor( + "left_peer_weight" + ) + bucket._right_peer_weight = right_peer_weight_tensor.to_bagua_tensor( + "right_peer_weight" + ) def init_operations( self, @@ -140,7 +149,7 @@ def init_operations( peer_selection_mode="ring", communication_interval=self.communication_interval, compression="MinMaxUInt8", - weight=bucket.states["weight"], - left_peer_weight=bucket.states["left_peer_weight"], - right_peer_weight=bucket.states["right_peer_weight"], + weight=bucket._weight, + left_peer_weight=bucket._left_peer_weight, + right_peer_weight=bucket._right_peer_weight, ) diff --git a/bagua/torch_api/bucket.py b/bagua/torch_api/bucket.py index 367f51f1a..53d4f4bee 100644 --- a/bagua/torch_api/bucket.py +++ b/bagua/torch_api/bucket.py @@ -37,10 +37,6 @@ def __init__( """ The bucket's name. """ - self.states = defaultdict(dict) - """ - The state of the bucket as a :class:`dict`. - """ self.padding_tensor = None if alignment > 1: @@ -157,19 +153,6 @@ def wrapped_pyop(name): self.backend_bucket.append_python_op(wrapper_function_factory(python_function)) return self - def set_state(self, name: str, tensor: torch.Tensor): - """ - Set a state to the bucket. - - This operation will create a Bagua tensor from `tensor`. The original tensor is not changed. - - Args: - name: the key of the state. - tensor: the value of the state. - """ - - self.states[name] = tensor.to_bagua_tensor(name) - def append_centralized_synchronous_op( self, hierarchical: bool = False, @@ -237,18 +220,18 @@ def append_decentralized_synchronous_op( hierarchical (bool): Enable hierarchical communication. Which means the GPUs on the same machine will communicate will each other first. After that, machines do inter-node communication. This can boost performance when the inter-node communication cost is high. - peer_selection_mode (str): Can be "all" or "shift_one" or "ring". "all" means all workers' - weights are averaged in each communication step. "shift_one" means each worker - selects a different peer to do weights average in each communication step. + peer_selection_mode (str): Can be "all" or "shift_one" for full precision decentralized operation, while "ring" for + low precision decentralized operation. "all" means all workers' weights are averaged in each communication step. + "shift_one" means each worker selects a different peer to do weights average in each communication step. "ring" means all workers are connected into a ring, and each worker communicate with its neighbors. communication_interval (int): Number of iterations between two communication steps. compression: If not ``None``, the tensors will be compressed for communication. Currently "MinMaxUInt8" is supported. - weight (torch.Tensor): Local model of current worker, required for low precision decentralized algorithm. + weight (torch.Tensor): Local model of current worker, required for low precision decentralized operation. left_peer_weight (torch.Tensor): Model replica of current worker's connected left peer, required for low - precision decentralized algorithm. + precision decentralized operation. right_peer_weight (torch.Tensor): Model replica of current worker's connected right peer, required for - low precision decentralizd algorithm. + low precision decentralized operation. Returns: The bucket itself. """ From 99799c700d3fb6e1618c3c04dabb27cb05256abe Mon Sep 17 00:00:00 2001 From: ritaw Date: Wed, 7 Jul 2021 17:44:17 +0800 Subject: [PATCH 08/23] update post step hook --- bagua/torch_api/algorithms/decentralized.py | 15 +++++---------- bagua/torch_api/bucket.py | 1 - 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/bagua/torch_api/algorithms/decentralized.py b/bagua/torch_api/algorithms/decentralized.py index 41a674890..6f5bc2cb8 100644 --- a/bagua/torch_api/algorithms/decentralized.py +++ b/bagua/torch_api/algorithms/decentralized.py @@ -85,7 +85,6 @@ def __init__(self, hierarchical: bool = True, communication_interval: int = 1): """ self.hierarchical = hierarchical self.communication_interval = communication_interval - self.optimizer_step_count = 0 def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]: parameters = bagua_module.bagua_build_params() @@ -108,16 +107,12 @@ def hook(): def init_post_optimizer_step_hook(self, bagua_module: BaguaModule): def hook(optimizer: torch.optim.Optimizer): - self.optimizer_step_count += 1 + for group in optimizer.param_groups: + for param in group["params"]: + if param.is_bagua_tensor(): + param.bagua_mark_communication_ready() - if self.optimizer_step_count == len(bagua_module.bagua_optimizers): - for bucket in bagua_module.bagua_buckets: - - for tensor in bucket.tensors: - tensor.bagua_mark_communication_ready() - - bagua_module._bagua_backend.wait_pending_comm_ops() - self.optimizer_step_count = 0 + bagua_module._bagua_backend.wait_pending_comm_ops() return hook diff --git a/bagua/torch_api/bucket.py b/bagua/torch_api/bucket.py index 53d4f4bee..4d08653ea 100644 --- a/bagua/torch_api/bucket.py +++ b/bagua/torch_api/bucket.py @@ -2,7 +2,6 @@ from __future__ import annotations from typing import List, Callable, Optional -from collections import defaultdict import bagua_core as B import torch From 62158017ee3bea1c32eafc0a8e100e08ab5f8677 Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 8 Jul 2021 13:26:09 +0800 Subject: [PATCH 09/23] . --- bagua/torch_api/algorithms/decentralized.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/bagua/torch_api/algorithms/decentralized.py b/bagua/torch_api/algorithms/decentralized.py index 6f5bc2cb8..f9fd073bf 100644 --- a/bagua/torch_api/algorithms/decentralized.py +++ b/bagua/torch_api/algorithms/decentralized.py @@ -91,6 +91,19 @@ def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]: self.tensors = [ param.ensure_bagua_tensor(name) for name, param in parameters.__reversed__() ] + optimizer_param_ids = [ + id(param) + for optimizer in bagua_module.bagua_optimizers + for group in optimizer.param_groups + for param in group["params"] + ] + + for name, param in parameters: + if id(param) not in optimizer_param_ids: + raise RuntimeError( + f"Parameter {name} is not used by the optimizer, need to include it " + "to your module attribute `_bagua_params_and_buffers_to_ignore` to ignore it." + ) return self.tensors def init_backward_hook(self, bagua_module: BaguaModule): From b0994af1973c6ad1015146986a3d9519d94ba767 Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 8 Jul 2021 16:59:12 +0800 Subject: [PATCH 10/23] add test --- tests/internal/__init__.py | 0 tests/internal/common_utils.py | 10 ++ .../test_low_precision_decentralized.py | 120 ++++++++++++++++++ 3 files changed, 130 insertions(+) create mode 100644 tests/internal/__init__.py create mode 100644 tests/internal/common_utils.py create mode 100644 tests/torch_api/test_low_precision_decentralized.py diff --git a/tests/internal/__init__.py b/tests/internal/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/internal/common_utils.py b/tests/internal/common_utils.py new file mode 100644 index 000000000..6cfbf65bd --- /dev/null +++ b/tests/internal/common_utils.py @@ -0,0 +1,10 @@ +import socket + + +def find_free_port(): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("localhost", 0)) + sockname = sock.getsockname() + sock.close() + return sockname[1] diff --git a/tests/torch_api/test_low_precision_decentralized.py b/tests/torch_api/test_low_precision_decentralized.py new file mode 100644 index 000000000..9e97038fb --- /dev/null +++ b/tests/torch_api/test_low_precision_decentralized.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from tests.internal.common_utils import find_free_port +import bagua.torch_api as bagua +import unittest +import torch.multiprocessing as mp +import os + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.fc1 = nn.Linear(2, 10, bias=False) + self.fc2 = nn.Linear(10, 50, bias=True) + self.fc3 = nn.Linear(50, 4, bias=False) + self.relu = nn.ReLU() + + # self.no_grad_param = nn.Parameter(torch.tensor([2.0, 2.0]), requires_grad=False) + + def forward(self, x): + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + x = self.fc3(x) + return F.softmax(x, dim=1) + + +def run_model(gpu, nprocs, hierarchical, communication_interval, results): + # initialize subprocess env + os.environ["RANK"] = str(gpu) + os.environ["LOCAL_RANK"] = str(gpu) + + # init bagua distributed process group + torch.cuda.set_device(bagua.get_local_rank()) + bagua.init_process_group() + + print( + f"initialize bagua training process, rank: {bagua.get_rank()}, world_size: {bagua.get_world_size()}" + ) + + # construct model and optimizer, etc. + model = Net().cuda() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + loss_fn = nn.MSELoss() + + # wrap model + model = model.with_bagua( + [optimizer], + bagua.algorithms.decentralized.LowPrecisionDecentralizedAlgorithm( + hierarchical=hierarchical, communication_interval=communication_interval + ), + ) + + for batch_idx in range(10): + data = torch.randn(4, 2).cuda() + target = torch.randn(4, 4).cuda() + + optimizer.zero_grad() + output = model(data) + loss = loss_fn(output, target) + + loss.backward() + optimizer.step() + + ret = results[gpu] + for bucket in model.bagua_buckets: + ret.bucket_weight = torch.norm(bucket.flattened_tensor()) + ret.weight = torch.norm(bucket._weight) + ret.left_peer_weight = torch.norm(bucket._left_peer_weight) + ret.right_peer_weight = torch.norm(bucket._right_peer_weight) + + +class Result(object): + def __init__(self): + self.bucket_weight = 0 + self.weight = 0 + self.left_peer_weight = 0 + self.right_peer_weight = 0 + + +class TestLowPrecisionDecentralized(unittest.TestCase): + def run_test_locally(self, nprocs, hierarchical, communication_interval): + os.environ["WORLD_SIZE"] = str(nprocs) + os.environ["LOCAL_WORLD_SIZE"] = str(nprocs) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(find_free_port()) + os.environ["BAGUA_SERVICE_PORT"] = str(find_free_port()) + + results = [Result() for _ in range(nprocs)] + mp.spawn( + run_model, + nprocs=nprocs, + args=(nprocs, hierarchical, communication_interval, results), + ) + + for rank in range(nprocs): + left_peer_rank = (rank + nprocs - 1) % nprocs + right_peer_rank = (rank + 1) % nprocs + + if hierarchical: + # all worker have the same bucket weight + self.assertTrue( + results[rank].bucket_weight == results[left_peer_rank].bucket_weight + ) + else: + self.assertTrue( + results[rank].weight == results[left_peer_rank].right_peer_weight + ) + self.assertTrue( + results[rank].weight == results[right_peer_rank].left_peer_weight + ) + + def test_algorithm(self): + self.run_test_locally(nprocs=4, hierarchical=False, communication_interval=1) + self.run_test_locally(nprocs=4, hierarchical=False, communication_interval=2) + self.run_test_locally(nprocs=4, hierarchical=True, communication_interval=1) + + +if __name__ == "__main__": + unittest.main() From e166f46f2365365fd40a17552515b46fbf278ebb Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 8 Jul 2021 17:58:02 +0800 Subject: [PATCH 11/23] update msg --- bagua/torch_api/algorithms/decentralized.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bagua/torch_api/algorithms/decentralized.py b/bagua/torch_api/algorithms/decentralized.py index f9fd073bf..3cb32226b 100644 --- a/bagua/torch_api/algorithms/decentralized.py +++ b/bagua/torch_api/algorithms/decentralized.py @@ -101,8 +101,9 @@ def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]: for name, param in parameters: if id(param) not in optimizer_param_ids: raise RuntimeError( - f"Parameter {name} is not used by the optimizer, need to include it " - "to your module attribute `_bagua_params_and_buffers_to_ignore` to ignore it." + f"Module parameter {name} is not used by your optimizer(s), need to exclude it " + "by adding the parameter name to the `List` attribute `_bagua_params_and_buffers_to_ignore` " + "of your module." ) return self.tensors From a6eebec4884ce440c3a5e3f171d60c3007cb887e Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 8 Jul 2021 18:19:22 +0800 Subject: [PATCH 12/23] . --- bagua/torch_api/algorithms/decentralized.py | 3 ++- bagua/torch_api/bucket.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/bagua/torch_api/algorithms/decentralized.py b/bagua/torch_api/algorithms/decentralized.py index f0f42f422..7a466826a 100644 --- a/bagua/torch_api/algorithms/decentralized.py +++ b/bagua/torch_api/algorithms/decentralized.py @@ -90,7 +90,8 @@ def __init__(self, hierarchical: bool = True, communication_interval: int = 1): def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]: parameters = bagua_module.bagua_build_params() self.tensors = [ - param.ensure_bagua_tensor(name) for name, param in parameters.__reversed__() + param.ensure_bagua_tensor(name, bagua_module.bagua_module_name) + for name, param in parameters.__reversed__() ] optimizer_param_ids = [ id(param) diff --git a/bagua/torch_api/bucket.py b/bagua/torch_api/bucket.py index da61593b9..8fe9cda46 100644 --- a/bagua/torch_api/bucket.py +++ b/bagua/torch_api/bucket.py @@ -242,8 +242,8 @@ def append_decentralized_synchronous_op( """ if hierarchical: self.backend_bucket.append_decentralized_synchronous_op( - _get_global_state().get_internode_communicator(), - _get_global_state().get_intranode_communicator(), + self._bagua_backend.internode_communicator, + self._bagua_backend.intranode_communicator, hierarchical=hierarchical, peer_selection_mode=peer_selection_mode, communication_interval=communication_interval, @@ -258,7 +258,7 @@ def append_decentralized_synchronous_op( ) else: self.backend_bucket.append_decentralized_synchronous_op( - _get_global_state().get_global_communicator(), + self._bagua_backend.global_communicator, None, hierarchical=hierarchical, peer_selection_mode=peer_selection_mode, From 9930a732ab55949645831a6631abbf79465e1726 Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 8 Jul 2021 18:38:57 +0800 Subject: [PATCH 13/23] skip test when cuda not available --- tests/torch_api/test_low_precision_decentralized.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/torch_api/test_low_precision_decentralized.py b/tests/torch_api/test_low_precision_decentralized.py index 9e97038fb..e19c1c4f4 100644 --- a/tests/torch_api/test_low_precision_decentralized.py +++ b/tests/torch_api/test_low_precision_decentralized.py @@ -79,7 +79,12 @@ def __init__(self): class TestLowPrecisionDecentralized(unittest.TestCase): - def run_test_locally(self, nprocs, hierarchical, communication_interval): + def run_test_locally(self, hierarchical, communication_interval): + if not torch.cuda.is_available(): + print("skip tests since cuda is not available") + return + + nprocs = torch.cuda.device_count() os.environ["WORLD_SIZE"] = str(nprocs) os.environ["LOCAL_WORLD_SIZE"] = str(nprocs) os.environ["MASTER_ADDR"] = "127.0.0.1" @@ -111,9 +116,9 @@ def run_test_locally(self, nprocs, hierarchical, communication_interval): ) def test_algorithm(self): - self.run_test_locally(nprocs=4, hierarchical=False, communication_interval=1) - self.run_test_locally(nprocs=4, hierarchical=False, communication_interval=2) - self.run_test_locally(nprocs=4, hierarchical=True, communication_interval=1) + self.run_test_locally(hierarchical=False, communication_interval=1) + self.run_test_locally(hierarchical=False, communication_interval=2) + self.run_test_locally(hierarchical=True, communication_interval=1) if __name__ == "__main__": From 4453bcf950902d00899e03a15911ce6f7c36cbc7 Mon Sep 17 00:00:00 2001 From: ritaw Date: Fri, 9 Jul 2021 10:23:04 +0800 Subject: [PATCH 14/23] . --- tests/torch_api/test_low_precision_decentralized.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/torch_api/test_low_precision_decentralized.py b/tests/torch_api/test_low_precision_decentralized.py index e19c1c4f4..618d542f2 100644 --- a/tests/torch_api/test_low_precision_decentralized.py +++ b/tests/torch_api/test_low_precision_decentralized.py @@ -16,8 +16,6 @@ def __init__(self): self.fc3 = nn.Linear(50, 4, bias=False) self.relu = nn.ReLU() - # self.no_grad_param = nn.Parameter(torch.tensor([2.0, 2.0]), requires_grad=False) - def forward(self, x): x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) @@ -103,7 +101,7 @@ def run_test_locally(self, hierarchical, communication_interval): right_peer_rank = (rank + 1) % nprocs if hierarchical: - # all worker have the same bucket weight + # all workers have equal weights self.assertTrue( results[rank].bucket_weight == results[left_peer_rank].bucket_weight ) From 30301435ac8af6fbc654bfd418920df6ca6ffe20 Mon Sep 17 00:00:00 2001 From: ritaw Date: Fri, 9 Jul 2021 19:37:06 +0800 Subject: [PATCH 15/23] add tests, not ready --- tests/internal/compressor.py | 25 ++ .../test_low_precision_decentralized.py | 241 ++++++++++++++++-- 2 files changed, 248 insertions(+), 18 deletions(-) create mode 100644 tests/internal/compressor.py diff --git a/tests/internal/compressor.py b/tests/internal/compressor.py new file mode 100644 index 000000000..61f73a0a7 --- /dev/null +++ b/tests/internal/compressor.py @@ -0,0 +1,25 @@ +import torch + + +class MinMaxUInt8: + def __init__(self): + self.eps = 1e-7 + self.quantization_level = 255 + + def compress(self, tensor): + _max = torch.max(tensor) + _min = torch.min(tensor) + + scale = self.quantization_level / (_max - _min + self.eps) + upper_bound = torch.round(_max * scale) + lower_bound = upper_bound - self.quantization_level + + level = (tensor * scale).int() + level = torch.clamp(level, min=lower_bound) + return _min, _max, level - lower_bound + + def decompress(self, _min, _max, compressed): + scale = self.quantization_level / (_max - _min + self.eps) + upper_bound = torch.round(_max * scale) + lower_bound = upper_bound - self.quantization_level + return (compressed + lower_bound) / scale diff --git a/tests/torch_api/test_low_precision_decentralized.py b/tests/torch_api/test_low_precision_decentralized.py index 618d542f2..a662b04df 100644 --- a/tests/torch_api/test_low_precision_decentralized.py +++ b/tests/torch_api/test_low_precision_decentralized.py @@ -2,10 +2,12 @@ import torch.nn as nn import torch.nn.functional as F from tests.internal.common_utils import find_free_port +from tests.internal.compressor import MinMaxUInt8 import bagua.torch_api as bagua import unittest import torch.multiprocessing as mp import os +from bagua.torch_api.utils import apply_flattened_call, flatten class Net(nn.Module): @@ -23,11 +25,28 @@ def forward(self, x): return F.softmax(x, dim=1) -def run_model(gpu, nprocs, hierarchical, communication_interval, results): +def _init_env(gpu): + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.manual_seed(47) # initialize subprocess env os.environ["RANK"] = str(gpu) os.environ["LOCAL_RANK"] = str(gpu) + +def run_model(gpu, nprocs, hierarchical, communication_interval, results): + _init_env(gpu) + + # construct model and optimizer, etc. + model = Net() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + loss_fn = nn.MSELoss() + + ret = results[gpu] + ret.init_weight.copy_( + torch.norm(flatten([param.data for param in model.parameters()])) + ) + # init bagua distributed process group torch.cuda.set_device(bagua.get_local_rank()) bagua.init_process_group() @@ -36,10 +55,7 @@ def run_model(gpu, nprocs, hierarchical, communication_interval, results): f"initialize bagua training process, rank: {bagua.get_rank()}, world_size: {bagua.get_world_size()}" ) - # construct model and optimizer, etc. - model = Net().cuda() - optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - loss_fn = nn.MSELoss() + model.cuda() # wrap model model = model.with_bagua( @@ -49,7 +65,7 @@ def run_model(gpu, nprocs, hierarchical, communication_interval, results): ), ) - for batch_idx in range(10): + for _ in range(10): data = torch.randn(4, 2).cuda() target = torch.randn(4, 4).cuda() @@ -61,19 +77,159 @@ def run_model(gpu, nprocs, hierarchical, communication_interval, results): optimizer.step() ret = results[gpu] - for bucket in model.bagua_buckets: - ret.bucket_weight = torch.norm(bucket.flattened_tensor()) - ret.weight = torch.norm(bucket._weight) - ret.left_peer_weight = torch.norm(bucket._left_peer_weight) - ret.right_peer_weight = torch.norm(bucket._right_peer_weight) + bucket = model.bagua_buckets[0] + ret.bucket_weight.copy_(flatten([param.data for param in model.parameters()])) + ret.weight.copy_(torch.norm(bucket._weight)) + ret.left_peer_weight.copy_(torch.norm(bucket._left_peer_weight)) + ret.right_peer_weight.copy_(torch.norm(bucket._right_peer_weight)) + + +def run_torch_model(gpu, nprocs, hierarchical, communication_interval, results): + _init_env(gpu) + + # construct model and optimizer, etc. + model = Net() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + loss_fn = nn.MSELoss() + + ret = results[gpu] + ret.init_weight.copy_( + torch.norm(flatten([param.data for param in model.parameters()])) + ) + + # init torch distributed process group + store = torch.distributed.FileStore("/tmp/filestore", nprocs) + torch.distributed.init_process_group( + world_size=nprocs, rank=gpu, store=store, backend="gloo" + ) + + print( + f"initialize torch training process, rank: {bagua.get_rank()}, world_size: {bagua.get_world_size()}" + ) + + # wrap model + model = LowPrecDecentralizedAlgor( + model, optimizer, hierarchical, communication_interval + ) + + for _ in range(10): + data = torch.randn(4, 2) + target = torch.randn(4, 4) + + optimizer.zero_grad() + output = model(data) + loss = loss_fn(output, target) + + loss.backward() + model.step() + + ret.bucket_weight.copy_(flatten([param.data for param in model.parameters()])) class Result(object): def __init__(self): - self.bucket_weight = 0 - self.weight = 0 - self.left_peer_weight = 0 - self.right_peer_weight = 0 + model = Net() + self.init_weight = torch.Tensor([0.0]).share_memory_() + self.bucket_weight = flatten( + [param.data for param in model.parameters()] + ).share_memory_() + self.weight = torch.Tensor([0.0]).share_memory_() + self.left_peer_weight = torch.Tensor([0.0]).share_memory_() + self.right_peer_weight = torch.Tensor([0.0]).share_memory_() + + +class LowPrecDecentralizedAlgor(nn.Module): + def __init__(self, module, optimizer, hierarchical, communication_interval): + super(LowPrecDecentralizedAlgor, self).__init__() + self.module = module + self.optimizer = optimizer + self.hierarchical = hierarchical + self.communication_interval = communication_interval + self.step_count = 0 + self.compressor = MinMaxUInt8() + + assert torch.distributed.is_initialized() + + self.rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + + weights = [param.data for param in self.module.parameters()] + apply_flattened_call(weights, lambda x: torch.distributed.broadcast(x, 0)) + + self.weight = flatten(weights) + self.left_peer_weight = self.weight.detach().clone() + self.right_peer_weight = self.weight.detach().clone() + + def forward(self, *inputs, **kwargs): + result = self.module(*inputs, **kwargs) + return result + + def step(self): + self.optimizer.step() + + def allreduce_fn(x): + torch.distributed.allreduce(x) + x /= self.world_size + + def communicate_with_peers(_buffer): + left_buffer = torch.zeros_like(_buffer, device=_buffer.device) + right_buffer = torch.zeros_like(_buffer, device=_buffer.device) + + left_peer_rank = (self.rank + self.world_size - 1) % self.world_size + right_peer_rank = (self.rank + 1) % self.world_size + + requests = [] + requests.append(torch.distributed.isend(_buffer, left_peer_rank)) + requests.append(torch.distributed.isend(_buffer, right_peer_rank)) + requests.append(torch.distributed.irecv(left_buffer, left_peer_rank)) + requests.append(torch.distributed.irecv(right_buffer, right_peer_rank)) + + for req in requests: + req.wait() + + return left_buffer, right_buffer + + def update_weight_fn(x): + diff = ( + x + + 1 / 3 * self.left_peer_weight + + 1 / 3 * self.right_peer_weight + - 5 / 3 * self.weight + ) + + _min, _max, compressed_buffer = self.compressor.compress(diff) + + left_compressed_buffer, right_compressed_buffer = communicate_with_peers( + compressed_buffer + ) + left_min, right_min = communicate_with_peers(_min) + left_max, right_max = communicate_with_peers(_max) + + left_decompressed = self.compressor.decompress( + left_min, left_max, left_compressed_buffer + ) + right_decompressed = self.compressor.decompress( + right_min, right_max, right_compressed_buffer + ) + + self.left_peer_weight += left_decompressed + self.right_peer_weight += right_decompressed + + decompressed = self.compressor.decompress(_min, _max, compressed_buffer) + x += decompressed + + if self.step_count % self.communication_interval == 0: + weights = [param.data for param in self.module.parameters()] + if self.hierarchical: + apply_flattened_call(weights, allreduce_fn) + apply_flattened_call( + weights, lambda x: torch.distributed.broadcast(x, 0) + ) + else: + apply_flattened_call(weights, update_weight_fn) + self.weight = flatten(weights) + + self.step_count += 1 class TestLowPrecisionDecentralized(unittest.TestCase): @@ -103,21 +259,70 @@ def run_test_locally(self, hierarchical, communication_interval): if hierarchical: # all workers have equal weights self.assertTrue( - results[rank].bucket_weight == results[left_peer_rank].bucket_weight + torch.equal( + results[rank].bucket_weight, + results[left_peer_rank].bucket_weight, + ) ) else: self.assertTrue( - results[rank].weight == results[left_peer_rank].right_peer_weight + results[rank].weight.item() + == results[left_peer_rank].right_peer_weight.item() ) self.assertTrue( - results[rank].weight == results[right_peer_rank].left_peer_weight + results[rank].weight.item() + == results[right_peer_rank].left_peer_weight.item() + ) + + def run_diff_locally(self, hierarchical, communication_interval): + if not torch.cuda.is_available(): + print("skip tests since cuda is not available") + return + + nprocs = torch.cuda.device_count() + os.environ["WORLD_SIZE"] = str(nprocs) + os.environ["LOCAL_WORLD_SIZE"] = str(nprocs) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(find_free_port()) + os.environ["BAGUA_SERVICE_PORT"] = str(find_free_port()) + + torch_results = [Result() for _ in range(nprocs)] + mp.spawn( + run_torch_model, + nprocs=nprocs, + args=(nprocs, hierarchical, communication_interval, torch_results), + ) + + bagua_results = [Result() for _ in range(nprocs)] + mp.spawn( + run_model, + nprocs=nprocs, + args=(nprocs, hierarchical, communication_interval, bagua_results), + ) + + for rank in range(nprocs): + self.assertTrue( + bagua_results[rank].init_weight.item() + == torch_results[rank].init_weight.item() + ) + + ret = torch.all( + torch.isclose( + bagua_results[rank].bucket_weight, + torch_results[rank].bucket_weight, ) + ).item() + + self.assertTrue(ret) def test_algorithm(self): self.run_test_locally(hierarchical=False, communication_interval=1) self.run_test_locally(hierarchical=False, communication_interval=2) self.run_test_locally(hierarchical=True, communication_interval=1) + def test_compare(self): + self.run_diff_locally(hierarchical=False, communication_interval=1) + if __name__ == "__main__": unittest.main() From 58e0b9629ec9c2af13daec898cbdd9fde04e43bf Mon Sep 17 00:00:00 2001 From: ritaw Date: Mon, 12 Jul 2021 21:15:00 +0800 Subject: [PATCH 16/23] update --- tests/internal/compressor.py | 35 +++- .../test_low_precision_decentralized.py | 194 +++++++++--------- 2 files changed, 126 insertions(+), 103 deletions(-) diff --git a/tests/internal/compressor.py b/tests/internal/compressor.py index 61f73a0a7..783c4601b 100644 --- a/tests/internal/compressor.py +++ b/tests/internal/compressor.py @@ -4,22 +4,39 @@ class MinMaxUInt8: def __init__(self): self.eps = 1e-7 - self.quantization_level = 255 + self.quantization_level = 255.0 - def compress(self, tensor): - _max = torch.max(tensor) + def compress(self, tensor: torch.Tensor) -> (torch.Tensor, torch.Tensor): _min = torch.min(tensor) - + _max = torch.max(tensor) scale = self.quantization_level / (_max - _min + self.eps) upper_bound = torch.round(_max * scale) lower_bound = upper_bound - self.quantization_level - level = (tensor * scale).int() - level = torch.clamp(level, min=lower_bound) - return _min, _max, level - lower_bound + level = torch.round(tensor * scale) + level = torch.clamp(level, max=upper_bound) + + _minmax = torch.zeros(2, dtype=tensor.dtype, device=tensor.device) + _minmax[0] = _min + _minmax[1] = _max + return _minmax, (level - lower_bound).to(torch.uint8) + + def decompress( + self, _minmax: torch.Tensor, compressed: torch.Tensor + ) -> torch.Tensor: + _min = _minmax[0] + _max = _minmax[1] - def decompress(self, _min, _max, compressed): scale = self.quantization_level / (_max - _min + self.eps) upper_bound = torch.round(_max * scale) lower_bound = upper_bound - self.quantization_level - return (compressed + lower_bound) / scale + return (compressed.float() + lower_bound) / scale + + +if __name__ == "__main__": + x = torch.rand(100).cuda() + _minmax, compressed = MinMaxUInt8().compress(x) + decompressed = MinMaxUInt8().decompress(_minmax, compressed) + + diff = x - decompressed + print(f"{diff}, {torch.norm(diff)}") diff --git a/tests/torch_api/test_low_precision_decentralized.py b/tests/torch_api/test_low_precision_decentralized.py index a662b04df..6970bb0cf 100644 --- a/tests/torch_api/test_low_precision_decentralized.py +++ b/tests/torch_api/test_low_precision_decentralized.py @@ -3,11 +3,12 @@ import torch.nn.functional as F from tests.internal.common_utils import find_free_port from tests.internal.compressor import MinMaxUInt8 -import bagua.torch_api as bagua import unittest import torch.multiprocessing as mp import os from bagua.torch_api.utils import apply_flattened_call, flatten +from bagua.torch_api.communication import get_backend +import bagua.torch_api as bagua class Net(nn.Module): @@ -25,37 +26,27 @@ def forward(self, x): return F.softmax(x, dim=1) -def _init_env(gpu): +def _init_env(rank): + # set deterministic torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True - torch.manual_seed(47) + torch.manual_seed(rank) # initialize subprocess env - os.environ["RANK"] = str(gpu) - os.environ["LOCAL_RANK"] = str(gpu) - + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) -def run_model(gpu, nprocs, hierarchical, communication_interval, results): - _init_env(gpu) - - # construct model and optimizer, etc. - model = Net() - optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - loss_fn = nn.MSELoss() - ret = results[gpu] - ret.init_weight.copy_( - torch.norm(flatten([param.data for param in model.parameters()])) - ) +def run_model(rank, nprocs, hierarchical, communication_interval, results): + _init_env(rank) # init bagua distributed process group - torch.cuda.set_device(bagua.get_local_rank()) + torch.cuda.set_device(rank) bagua.init_process_group() - print( - f"initialize bagua training process, rank: {bagua.get_rank()}, world_size: {bagua.get_world_size()}" - ) - - model.cuda() + # construct model and optimizer, etc. + model = Net().cuda() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + loss_fn = nn.MSELoss() # wrap model model = model.with_bagua( @@ -65,7 +56,12 @@ def run_model(gpu, nprocs, hierarchical, communication_interval, results): ), ) - for _ in range(10): + ret = results[rank] + bucket = model.bagua_buckets[0] + + ret.init_weight.copy_(flatten([param.data for param in model.parameters()])) + + for epoch in range(10): data = torch.randn(4, 2).cuda() target = torch.randn(4, 4).cuda() @@ -76,45 +72,38 @@ def run_model(gpu, nprocs, hierarchical, communication_interval, results): loss.backward() optimizer.step() - ret = results[gpu] - bucket = model.bagua_buckets[0] + torch.cuda.synchronize() ret.bucket_weight.copy_(flatten([param.data for param in model.parameters()])) ret.weight.copy_(torch.norm(bucket._weight)) ret.left_peer_weight.copy_(torch.norm(bucket._left_peer_weight)) ret.right_peer_weight.copy_(torch.norm(bucket._right_peer_weight)) -def run_torch_model(gpu, nprocs, hierarchical, communication_interval, results): - _init_env(gpu) - - # construct model and optimizer, etc. - model = Net() - optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - loss_fn = nn.MSELoss() - - ret = results[gpu] - ret.init_weight.copy_( - torch.norm(flatten([param.data for param in model.parameters()])) - ) +def run_torch_model(rank, nprocs, hierarchical, communication_interval, results): + _init_env(rank) # init torch distributed process group store = torch.distributed.FileStore("/tmp/filestore", nprocs) torch.distributed.init_process_group( - world_size=nprocs, rank=gpu, store=store, backend="gloo" + world_size=nprocs, rank=rank, store=store, backend="gloo" ) - print( - f"initialize torch training process, rank: {bagua.get_rank()}, world_size: {bagua.get_world_size()}" - ) + # construct model and optimizer, etc. + model = Net().cuda() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + loss_fn = nn.MSELoss() # wrap model model = LowPrecDecentralizedAlgor( model, optimizer, hierarchical, communication_interval ) - for _ in range(10): - data = torch.randn(4, 2) - target = torch.randn(4, 4) + ret = results[rank] + ret.init_weight.copy_(flatten([param.data for param in model.parameters()])) + + for epoch in range(10): + data = torch.randn(4, 2).cuda() + target = torch.randn(4, 4).cuda() optimizer.zero_grad() output = model(data) @@ -129,13 +118,15 @@ def run_torch_model(gpu, nprocs, hierarchical, communication_interval, results): class Result(object): def __init__(self): model = Net() - self.init_weight = torch.Tensor([0.0]).share_memory_() + self.init_weight = flatten( + [torch.zeros_like(param.data) for param in model.parameters()] + ) self.bucket_weight = flatten( - [param.data for param in model.parameters()] - ).share_memory_() - self.weight = torch.Tensor([0.0]).share_memory_() - self.left_peer_weight = torch.Tensor([0.0]).share_memory_() - self.right_peer_weight = torch.Tensor([0.0]).share_memory_() + [torch.zeros_like(param.data) for param in model.parameters()] + ) + self.weight = torch.Tensor([0.0]) + self.left_peer_weight = torch.Tensor([0.0]) + self.right_peer_weight = torch.Tensor([0.0]) class LowPrecDecentralizedAlgor(nn.Module): @@ -145,21 +136,25 @@ def __init__(self, module, optimizer, hierarchical, communication_interval): self.optimizer = optimizer self.hierarchical = hierarchical self.communication_interval = communication_interval - self.step_count = 0 self.compressor = MinMaxUInt8() + self.step_count = 0 assert torch.distributed.is_initialized() self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() - weights = [param.data for param in self.module.parameters()] - apply_flattened_call(weights, lambda x: torch.distributed.broadcast(x, 0)) + # broadcast parameters + for param in self.module.parameters(): + torch.distributed.broadcast(param.data, src=0) - self.weight = flatten(weights) + self.weight = flatten(self._build_params()) self.left_peer_weight = self.weight.detach().clone() self.right_peer_weight = self.weight.detach().clone() + def _build_params(self): + return [param.data for param in list(self.module.parameters()).__reversed__()] + def forward(self, *inputs, **kwargs): result = self.module(*inputs, **kwargs) return result @@ -168,28 +163,34 @@ def step(self): self.optimizer.step() def allreduce_fn(x): - torch.distributed.allreduce(x) + torch.distributed.all_reduce(x) x /= self.world_size - def communicate_with_peers(_buffer): - left_buffer = torch.zeros_like(_buffer, device=_buffer.device) - right_buffer = torch.zeros_like(_buffer, device=_buffer.device) + def communicate_with_peers( + tensor: torch.Tensor, comm_size: int + ) -> (torch.Tensor, torch.Tensor): + if comm_size == 1: + return tensor, tensor + + tensor = tensor.cpu() + left_tensor = torch.zeros_like(tensor) + right_tensor = torch.zeros_like(tensor) - left_peer_rank = (self.rank + self.world_size - 1) % self.world_size - right_peer_rank = (self.rank + 1) % self.world_size + left_peer_rank = (self.rank + self.world_size - 1) % comm_size + right_peer_rank = (self.rank + 1) % comm_size requests = [] - requests.append(torch.distributed.isend(_buffer, left_peer_rank)) - requests.append(torch.distributed.isend(_buffer, right_peer_rank)) - requests.append(torch.distributed.irecv(left_buffer, left_peer_rank)) - requests.append(torch.distributed.irecv(right_buffer, right_peer_rank)) + requests.append(torch.distributed.isend(tensor, left_peer_rank)) + requests.append(torch.distributed.isend(tensor, right_peer_rank)) + requests.append(torch.distributed.irecv(left_tensor, left_peer_rank)) + requests.append(torch.distributed.irecv(right_tensor, right_peer_rank)) for req in requests: req.wait() - return left_buffer, right_buffer + return left_tensor.cuda(), right_tensor.cuda() - def update_weight_fn(x): + def update_weight_fn(x, comm_size): diff = ( x + 1 / 3 * self.left_peer_weight @@ -197,37 +198,37 @@ def update_weight_fn(x): - 5 / 3 * self.weight ) - _min, _max, compressed_buffer = self.compressor.compress(diff) - - left_compressed_buffer, right_compressed_buffer = communicate_with_peers( - compressed_buffer + minmax, compressed = self.compressor.compress(diff) + left_compressed, right_compressed = communicate_with_peers( + compressed, comm_size ) - left_min, right_min = communicate_with_peers(_min) - left_max, right_max = communicate_with_peers(_max) + left_minmax, right_minmax = communicate_with_peers(minmax, comm_size) - left_decompressed = self.compressor.decompress( - left_min, left_max, left_compressed_buffer + self.left_peer_weight += self.compressor.decompress( + left_minmax, left_compressed ) - right_decompressed = self.compressor.decompress( - right_min, right_max, right_compressed_buffer + self.right_peer_weight += self.compressor.decompress( + right_minmax, right_compressed ) - self.left_peer_weight += left_decompressed - self.right_peer_weight += right_decompressed + diff = self.compressor.decompress(minmax, compressed) + x.copy_(self.weight + diff) - decompressed = self.compressor.decompress(_min, _max, compressed_buffer) - x += decompressed + self.weight.copy_(x) if self.step_count % self.communication_interval == 0: - weights = [param.data for param in self.module.parameters()] + weights = self._build_params() if self.hierarchical: apply_flattened_call(weights, allreduce_fn) + if self.rank == 0: + apply_flattened_call(weights, lambda x: update_weight_fn(x, 1)) apply_flattened_call( weights, lambda x: torch.distributed.broadcast(x, 0) ) else: - apply_flattened_call(weights, update_weight_fn) - self.weight = flatten(weights) + apply_flattened_call( + weights, lambda x: update_weight_fn(x, self.world_size) + ) self.step_count += 1 @@ -302,20 +303,25 @@ def run_diff_locally(self, hierarchical, communication_interval): for rank in range(nprocs): self.assertTrue( - bagua_results[rank].init_weight.item() - == torch_results[rank].init_weight.item() + torch.all( + torch.isclose( + bagua_results[rank].init_weight, + torch_results[rank].init_weight, + ) + ).item() ) - ret = torch.all( - torch.isclose( - bagua_results[rank].bucket_weight, - torch_results[rank].bucket_weight, - ) - ).item() - - self.assertTrue(ret) + self.assertTrue( + torch.all( + torch.isclose( + bagua_results[rank].bucket_weight, + torch_results[rank].bucket_weight, + ) + ).item() + ) def test_algorithm(self): + return self.run_test_locally(hierarchical=False, communication_interval=1) self.run_test_locally(hierarchical=False, communication_interval=2) self.run_test_locally(hierarchical=True, communication_interval=1) From 9096ae3e4a57ec5847dcdc2a2d7e5a376beb99b3 Mon Sep 17 00:00:00 2001 From: ritaw Date: Mon, 12 Jul 2021 21:26:15 +0800 Subject: [PATCH 17/23] format --- .../test_low_precision_decentralized.py | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/tests/torch_api/test_low_precision_decentralized.py b/tests/torch_api/test_low_precision_decentralized.py index 6970bb0cf..d0b3deeba 100644 --- a/tests/torch_api/test_low_precision_decentralized.py +++ b/tests/torch_api/test_low_precision_decentralized.py @@ -7,7 +7,6 @@ import torch.multiprocessing as mp import os from bagua.torch_api.utils import apply_flattened_call, flatten -from bagua.torch_api.communication import get_backend import bagua.torch_api as bagua @@ -191,14 +190,11 @@ def communicate_with_peers( return left_tensor.cuda(), right_tensor.cuda() def update_weight_fn(x, comm_size): - diff = ( - x - + 1 / 3 * self.left_peer_weight - + 1 / 3 * self.right_peer_weight - - 5 / 3 * self.weight - ) + x += 1 / 3 * self.left_peer_weight + x += 1 / 3 * self.right_peer_weight + x -= 5 / 3 * self.weight - minmax, compressed = self.compressor.compress(diff) + minmax, compressed = self.compressor.compress(x) left_compressed, right_compressed = communicate_with_peers( compressed, comm_size ) @@ -267,12 +263,10 @@ def run_test_locally(self, hierarchical, communication_interval): ) else: self.assertTrue( - results[rank].weight.item() - == results[left_peer_rank].right_peer_weight.item() + results[rank].weight.item() == results[left_peer_rank].right_peer_weight.item() ) self.assertTrue( - results[rank].weight.item() - == results[right_peer_rank].left_peer_weight.item() + results[rank].weight.item() == results[right_peer_rank].left_peer_weight.item() ) def run_diff_locally(self, hierarchical, communication_interval): @@ -327,7 +321,9 @@ def test_algorithm(self): self.run_test_locally(hierarchical=True, communication_interval=1) def test_compare(self): + # self.run_diff_locally(hierarchical=True, communication_interval=1) self.run_diff_locally(hierarchical=False, communication_interval=1) + self.run_diff_locally(hierarchical=False, communication_interval=2) if __name__ == "__main__": From d8d7e7db6d661e35cd768d074828398902862e11 Mon Sep 17 00:00:00 2001 From: wangraying <45031995+wangraying@users.noreply.github.com> Date: Mon, 12 Jul 2021 21:27:28 +0800 Subject: [PATCH 18/23] Update tests/torch_api/test_low_precision_decentralized.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- tests/torch_api/test_low_precision_decentralized.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/torch_api/test_low_precision_decentralized.py b/tests/torch_api/test_low_precision_decentralized.py index d0b3deeba..1e5eb7d6f 100644 --- a/tests/torch_api/test_low_precision_decentralized.py +++ b/tests/torch_api/test_low_precision_decentralized.py @@ -263,7 +263,8 @@ def run_test_locally(self, hierarchical, communication_interval): ) else: self.assertTrue( - results[rank].weight.item() == results[left_peer_rank].right_peer_weight.item() + results[rank].weight.item() + == results[left_peer_rank].right_peer_weight.item() ) self.assertTrue( results[rank].weight.item() == results[right_peer_rank].left_peer_weight.item() From 1af9b8c6d030b8d11e99ef876b33ff75cd3395f6 Mon Sep 17 00:00:00 2001 From: wangraying <45031995+wangraying@users.noreply.github.com> Date: Mon, 12 Jul 2021 21:27:34 +0800 Subject: [PATCH 19/23] Update tests/torch_api/test_low_precision_decentralized.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- tests/torch_api/test_low_precision_decentralized.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/torch_api/test_low_precision_decentralized.py b/tests/torch_api/test_low_precision_decentralized.py index 1e5eb7d6f..55bbb9338 100644 --- a/tests/torch_api/test_low_precision_decentralized.py +++ b/tests/torch_api/test_low_precision_decentralized.py @@ -267,7 +267,8 @@ def run_test_locally(self, hierarchical, communication_interval): == results[left_peer_rank].right_peer_weight.item() ) self.assertTrue( - results[rank].weight.item() == results[right_peer_rank].left_peer_weight.item() + results[rank].weight.item() + == results[right_peer_rank].left_peer_weight.item() ) def run_diff_locally(self, hierarchical, communication_interval): From d861e6af170ca2320c186d1977f30657c4be08a3 Mon Sep 17 00:00:00 2001 From: ritaw Date: Mon, 12 Jul 2021 21:30:42 +0800 Subject: [PATCH 20/23] ul;: --- tests/torch_api/test_low_precision_decentralized.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/torch_api/test_low_precision_decentralized.py b/tests/torch_api/test_low_precision_decentralized.py index 55bbb9338..3548929b7 100644 --- a/tests/torch_api/test_low_precision_decentralized.py +++ b/tests/torch_api/test_low_precision_decentralized.py @@ -263,12 +263,14 @@ def run_test_locally(self, hierarchical, communication_interval): ) else: self.assertTrue( - results[rank].weight.item() - == results[left_peer_rank].right_peer_weight.item() + torch.equal( + results[rank].weight, results[left_peer_rank].right_peer_weight + ) ) self.assertTrue( - results[rank].weight.item() - == results[right_peer_rank].left_peer_weight.item() + torch.equal( + results[rank].weight, results[right_peer_rank].left_peer_weight + ) ) def run_diff_locally(self, hierarchical, communication_interval): From d85ab77a633cf60b5e556af3a1170efd51f76189 Mon Sep 17 00:00:00 2001 From: ritaw Date: Tue, 13 Jul 2021 17:15:01 +0800 Subject: [PATCH 21/23] finally clear the diff --- .../test_low_precision_decentralized.py | 51 ++++++++++--------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/tests/torch_api/test_low_precision_decentralized.py b/tests/torch_api/test_low_precision_decentralized.py index 3548929b7..06fb20e56 100644 --- a/tests/torch_api/test_low_precision_decentralized.py +++ b/tests/torch_api/test_low_precision_decentralized.py @@ -71,20 +71,22 @@ def run_model(rank, nprocs, hierarchical, communication_interval, results): loss.backward() optimizer.step() - torch.cuda.synchronize() ret.bucket_weight.copy_(flatten([param.data for param in model.parameters()])) ret.weight.copy_(torch.norm(bucket._weight)) ret.left_peer_weight.copy_(torch.norm(bucket._left_peer_weight)) ret.right_peer_weight.copy_(torch.norm(bucket._right_peer_weight)) -def run_torch_model(rank, nprocs, hierarchical, communication_interval, results): +def run_torch_model( + rank, nprocs, hierarchical, communication_interval, results, backend +): _init_env(rank) # init torch distributed process group + torch.cuda.set_device(rank) store = torch.distributed.FileStore("/tmp/filestore", nprocs) torch.distributed.init_process_group( - world_size=nprocs, rank=rank, store=store, backend="gloo" + world_size=nprocs, rank=rank, store=store, backend=backend ) # construct model and optimizer, etc. @@ -147,9 +149,9 @@ def __init__(self, module, optimizer, hierarchical, communication_interval): for param in self.module.parameters(): torch.distributed.broadcast(param.data, src=0) - self.weight = flatten(self._build_params()) - self.left_peer_weight = self.weight.detach().clone() - self.right_peer_weight = self.weight.detach().clone() + self.weight = flatten(self._build_params()).cuda() + self.left_peer_weight = self.weight.detach().clone().cuda() + self.right_peer_weight = self.weight.detach().clone().cuda() def _build_params(self): return [param.data for param in list(self.module.parameters()).__reversed__()] @@ -161,10 +163,6 @@ def forward(self, *inputs, **kwargs): def step(self): self.optimizer.step() - def allreduce_fn(x): - torch.distributed.all_reduce(x) - x /= self.world_size - def communicate_with_peers( tensor: torch.Tensor, comm_size: int ) -> (torch.Tensor, torch.Tensor): @@ -209,18 +207,20 @@ def update_weight_fn(x, comm_size): diff = self.compressor.decompress(minmax, compressed) x.copy_(self.weight + diff) - self.weight.copy_(x) + def hierarchical_update_weight_fn(x): + torch.distributed.reduce(x, dst=0) + if self.rank == 0: + x /= self.world_size + update_weight_fn(x, comm_size=1) + + torch.distributed.broadcast(x, 0) + if self.step_count % self.communication_interval == 0: weights = self._build_params() if self.hierarchical: - apply_flattened_call(weights, allreduce_fn) - if self.rank == 0: - apply_flattened_call(weights, lambda x: update_weight_fn(x, 1)) - apply_flattened_call( - weights, lambda x: torch.distributed.broadcast(x, 0) - ) + apply_flattened_call(weights, hierarchical_update_weight_fn) else: apply_flattened_call( weights, lambda x: update_weight_fn(x, self.world_size) @@ -273,7 +273,7 @@ def run_test_locally(self, hierarchical, communication_interval): ) ) - def run_diff_locally(self, hierarchical, communication_interval): + def run_diff_locally(self, hierarchical, communication_interval, backend): if not torch.cuda.is_available(): print("skip tests since cuda is not available") return @@ -289,7 +289,7 @@ def run_diff_locally(self, hierarchical, communication_interval): mp.spawn( run_torch_model, nprocs=nprocs, - args=(nprocs, hierarchical, communication_interval, torch_results), + args=(nprocs, hierarchical, communication_interval, torch_results, backend), ) bagua_results = [Result() for _ in range(nprocs)] @@ -319,15 +319,20 @@ def run_diff_locally(self, hierarchical, communication_interval): ) def test_algorithm(self): - return self.run_test_locally(hierarchical=False, communication_interval=1) self.run_test_locally(hierarchical=False, communication_interval=2) self.run_test_locally(hierarchical=True, communication_interval=1) def test_compare(self): - # self.run_diff_locally(hierarchical=True, communication_interval=1) - self.run_diff_locally(hierarchical=False, communication_interval=1) - self.run_diff_locally(hierarchical=False, communication_interval=2) + self.run_diff_locally( + hierarchical=False, communication_interval=1, backend="gloo" + ) + self.run_diff_locally( + hierarchical=False, communication_interval=2, backend="gloo" + ) + self.run_diff_locally( + hierarchical=True, communication_interval=1, backend="nccl" + ) if __name__ == "__main__": From 93ea6c6522dbc5a5e92cff0f6aa5e2b93993ab65 Mon Sep 17 00:00:00 2001 From: ritaw Date: Tue, 13 Jul 2021 17:31:03 +0800 Subject: [PATCH 22/23] update --- bagua/torch_api/algorithms/decentralized.py | 8 +++----- bagua/torch_api/bucket.py | 19 ++++++++----------- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/bagua/torch_api/algorithms/decentralized.py b/bagua/torch_api/algorithms/decentralized.py index 7a466826a..7de3500a3 100644 --- a/bagua/torch_api/algorithms/decentralized.py +++ b/bagua/torch_api/algorithms/decentralized.py @@ -133,11 +133,9 @@ def hook(optimizer: torch.optim.Optimizer): return hook def _init_states(self, bucket: BaguaBucket): - bucket_flattened_tensor = bucket.flattened_tensor() - - weight_tensor = bucket_flattened_tensor.detach().clone() - left_peer_weight_tensor = bucket_flattened_tensor.detach().clone() - right_peer_weight_tensor = bucket_flattened_tensor.detach().clone() + weight_tensor = bucket.flattened_tensor() + left_peer_weight_tensor = bucket.flattened_tensor() + right_peer_weight_tensor = bucket.flattened_tensor() bucket._weight = weight_tensor.to_bagua_tensor("weight") bucket._left_peer_weight = left_peer_weight_tensor.to_bagua_tensor( diff --git a/bagua/torch_api/bucket.py b/bagua/torch_api/bucket.py index 8fe9cda46..cba1c60d5 100644 --- a/bagua/torch_api/bucket.py +++ b/bagua/torch_api/bucket.py @@ -71,14 +71,10 @@ def __init__( for tensor in self._all_tensors: tensor._bagua_bucket = self - def flattened_tensor(self) -> torch.Tensor: + def flattened_tensor(self) -> BaguaTensor: """ Returns a tensor contiguous in memory which contains the same data as `self` tensors and padding tensor (if exists). - If `self` tensors and padding tensor are already flattened, this function returns a tensor corresponding to their - underlying storage. """ - if self.flatten: - return self.backend_tensor total_size = 0 for tensor in self._all_tensors: @@ -225,18 +221,19 @@ def append_decentralized_synchronous_op( hierarchical (bool): Enable hierarchical communication. Which means the GPUs on the same machine will communicate will each other first. After that, machines do inter-node communication. This can boost performance when the inter-node communication cost is high. - peer_selection_mode (str): Can be "all" or "shift_one" for full precision decentralized operation, while "ring" for + peer_selection_mode (str): Can be "all" or "shift_one" for full precision decentralized operation, "ring" for low precision decentralized operation. "all" means all workers' weights are averaged in each communication step. "shift_one" means each worker selects a different peer to do weights average in each communication step. "ring" means all workers are connected into a ring, and each worker communicate with its neighbors. communication_interval (int): Number of iterations between two communication steps. compression: If not ``None``, the tensors will be compressed for communication. Currently "MinMaxUInt8" is supported. - weight (torch.Tensor): Local model of current worker, required for low precision decentralized operation. - left_peer_weight (torch.Tensor): Model replica of current worker's connected left peer, required for low - precision decentralized operation. - right_peer_weight (torch.Tensor): Model replica of current worker's connected right peer, required for - low precision decentralized operation. + weight (torch.Tensor): Local model of current worker, a flattened tensor containing the same data as the local model + weights of current worker, required for low precision decentralized operation. + left_peer_weight (torch.Tensor): Model replica of current worker's connected left peer, a flattened tensor containing + the same data as model weights of left peer, required for low precision decentralized operation. + right_peer_weight (torch.Tensor): Model replica of current worker's connected right peer, similarly as `left_peer_weight`, + required for low precision decentralized operation. Returns: The bucket itself. """ From e2c857d0fe2798fc456e56661e4185e0285b454f Mon Sep 17 00:00:00 2001 From: ritaw Date: Tue, 13 Jul 2021 18:57:32 +0800 Subject: [PATCH 23/23] update --- bagua/torch_api/bucket.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/bagua/torch_api/bucket.py b/bagua/torch_api/bucket.py index cba1c60d5..7388c14f2 100644 --- a/bagua/torch_api/bucket.py +++ b/bagua/torch_api/bucket.py @@ -207,9 +207,9 @@ def append_decentralized_synchronous_op( peer_selection_mode: str = "all", communication_interval: int = 1, compression: Optional[str] = None, - weight: Optional[torch.Tensor] = None, - left_peer_weight: Optional[torch.Tensor] = None, - right_peer_weight: Optional[torch.Tensor] = None, + weight: Optional[BaguaTensor] = None, + left_peer_weight: Optional[BaguaTensor] = None, + right_peer_weight: Optional[BaguaTensor] = None, ) -> BaguaBucket: """ Append a decentralized synchronous operation to a bucket. It will do gossipy style model averaging among workers. @@ -228,11 +228,11 @@ def append_decentralized_synchronous_op( communication_interval (int): Number of iterations between two communication steps. compression: If not ``None``, the tensors will be compressed for communication. Currently "MinMaxUInt8" is supported. - weight (torch.Tensor): Local model of current worker, a flattened tensor containing the same data as the local model + weight (BaguaTensor): Local model of current worker, a flattened tensor containing the same data as the local model weights of current worker, required for low precision decentralized operation. - left_peer_weight (torch.Tensor): Model replica of current worker's connected left peer, a flattened tensor containing + left_peer_weight (BaguaTensor): Model replica of current worker's connected left peer, a flattened tensor containing the same data as model weights of left peer, required for low precision decentralized operation. - right_peer_weight (torch.Tensor): Model replica of current worker's connected right peer, similarly as `left_peer_weight`, + right_peer_weight (BaguaTensor): Model replica of current worker's connected right peer, similarly as `left_peer_weight`, required for low precision decentralized operation. Returns: The bucket itself.