Skip to content

Commit

Permalink
Revert "[Dynamo] add flex attention mode test (pytorch#137121)"
Browse files Browse the repository at this point in the history
This reverts commit 144665d.

Reverted pytorch#137121 on behalf of https://github.com/malfet due to Need to revert to be able to revert pytorch#136910 ([comment](pytorch#137121 (comment)))
  • Loading branch information
pytorchmergebot committed Oct 8, 2024
1 parent 11192ce commit cc10ef4
Showing 1 changed file with 0 additions and 18 deletions.
18 changes: 0 additions & 18 deletions test/dynamo/test_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
_push_on_torch_function_stack,
)
from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode
from torch.testing._internal.triton_utils import requires_cuda
from torch.utils._device import DeviceContext
from torch.utils._python_dispatch import TorchDispatchMode

Expand Down Expand Up @@ -582,23 +581,6 @@ def run_checks(setups_and_oplists, skips, ref_map):
run_checks(setups_and_oplists, skips, BUILTIN_TO_TENSOR_FN_MAP)
run_checks(rsetups_and_oplists, rskips, BUILTIN_TO_TENSOR_RFN_MAP)

@requires_cuda
def test_flex_attention(self):
import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

torch.set_default_device("cuda")

flex_attention = torch.compile(flex_attention, dynamic=False)

prefix_lengths = torch.arange(8)

def prefix_lm(b, h, q, kv):
return prefix_lengths[b] >= kv

# This runs in fullgraph already
mask = create_block_mask(prefix_lm, 8, None, 512, 512, _compile=True)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down

0 comments on commit cc10ef4

Please sign in to comment.