Skip to content

Commit

Permalink
Check for fused kernel before inplace update (pytorch#137042)
Browse files Browse the repository at this point in the history
Summary:
Given an op, with a pair (output buffer, input buffer) from that op, we consider marking the output buffer as inline. However, if the parent of input buffer and the current op are going to be fused, then we don't want to mark the output buffer as inline. This change checks that criterion, and skips inlining if it is so.

Test Plan:
New unit test "layer_norm_should_not_inplace" runs LayerNorm and checks for no "in_out" pointers.

Fixes pytorch#120217

Here's a diagram of the issue:
![Inline+Fusion](https://github.com/user-attachments/assets/c03308d8-fdbf-40a0-a46d-964ece5f9e6d)

Pull Request resolved: pytorch#137042
Approved by: https://github.com/eellison
  • Loading branch information
exclamaforte authored and pytorchmergebot committed Oct 2, 2024
1 parent a3f3773 commit 36fb342
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
14 changes: 14 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12466,6 +12466,20 @@ def fn(x):
_, (code,) = run_and_get_code(torch.compile(fn), inp)
FileCheck().check("copy_").check_same("True").run(code)

@config.patch(inplace_buffers=True)
def test_layer_norm_should_not_inplace(self):
# https://github.com/pytorch/pytorch/issues/120217
D = 16

def fn(x):
return nn.LayerNorm([D], dtype=torch.float16)(x)

inps = [torch.rand(D, dtype=torch.float16)]
fn_opt = torch.compile(fn)
code = run_and_get_triton_code(fn_opt, *inps)
self.assertTrue("in_out_ptr" not in code)
self.assertEqual(fn_opt(*inps), fn(*inps))

class RNNTest(TestCase):
device_type = GPU_TYPE

Expand Down
9 changes: 9 additions & 0 deletions torch/_inductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,10 @@ def decide_inplace_update(self) -> None:
and hasattr(V.kernel, "args")
):
return
fused_nodes = {
node.get_name()
for node in self.scheduler.name_to_fused_node[self.get_name()].get_nodes()
}

ordered_reads = sorted(self.read_writes.reads, key=lambda x: x.name)

Expand All @@ -419,6 +423,11 @@ def decide_inplace_update(self) -> None:
and V.graph.wrapper_code.can_reuse(input_buf, self)
and not isinstance(input_buf.defining_op, NopKernelSchedulerNode)
):
# If the writers of input_buf are in the same FusedSchedulerNode as the current op, then there is
# no need to inplace.
if input_buf.defining_op.get_name() in fused_nodes:
continue

assert input_buf.users is not None
remaining_uses = [
x
Expand Down

0 comments on commit 36fb342

Please sign in to comment.