Skip to content

Commit

Permalink
Revert "[AOTI] Handle inplace output in ProxyExecutor (pytorch#137660)"
Browse files Browse the repository at this point in the history
This reverts commit 573101a.

Reverted pytorch#137660 on behalf of https://github.com/desertfire due to Fails in fbcode ([comment](pytorch#137660 (comment)))
  • Loading branch information
pytorchmergebot committed Oct 11, 2024
1 parent c58e5c4 commit 0121d64
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 64 deletions.
1 change: 0 additions & 1 deletion .ci/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,6 @@ test_inductor_distributed() {
python test/run_test.py -i inductor/test_aot_inductor.py -k test_non_default_cuda_device --verbose
python test/run_test.py -i inductor/test_aot_inductor.py -k test_replicate_on_devices --verbose
python test/run_test.py -i distributed/test_c10d_functional_native.py --verbose
TORCHINDUCTOR_ABI_COMPATIBLE=1 python test/run_test.py -i distributed/test_c10d_functional_native.py --verbose
python test/run_test.py -i distributed/_tensor/test_dtensor_compile.py --verbose
python test/run_test.py -i distributed/tensor/parallel/test_micro_pipeline_tp.py --verbose
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_comm.py --verbose
Expand Down
66 changes: 16 additions & 50 deletions torch/_inductor/codegen/cpp_wrapper_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2007,11 +2007,7 @@ def codegen_while_loop(self, while_loop):
self.writeline("}")

def generate_extern_kernel_args_decl_if_needed(
self,
op_overload,
raw_args,
output_args: Optional[List[str]] = None,
raw_outputs: Optional[List[ir.Buffer]] = None,
self, op_overload, raw_args, output_args
):
arg_types = [x.real_type for x in op_overload._schema.arguments]
return_types = [x.type for x in op_overload._schema.returns]
Expand Down Expand Up @@ -2099,14 +2095,13 @@ def fill_args(arg, arg_type):
else:
fill_args(arg, arg_type)

def fill_output_arg(arg, return_type, is_mutated_output: bool):
def fill_output_arg(arg, return_type):
if isinstance(return_type, torch.TensorType):
if not is_mutated_output:
self.writeline(f"AtenTensorHandle {arg}_handle; // output buffer")
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{arg}_handle));"
)
self.writeline(f"RAIIAtenTensorHandle {arg}({arg}_handle);")
self.writeline(f"AtenTensorHandle {arg}_handle; // output buffer")
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{arg}_handle));"
)
self.writeline(f"RAIIAtenTensorHandle {arg}({arg}_handle);")
new_tensor_args.append(f"{arg}")
elif isinstance(return_type, torch.SymIntType):
raise NotImplementedError("NYI support for return type: SymInt")
Expand All @@ -2130,21 +2125,13 @@ def fill_output_arg(arg, return_type, is_mutated_output: bool):
f"return type {return_type} is not yet supported."
)

for output_arg, raw_output_arg in zip(output_args, raw_outputs): # type: ignore[arg-type]
for output_arg in output_args:
assert output_arg is not None, "Optional return types are not yet supported"
if isinstance(output_arg, (list, tuple)):
for out in output_arg:
fill_output_arg(
out,
torch.TensorType.get(),
isinstance(raw_output_arg, ir.MutationOutput),
)
fill_output_arg(out, torch.TensorType.get())
else:
fill_output_arg(
output_arg,
torch.TensorType.get(),
isinstance(raw_output_arg, ir.MutationOutput),
)
fill_output_arg(output_arg, torch.TensorType.get())

return new_tensor_args, new_int_args

Expand All @@ -2166,12 +2153,6 @@ def extract_output_name(out):
return None
elif isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)):
return out.get_name()
elif isinstance(out, ir.MutationOutput):
mutated_buf_names = out.get_mutation_names()
assert (
isinstance(mutated_buf_names, list) and len(mutated_buf_names) == 1
), "Expect only one mutated buffer in MutationOutput"
return mutated_buf_names[0]
elif isinstance(out, (list, tuple)):
return type(out)(extract_output_name(o) for o in out)
else:
Expand All @@ -2196,7 +2177,6 @@ def extract_output_name(out):
op_overload,
raw_args,
output_args,
outputs,
)
else:
return self.generate_extern_kernel_alloc_and_find_schema_if_needed_jit(
Expand All @@ -2210,7 +2190,6 @@ def extract_output_name(out):
op_overload,
raw_args,
output_args,
outputs,
)

def generate_scoped_gil_acquire(self, declarations_before_scope, lines_in_scope):
Expand Down Expand Up @@ -2354,7 +2333,6 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed_jit(
op_overload: Optional[torch._ops.OpOverload] = None,
raw_args=None,
output_args: Optional[List[str]] = None,
raw_outputs: Optional[List[ir.Buffer]] = None,
):
if not config.abi_compatible:
# Will update this to use an OSS version ProxyExecutor
Expand Down Expand Up @@ -2418,19 +2396,11 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed_jit(
{output_arg} =
reinterpret_cast<AtenTensorHandle>(PyCapsule_GetPointer(PyList_GET_ITEM(py_{buf_name}.get(), {idx}), NULL));"""

if raw_outputs:
declarations_before_scope = [
f"RAIIAtenTensorHandle {output_arg};"
for output_arg, raw_output_arg in zip(output_args, raw_outputs) # type: ignore[arg-type]
if output_arg is not None
and not isinstance(raw_output_arg, ir.MutationOutput)
]
else:
declarations_before_scope = [
f"RAIIAtenTensorHandle {output_arg};"
for output_arg in output_args # type: ignore[arg-type]
if output_arg is not None
]
declarations_before_scope = [
f"RAIIAtenTensorHandle {output_arg};"
for output_arg in output_args
if output_arg is not None
]
scope_gil_acquire = self.generate_scoped_gil_acquire(
declarations_before_scope, lines
)
Expand All @@ -2442,16 +2412,12 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed_with_proxy_executor(
op_overload,
raw_args, # contains both args and flatten kwargs
output_args: Optional[List[str]] = None,
raw_outputs: Optional[List[ir.Buffer]] = None,
):
(
tensor_call_args,
int_call_args,
) = self.generate_extern_kernel_args_decl_if_needed(
op_overload,
raw_args,
output_args,
raw_outputs,
op_overload, raw_args, output_args
)

tensor_call_args_str = ", ".join(tensor_call_args)
Expand Down
19 changes: 6 additions & 13 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -6066,23 +6066,15 @@ def handle_single_output(return_type, output):
target = self.op_overload
returns = target._schema.returns # type: ignore[union-attr]
if len(returns) == 1:
# FIXME: there is a corner case here, i.e. all_reduce_coalesced_'s return value
# is a list of tensors, but self.mutation_outputs is already flatterned. A proper
# fix would require changing all the uses of self.mutation_outputs.
return_type = returns[0].real_type
output_arguments = [
handle_single_output(
return_type, [*self.outputs, *self.mutation_outputs]
)
]
output_arguments = [handle_single_output(return_type, self.outputs)]
else:
# For tuple returns, e.g "-> (Tensor, Tensor)" or "-> (Tesnor, Tensor[])"
# Not generating output args for self.mutation_outputs
assert isinstance(self.outputs, tuple)
assert len(returns) == len(self.outputs)
output_arguments = [
handle_single_output(return_schema.real_type, output)
for return_schema, output in zip(
returns, [*self.outputs, *self.mutation_outputs]
)
for return_schema, output in zip(returns, self.outputs)
]

node = ExternKernelNode(
Expand Down Expand Up @@ -6146,7 +6138,7 @@ def codegen(self, wrapper):
self.cpp_kernel_overload_name,
self.op_overload,
exported_args,
[*self.outputs, *self.mutation_outputs],
self.outputs,
)
else:
self.codegen_comment(wrapper)
Expand Down Expand Up @@ -6986,6 +6978,7 @@ def create_wait(cls, kernel, inp: TensorBox) -> None:
packed.mutation_outputs.append(
MutationOutput(NoneLayout(inp.get_device()), inp, packed)
)
packed.outputs = [packed]

def get_read_writes(self):
read_writes = super().get_read_writes()
Expand Down

0 comments on commit 0121d64

Please sign in to comment.