Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas] Integrate FlashAttention with SPMD #6935

Merged
merged 9 commits into from
Apr 18, 2024
Merged

Conversation

alanwaketan
Copy link
Collaborator

Summary:
This pull request integrating FlashAttention with SPMD. The way it works is to create a manual sharding region for the kernel which means we wraps all the inputs with enable_manual_sharding and all the outputs with disable_manual_sharding.

Added a new test file because the original test file is not SPMD aware.

Test Plan:
PJRT_DEVICE=TPU python test/test_pallas_spmd.py

@alanwaketan alanwaketan self-assigned this Apr 17, 2024
Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, minor comments

Comment on lines 14 to 19
if xr.device_type() == 'TPU':
from torch_xla.experimental.custom_kernel import jax_import_guard
jax_import_guard()
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is python import global?

jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
from torch_xla.experimental.custom_kernel import flash_attention

xr.use_spmd()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, this should be called in the setup class since it is a one time global config.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I probably can call this in main as well. The setup class seems overkilled for this.

Copy link
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome stuff Jiewen!

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_flash_attention_spmd_data_parallel(self):
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this impact the resulting kernel?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea.

@@ -184,15 +185,29 @@ class FlashAttention(torch.autograd.Function):
}

@staticmethod
def forward(ctx, q, k, v, causal=False):
def forward(ctx, q, k, v, causal=False, sharding_spec=None, mesh=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: sharding_spec -> partition_spec?

@alanwaketan
Copy link
Collaborator Author

Thanks Jon and Jack for the reviews.

@alanwaketan alanwaketan merged commit 9f2b82d into master Apr 18, 2024
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants