Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Monkeypatch Device Mesh ND Slicing #3302

Merged
merged 3 commits into from
May 18, 2024

Conversation

mvpatel2000
Copy link
Contributor

@mvpatel2000 mvpatel2000 commented May 17, 2024

What does this PR do?

Monkeypatch Device Mesh ND Slicing. This allows us to create child meshes of multiple dimensions.

We patch in this PR from Torch pytorch/pytorch#119752 which will not be released until PyTorch 2.4

Toy script that now works which did not work before this PR:

from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh

from composer.utils import dist
from composer.trainer.mosaic_fsdp import patch_pytorch

dist.initialize_dist('gpu', timeout=10.0)

patch_pytorch()

child_mesh_dim_names = ("PP", "DP")
mesh_dim_names = ("PP", "DP", "TP")
mesh = init_device_mesh(
    'cuda', (2, 2, 2), mesh_dim_names=mesh_dim_names
)
child_mesh = mesh[child_mesh_dim_names]
print(child_mesh)

@mvpatel2000 mvpatel2000 requested review from bigning, snarayan21 and dakinggg and removed request for bigning May 17, 2024 20:45
Copy link
Contributor

@snarayan21 snarayan21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you indicate which lines are different from torch 2.3 that we're patching? Or, since the torch PR seems to have made quite extensive changes, a comment with a link to the PR in torch would be useful. other than that, lgtm

@mvpatel2000
Copy link
Contributor Author

could you indicate which lines are different from torch 2.3 that we're patching? Or, since the torch PR seems to have made quite extensive changes, a comment with a link to the PR in torch would be useful. other than that, lgtm

The PR is already in description. It is copy pasting the code from PR into torch -- the changes are equal to the linked PR

Copy link
Contributor

@snarayan21 snarayan21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@mvpatel2000 mvpatel2000 merged commit bb1b7cf into mosaicml:dev May 18, 2024
15 checks passed
@mvpatel2000 mvpatel2000 deleted the mvpatel2000/device-mesh-patch branch May 18, 2024 03:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants