diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 7aa8d0c4ec7a9..3392f17b1f5a9 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -318,7 +318,6 @@ test_inductor_distributed() { python test/run_test.py -i inductor/test_aot_inductor.py -k test_non_default_cuda_device --verbose python test/run_test.py -i inductor/test_aot_inductor.py -k test_replicate_on_devices --verbose python test/run_test.py -i distributed/test_c10d_functional_native.py --verbose - TORCHINDUCTOR_ABI_COMPATIBLE=1 python test/run_test.py -i distributed/test_c10d_functional_native.py --verbose python test/run_test.py -i distributed/_tensor/test_dtensor_compile.py --verbose python test/run_test.py -i distributed/tensor/parallel/test_micro_pipeline_tp.py --verbose python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_comm.py --verbose diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 468929774fdb6..fccd6f8765d52 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -2007,11 +2007,7 @@ def codegen_while_loop(self, while_loop): self.writeline("}") def generate_extern_kernel_args_decl_if_needed( - self, - op_overload, - raw_args, - output_args: Optional[List[str]] = None, - raw_outputs: Optional[List[ir.Buffer]] = None, + self, op_overload, raw_args, output_args ): arg_types = [x.real_type for x in op_overload._schema.arguments] return_types = [x.type for x in op_overload._schema.returns] @@ -2099,14 +2095,13 @@ def fill_args(arg, arg_type): else: fill_args(arg, arg_type) - def fill_output_arg(arg, return_type, is_mutated_output: bool): + def fill_output_arg(arg, return_type): if isinstance(return_type, torch.TensorType): - if not is_mutated_output: - self.writeline(f"AtenTensorHandle {arg}_handle; // output buffer") - self.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{arg}_handle));" - ) - self.writeline(f"RAIIAtenTensorHandle {arg}({arg}_handle);") + self.writeline(f"AtenTensorHandle {arg}_handle; // output buffer") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{arg}_handle));" + ) + self.writeline(f"RAIIAtenTensorHandle {arg}({arg}_handle);") new_tensor_args.append(f"{arg}") elif isinstance(return_type, torch.SymIntType): raise NotImplementedError("NYI support for return type: SymInt") @@ -2130,21 +2125,13 @@ def fill_output_arg(arg, return_type, is_mutated_output: bool): f"return type {return_type} is not yet supported." ) - for output_arg, raw_output_arg in zip(output_args, raw_outputs): # type: ignore[arg-type] + for output_arg in output_args: assert output_arg is not None, "Optional return types are not yet supported" if isinstance(output_arg, (list, tuple)): for out in output_arg: - fill_output_arg( - out, - torch.TensorType.get(), - isinstance(raw_output_arg, ir.MutationOutput), - ) + fill_output_arg(out, torch.TensorType.get()) else: - fill_output_arg( - output_arg, - torch.TensorType.get(), - isinstance(raw_output_arg, ir.MutationOutput), - ) + fill_output_arg(output_arg, torch.TensorType.get()) return new_tensor_args, new_int_args @@ -2166,12 +2153,6 @@ def extract_output_name(out): return None elif isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)): return out.get_name() - elif isinstance(out, ir.MutationOutput): - mutated_buf_names = out.get_mutation_names() - assert ( - isinstance(mutated_buf_names, list) and len(mutated_buf_names) == 1 - ), "Expect only one mutated buffer in MutationOutput" - return mutated_buf_names[0] elif isinstance(out, (list, tuple)): return type(out)(extract_output_name(o) for o in out) else: @@ -2196,7 +2177,6 @@ def extract_output_name(out): op_overload, raw_args, output_args, - outputs, ) else: return self.generate_extern_kernel_alloc_and_find_schema_if_needed_jit( @@ -2210,7 +2190,6 @@ def extract_output_name(out): op_overload, raw_args, output_args, - outputs, ) def generate_scoped_gil_acquire(self, declarations_before_scope, lines_in_scope): @@ -2354,7 +2333,6 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed_jit( op_overload: Optional[torch._ops.OpOverload] = None, raw_args=None, output_args: Optional[List[str]] = None, - raw_outputs: Optional[List[ir.Buffer]] = None, ): if not config.abi_compatible: # Will update this to use an OSS version ProxyExecutor @@ -2418,19 +2396,11 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed_jit( {output_arg} = reinterpret_cast(PyCapsule_GetPointer(PyList_GET_ITEM(py_{buf_name}.get(), {idx}), NULL));""" - if raw_outputs: - declarations_before_scope = [ - f"RAIIAtenTensorHandle {output_arg};" - for output_arg, raw_output_arg in zip(output_args, raw_outputs) # type: ignore[arg-type] - if output_arg is not None - and not isinstance(raw_output_arg, ir.MutationOutput) - ] - else: - declarations_before_scope = [ - f"RAIIAtenTensorHandle {output_arg};" - for output_arg in output_args # type: ignore[arg-type] - if output_arg is not None - ] + declarations_before_scope = [ + f"RAIIAtenTensorHandle {output_arg};" + for output_arg in output_args + if output_arg is not None + ] scope_gil_acquire = self.generate_scoped_gil_acquire( declarations_before_scope, lines ) @@ -2442,16 +2412,12 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed_with_proxy_executor( op_overload, raw_args, # contains both args and flatten kwargs output_args: Optional[List[str]] = None, - raw_outputs: Optional[List[ir.Buffer]] = None, ): ( tensor_call_args, int_call_args, ) = self.generate_extern_kernel_args_decl_if_needed( - op_overload, - raw_args, - output_args, - raw_outputs, + op_overload, raw_args, output_args ) tensor_call_args_str = ", ".join(tensor_call_args) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 7c2ca2c192119..5ded8b8ead919 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -6066,23 +6066,15 @@ def handle_single_output(return_type, output): target = self.op_overload returns = target._schema.returns # type: ignore[union-attr] if len(returns) == 1: - # FIXME: there is a corner case here, i.e. all_reduce_coalesced_'s return value - # is a list of tensors, but self.mutation_outputs is already flatterned. A proper - # fix would require changing all the uses of self.mutation_outputs. return_type = returns[0].real_type - output_arguments = [ - handle_single_output( - return_type, [*self.outputs, *self.mutation_outputs] - ) - ] + output_arguments = [handle_single_output(return_type, self.outputs)] else: # For tuple returns, e.g "-> (Tensor, Tensor)" or "-> (Tesnor, Tensor[])" - # Not generating output args for self.mutation_outputs + assert isinstance(self.outputs, tuple) + assert len(returns) == len(self.outputs) output_arguments = [ handle_single_output(return_schema.real_type, output) - for return_schema, output in zip( - returns, [*self.outputs, *self.mutation_outputs] - ) + for return_schema, output in zip(returns, self.outputs) ] node = ExternKernelNode( @@ -6146,7 +6138,7 @@ def codegen(self, wrapper): self.cpp_kernel_overload_name, self.op_overload, exported_args, - [*self.outputs, *self.mutation_outputs], + self.outputs, ) else: self.codegen_comment(wrapper) @@ -6986,6 +6978,7 @@ def create_wait(cls, kernel, inp: TensorBox) -> None: packed.mutation_outputs.append( MutationOutput(NoneLayout(inp.get_device()), inp, packed) ) + packed.outputs = [packed] def get_read_writes(self): read_writes = super().get_read_writes()