Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): associate PyTorch Process Group with Bagua Process Group using cache #402

Merged
merged 9 commits into from
Dec 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bagua/torch_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,5 @@
from . import communication # noqa: F401
from . import algorithms # noqa: F401
from . import checkpoint # noqa: E402,F401
from . import data_parallel # noqa: E402,F401
from .model_parallel import moe # noqa: E402,F401
15 changes: 14 additions & 1 deletion bagua/torch_api/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch.distributed.distributed_c10d as c10d
from torch._C._distributed_c10d import ProcessGroup as TorchProcessGroup
import gorilla
import weakref

# fmt: off
__all__ = [
Expand All @@ -49,6 +50,9 @@
# Process group count for default naming
_group_count = 0

# Torch process group to bagua process group
_torch_to_bagua_pg_map = weakref.WeakKeyDictionary({})


# must be consistent with Aluminum ReductionOperator: https://github.com/BaguaSys/Aluminum/blob/master/include/aluminum/base.hpp
class ReduceOp(IntEnum):
Expand All @@ -68,7 +72,16 @@ class ReduceOp(IntEnum):
@gorilla.patches(TorchProcessGroup, filter=lambda name, obj: "bagua" in name)
class BaguaProcessGroupPatch:
def bagua_patch(self, stream: Optional[torch.cuda.Stream] = None):
self.bagua_pg = from_torch_group(self, stream)
global _torch_to_bagua_pg_map
if self not in _torch_to_bagua_pg_map:
_torch_to_bagua_pg_map[self] = from_torch_group(self, stream)

return self

@property
def bagua_pg(self):
assert self in _torch_to_bagua_pg_map, "cannot find associated Bagua process group in cache, BaguaProcessGroupPatch.bagua_patch(...) needs to be run first to initialize Bagua process group in cache."
return _torch_to_bagua_pg_map[self]

def bagua_get_global_communicator(self):
return get_communicator(self.bagua_pg.group_name, "global")
Expand Down
54 changes: 52 additions & 2 deletions tests/torch_api/test_process_group.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import bagua.torch_api as bagua
import torch
import os
import unittest

import torch
import torch.distributed as c10d

import bagua.torch_api as bagua
from tests.internal.common_utils import find_free_port
from tests.internal.multi_process import MultiProcessTestCase, setup_bagua_env
from tests import skip_if_cuda_not_available

Expand Down Expand Up @@ -73,5 +78,50 @@ def test_from_torch_group(self):
self.run_test_locally(run_from_torch_group, nprocs, args={}, results=None)


from torch.testing._internal.common_distributed import ( # noqa: E402
MultiProcessTestCase,
skip_if_lt_x_gpu,
)


class ProcessGroupNCCLTest(MultiProcessTestCase):
def setUp(self):
super(ProcessGroupNCCLTest, self).setUp()
# NCCL_BLOCKING_WAIT overrides NCCL_ASYNC_ERROR_HANDLING hence tests
# that use NCCL_BLOCKING_WAIT will test it as expected.
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
os.environ.update(
{
"MASTER_ADDR": "127.0.0.1",
"MASTER_PORT": str(find_free_port(8000, 8100)),
"BAGUA_SERVICE_PORT": str(find_free_port(9000, 9100)),
}
)
self._spawn_processes()

@skip_if_lt_x_gpu(2)
def test_bagua_pg(self):
# Need to use NCCL_BLOCKING_WAIT and not ASYNC_ERROR_HANDLING,
# otherwise process will be taken down and we can't check for errors.
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
os.environ["NCCL_BLOCKING_WAIT"] = "1"
os.environ.update(
{
"WORLD_SIZE": str(self.world_size),
"LOCAL_WORLD_SIZE": str(self.world_size),
"RANK": str(self.rank),
"LOCAL_RANK": str(self.rank),
}
)

bagua.init_process_group()
pg = c10d.new_group(ranks=list(range(0, self.world_size)))
pg.bagua_patch()
self.assertTrue(pg in bagua.communication._torch_to_bagua_pg_map)
del pg
c10d.destroy_process_group()
self.assertEqual(len(bagua.communication._torch_to_bagua_pg_map), 0)


if __name__ == "__main__":
unittest.main()