-
Notifications
You must be signed in to change notification settings - Fork 454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Pallas] Integrate FlashAttention with SPMD #6935
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, minor comments
test/test_pallas_spmd.py
Outdated
if xr.device_type() == 'TPU': | ||
from torch_xla.experimental.custom_kernel import jax_import_guard | ||
jax_import_guard() | ||
import jax | ||
import jax.numpy as jnp | ||
from jax.experimental import pallas as pl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, you can put this part in the setup class similar to https://github.com/pytorch/xla/blob/master/test/spmd/test_xla_sharding_base.py#L31-L35
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is python import global?
test/test_pallas_spmd.py
Outdated
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) | ||
from torch_xla.experimental.custom_kernel import flash_attention | ||
|
||
xr.use_spmd() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, this should be called in the setup class since it is a one time global config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I probably can call this in main
as well. The setup class seems overkilled for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome stuff Jiewen!
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, | ||
"This test only works on TPUv3+.") | ||
def test_flash_attention_spmd_data_parallel(self): | ||
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this impact the resulting kernel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea.
@@ -184,15 +185,29 @@ class FlashAttention(torch.autograd.Function): | |||
} | |||
|
|||
@staticmethod | |||
def forward(ctx, q, k, v, causal=False): | |||
def forward(ctx, q, k, v, causal=False, sharding_spec=None, mesh=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: sharding_spec
-> partition_spec
?
Thanks Jon and Jack for the reviews. |
Summary:
This pull request integrating FlashAttention with SPMD. The way it works is to create a manual sharding region for the kernel which means we wraps all the inputs with enable_manual_sharding and all the outputs with disable_manual_sharding.
Added a new test file because the original test file is not SPMD aware.
Test Plan:
PJRT_DEVICE=TPU python test/test_pallas_spmd.py