Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use bagua_module_name to identify different modules #438

Merged
merged 3 commits into from
Dec 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions bagua/torch_api/data_parallel/bagua_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@


class BaguaDistributedDataParallel:

def __init__(
self,
module: Module,
Expand All @@ -37,12 +36,7 @@ def __init__(
find_unused_parameters: bool = False,
) -> None:
self.module = module
if bagua_module_name is None:
self.bagua_module_name = "{}_{}".format(
self.__class__.__name__, id(module)
)
else:
self.bagua_module_name = bagua_module_name
self.bagua_module_name = bagua_module_name

self.bagua_optimizers = optimizers
self.bagua_algorithm = algorithm.reify(process_group)
Expand Down Expand Up @@ -75,8 +69,8 @@ def __init__(
self._bagua_autotune_completed = False

class BaguaDistributedDataParallelStates:
"""Empty class whose instances are used for keeping track of BaguaDistributedDataParallel's internal states.
"""
"""Empty class whose instances are used for keeping track of BaguaDistributedDataParallel's internal states."""

pass

if hasattr(self.module, "_bagua_states"):
Expand Down Expand Up @@ -180,7 +174,10 @@ def bagua_build_params(self) -> List[Tuple[str, torch.nn.Parameter]]:
]

if self.find_unused_parameters and len(self.autograd_graph_params) != 0:
modules_and_parameters = filter(lambda it: it[1][0] in self.autograd_graph_params, modules_and_parameters)
modules_and_parameters = filter(
lambda it: it[1][0] in self.autograd_graph_params,
modules_and_parameters,
)

# Deduplicate any parameters that might be shared across child modules.
memo = set()
Expand Down Expand Up @@ -425,7 +422,9 @@ def _bagua_reset_algorithm_buckets(self):
self._bagua_cleanup_algorithm()
raw_buckets = self._bagua_autotune_get_buckets()
self.bagua_buckets.extend(
self.bagua_algorithm.tensors_to_buckets(raw_buckets, self.gradient_as_bucket_view)
self.bagua_algorithm.tensors_to_buckets(
raw_buckets, self.gradient_as_bucket_view
)
)

for name, param in self.module.named_parameters():
Expand All @@ -448,7 +447,10 @@ def real_post_backward_hook(*unused):
)

if self.find_unused_parameters:
if set(self.autograd_graph_params.keys()) != self.params_in_use:
if (
set(self.autograd_graph_params.keys())
!= self.params_in_use
):
self.rebuild_buckets()
self.delay_allreduce()

Expand Down
37 changes: 21 additions & 16 deletions bagua/torch_api/data_parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def will_sync_module_buffers(self):
raise NotImplementedError


def to_bagua_process_group(process_group: Union[TorchProcessGroup, BaguaProcessGroup, None] = None):
def to_bagua_process_group(
process_group: Union[TorchProcessGroup, BaguaProcessGroup, None] = None
):
"""Convert a PyTorch process group to a Bagua process group.

Args:
Expand All @@ -77,7 +79,10 @@ def to_bagua_process_group(process_group: Union[TorchProcessGroup, BaguaProcessG

if process_group is None:
return _get_default_group()
elif type(process_group) in [TorchProcessGroup, torch._C._distributed_c10d.ProcessGroupNCCL]:
elif type(process_group) in [
TorchProcessGroup,
torch.distributed.ProcessGroupNCCL,
]:
return process_group.bagua_patch().bagua_pg # pytype: disable=attribute-error
elif type(process_group) is BaguaProcessGroup:
return process_group
Expand Down Expand Up @@ -106,8 +111,7 @@ def __init__(
optimizers: List[torch.optim.Optimizer] = [],
algorithm: "bagua.torch_api.algorithms.Algorithm" = GradientAllReduceAlgorithm(),
) -> None:
"""Bagua internal use function. Construction use :class:`DistributedDataParallel`.
"""
"""Bagua internal use function. Construction use :class:`DistributedDataParallel`."""
super(DistributedDataParallel_V1_9_0, self).__init__()
assert any((p.requires_grad for p in module.parameters())), (
"DistributedDataParallel is not needed when a module "
Expand Down Expand Up @@ -136,11 +140,19 @@ def __init__(
self.broadcast_buffers = broadcast_buffers
self.find_unused_parameters = find_unused_parameters

if not hasattr(module, "_bagua_module_name"):
module._bagua_module_name = "{}_{}".format(
self.__class__.__name__, id(module)
)

self.inner = BaguaDistributedDataParallel(
self.module, optimizers, algorithm,
self.module,
optimizers,
algorithm,
process_group=to_bagua_process_group(process_group),
gradient_as_bucket_view=gradient_as_bucket_view,
find_unused_parameters=find_unused_parameters,
bagua_module_name=module.bagua_module_name,
)

@property
Expand All @@ -152,8 +164,7 @@ def require_backward_grad_sync(self):

@property
def parameters_to_ignore(self):
"""Parameters that will be ignored in DDP.
"""
"""Parameters that will be ignored in DDP."""
return self.inner.parameters_to_ignore

def forward(self, *inputs, **kwargs):
Expand Down Expand Up @@ -191,13 +202,6 @@ def bagua_algorithm(self):
"""
return self.inner.bagua_algorithm

@property
def bagua_module_name(self):
"""
The module's name. Bagua uses the module name to distinguish different modules.
"""
return self.inner.bagua_module_name

@property
def bagua_optimizers(self):
"""
Expand Down Expand Up @@ -226,7 +230,7 @@ def DistributedDataParallel(
gradient_as_bucket_view: bool = True,
# The followings are parameters for Bagua
optimizers: List[torch.optim.Optimizer] = [],
algorithm: "bagua.torch_api.algorithms.Algorithm" = GradientAllReduceAlgorithm()
algorithm: "bagua.torch_api.algorithms.Algorithm" = GradientAllReduceAlgorithm(),
) -> Union[TorchDistributedDataParallel, DistributedDataParallel_V1_9_0]:
r"""
This function provides a `PyTorch DDP <https://github.com/pytorch/pytorch/blob/v1.9.0/torch/nn/parallel/distributed.py#L125>`_ compatible
Expand Down Expand Up @@ -326,7 +330,8 @@ def DistributedDataParallel(
" have not been supported yet. Bagua has automatically "
"fallback to upstream PyTorch DistributedDataParallel "
"implementation. If this is unexpected, please submit "
"an issue to https://github.com/BaguaSys/bagua. Thanks.")
"an issue to https://github.com/BaguaSys/bagua. Thanks."
)
return TorchDistributedDataParallel(
module=module,
device_ids=device_ids,
Expand Down
16 changes: 11 additions & 5 deletions bagua/torch_api/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,24 +99,30 @@ def with_bagua( # pytype: disable=module-attr
if process_group is None:
process_group = _get_default_group()

bagua_module_name = None
if hasattr(self, "bagua_ddp"):
bagua_module_name = self.bagua_ddp.bagua_module_name
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my understanding, it cannot use .with_bagua and BaguaDistributedParallel at the same time. So may not need to use separate .bagua_module_name to identify different modules.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to make it possible for our users to tell us whether two module instances are the same or not, instead of simply being determined by the their ids.

In fact, this comes from a dev requirement from Lightning. A torch.nn.Module can be wrapped multiple times, but they share same model parameters.

Copy link
Member

@shjwudp shjwudp Dec 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to make it possible for our users to tell us whether two module instances are the same or not, instead of simply being determined by the their ids.

Users only need to print the module.
My core point is .with_bagua is deprecated, we should not use it anymore.

if not hasattr(self, "_bagua_module_name"):
self._bagua_module_name = "{}_{}".format(self.__class__.__name__, id(self))

self.bagua_ddp = BaguaDistributedDataParallel(
self,
optimizers=optimizers,
algorithm=algorithm,
process_group=process_group,
bagua_module_name=bagua_module_name,
bagua_module_name=self.bagua_module_name,
gradient_as_bucket_view=do_flatten,
)

return self

@property
def bagua_module_name(self):
return self.bagua_ddp.bagua_module_name
"""
The module's name. Bagua uses the module name to distinguish different modules.
"""
return self._bagua_module_name

@bagua_module_name.setter
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use .bagua_module_name directly?

Copy link
Member Author

@wangraying wangraying Dec 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can use bagua_module_name directly, I just think it looks clearer.

def bagua_module_name(self, name: str):
self._bagua_module_name = name

@property
def bagua_algorithm(self):
Expand Down