diff --git a/bagua/torch_api/data_parallel/bagua_distributed.py b/bagua/torch_api/data_parallel/bagua_distributed.py index f58483ff9..a2932eeae 100644 --- a/bagua/torch_api/data_parallel/bagua_distributed.py +++ b/bagua/torch_api/data_parallel/bagua_distributed.py @@ -25,7 +25,6 @@ class BaguaDistributedDataParallel: - def __init__( self, module: Module, @@ -37,12 +36,7 @@ def __init__( find_unused_parameters: bool = False, ) -> None: self.module = module - if bagua_module_name is None: - self.bagua_module_name = "{}_{}".format( - self.__class__.__name__, id(module) - ) - else: - self.bagua_module_name = bagua_module_name + self.bagua_module_name = bagua_module_name self.bagua_optimizers = optimizers self.bagua_algorithm = algorithm.reify(process_group) @@ -75,8 +69,8 @@ def __init__( self._bagua_autotune_completed = False class BaguaDistributedDataParallelStates: - """Empty class whose instances are used for keeping track of BaguaDistributedDataParallel's internal states. - """ + """Empty class whose instances are used for keeping track of BaguaDistributedDataParallel's internal states.""" + pass if hasattr(self.module, "_bagua_states"): @@ -180,7 +174,10 @@ def bagua_build_params(self) -> List[Tuple[str, torch.nn.Parameter]]: ] if self.find_unused_parameters and len(self.autograd_graph_params) != 0: - modules_and_parameters = filter(lambda it: it[1][0] in self.autograd_graph_params, modules_and_parameters) + modules_and_parameters = filter( + lambda it: it[1][0] in self.autograd_graph_params, + modules_and_parameters, + ) # Deduplicate any parameters that might be shared across child modules. memo = set() @@ -425,7 +422,9 @@ def _bagua_reset_algorithm_buckets(self): self._bagua_cleanup_algorithm() raw_buckets = self._bagua_autotune_get_buckets() self.bagua_buckets.extend( - self.bagua_algorithm.tensors_to_buckets(raw_buckets, self.gradient_as_bucket_view) + self.bagua_algorithm.tensors_to_buckets( + raw_buckets, self.gradient_as_bucket_view + ) ) for name, param in self.module.named_parameters(): @@ -448,7 +447,10 @@ def real_post_backward_hook(*unused): ) if self.find_unused_parameters: - if set(self.autograd_graph_params.keys()) != self.params_in_use: + if ( + set(self.autograd_graph_params.keys()) + != self.params_in_use + ): self.rebuild_buckets() self.delay_allreduce() diff --git a/bagua/torch_api/data_parallel/distributed.py b/bagua/torch_api/data_parallel/distributed.py index 462b586cb..21049e770 100644 --- a/bagua/torch_api/data_parallel/distributed.py +++ b/bagua/torch_api/data_parallel/distributed.py @@ -60,7 +60,9 @@ def will_sync_module_buffers(self): raise NotImplementedError -def to_bagua_process_group(process_group: Union[TorchProcessGroup, BaguaProcessGroup, None] = None): +def to_bagua_process_group( + process_group: Union[TorchProcessGroup, BaguaProcessGroup, None] = None +): """Convert a PyTorch process group to a Bagua process group. Args: @@ -77,7 +79,10 @@ def to_bagua_process_group(process_group: Union[TorchProcessGroup, BaguaProcessG if process_group is None: return _get_default_group() - elif type(process_group) in [TorchProcessGroup, torch._C._distributed_c10d.ProcessGroupNCCL]: + elif type(process_group) in [ + TorchProcessGroup, + torch.distributed.ProcessGroupNCCL, + ]: return process_group.bagua_patch().bagua_pg # pytype: disable=attribute-error elif type(process_group) is BaguaProcessGroup: return process_group @@ -106,8 +111,7 @@ def __init__( optimizers: List[torch.optim.Optimizer] = [], algorithm: "bagua.torch_api.algorithms.Algorithm" = GradientAllReduceAlgorithm(), ) -> None: - """Bagua internal use function. Construction use :class:`DistributedDataParallel`. - """ + """Bagua internal use function. Construction use :class:`DistributedDataParallel`.""" super(DistributedDataParallel_V1_9_0, self).__init__() assert any((p.requires_grad for p in module.parameters())), ( "DistributedDataParallel is not needed when a module " @@ -136,11 +140,19 @@ def __init__( self.broadcast_buffers = broadcast_buffers self.find_unused_parameters = find_unused_parameters + if not hasattr(module, "_bagua_module_name"): + module._bagua_module_name = "{}_{}".format( + self.__class__.__name__, id(module) + ) + self.inner = BaguaDistributedDataParallel( - self.module, optimizers, algorithm, + self.module, + optimizers, + algorithm, process_group=to_bagua_process_group(process_group), gradient_as_bucket_view=gradient_as_bucket_view, find_unused_parameters=find_unused_parameters, + bagua_module_name=module.bagua_module_name, ) @property @@ -152,8 +164,7 @@ def require_backward_grad_sync(self): @property def parameters_to_ignore(self): - """Parameters that will be ignored in DDP. - """ + """Parameters that will be ignored in DDP.""" return self.inner.parameters_to_ignore def forward(self, *inputs, **kwargs): @@ -191,13 +202,6 @@ def bagua_algorithm(self): """ return self.inner.bagua_algorithm - @property - def bagua_module_name(self): - """ - The module's name. Bagua uses the module name to distinguish different modules. - """ - return self.inner.bagua_module_name - @property def bagua_optimizers(self): """ @@ -226,7 +230,7 @@ def DistributedDataParallel( gradient_as_bucket_view: bool = True, # The followings are parameters for Bagua optimizers: List[torch.optim.Optimizer] = [], - algorithm: "bagua.torch_api.algorithms.Algorithm" = GradientAllReduceAlgorithm() + algorithm: "bagua.torch_api.algorithms.Algorithm" = GradientAllReduceAlgorithm(), ) -> Union[TorchDistributedDataParallel, DistributedDataParallel_V1_9_0]: r""" This function provides a `PyTorch DDP `_ compatible @@ -326,7 +330,8 @@ def DistributedDataParallel( " have not been supported yet. Bagua has automatically " "fallback to upstream PyTorch DistributedDataParallel " "implementation. If this is unexpected, please submit " - "an issue to https://github.com/BaguaSys/bagua. Thanks.") + "an issue to https://github.com/BaguaSys/bagua. Thanks." + ) return TorchDistributedDataParallel( module=module, device_ids=device_ids, diff --git a/bagua/torch_api/distributed.py b/bagua/torch_api/distributed.py index af4d29d19..93576e197 100644 --- a/bagua/torch_api/distributed.py +++ b/bagua/torch_api/distributed.py @@ -99,16 +99,15 @@ def with_bagua( # pytype: disable=module-attr if process_group is None: process_group = _get_default_group() - bagua_module_name = None - if hasattr(self, "bagua_ddp"): - bagua_module_name = self.bagua_ddp.bagua_module_name + if not hasattr(self, "_bagua_module_name"): + self._bagua_module_name = "{}_{}".format(self.__class__.__name__, id(self)) self.bagua_ddp = BaguaDistributedDataParallel( self, optimizers=optimizers, algorithm=algorithm, process_group=process_group, - bagua_module_name=bagua_module_name, + bagua_module_name=self.bagua_module_name, gradient_as_bucket_view=do_flatten, ) @@ -116,7 +115,14 @@ def with_bagua( # pytype: disable=module-attr @property def bagua_module_name(self): - return self.bagua_ddp.bagua_module_name + """ + The module's name. Bagua uses the module name to distinguish different modules. + """ + return self._bagua_module_name + + @bagua_module_name.setter + def bagua_module_name(self, name: str): + self._bagua_module_name = name @property def bagua_algorithm(self):