From 76c5bdd2cc476a32e3f93950579f2de9230c9998 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 8 Oct 2024 17:12:41 +0000 Subject: [PATCH] Revert "[Dynamo] Handle extracted unbound tensor methods (#137227)" This reverts commit 14eabd69152e31d059444310979625542db2aece. Reverted https://github.com/pytorch/pytorch/pull/137227 on behalf of https://github.com/malfet due to Need to revert to be able to revert https://github.com/pytorch/pytorch/pull/136910 ([comment](https://github.com/pytorch/pytorch/pull/137227#issuecomment-2400406384)) --- test/dynamo/test_functions.py | 10 ---------- .../TestAsArrayCPU.test_copy_list_cpu_bfloat16 | 0 .../TestAsArrayCPU.test_copy_list_cpu_bool | 0 ...TestAsArrayCPU.test_copy_list_cpu_complex128 | 0 .../TestAsArrayCPU.test_copy_list_cpu_complex64 | 0 .../TestAsArrayCPU.test_copy_list_cpu_float16 | 0 .../TestAsArrayCPU.test_copy_list_cpu_float32 | 0 .../TestAsArrayCPU.test_copy_list_cpu_float64 | 0 .../TestAsArrayCPU.test_copy_list_cpu_int16 | 0 .../TestAsArrayCPU.test_copy_list_cpu_int32 | 0 .../TestAsArrayCPU.test_copy_list_cpu_int64 | 0 .../TestAsArrayCPU.test_copy_list_cpu_int8 | 0 .../TestAsArrayCPU.test_copy_list_cpu_uint8 | 0 torch/_dynamo/variables/tensor.py | 2 +- torch/_dynamo/variables/torch.py | 17 ++++------------- 15 files changed, 5 insertions(+), 24 deletions(-) create mode 100644 test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_bfloat16 create mode 100644 test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_bool create mode 100644 test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_complex128 create mode 100644 test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_complex64 create mode 100644 test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_float16 create mode 100644 test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_float32 create mode 100644 test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_int16 create mode 100644 test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_int32 create mode 100644 test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_int64 create mode 100644 test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_int8 create mode 100644 test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_uint8 diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 694e8b5d23502..3d5d8e6928c86 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -938,16 +938,6 @@ def test_tensor_is_complex(x): else: return x - 1 - @make_test - def test_tensor_size(x): - fn = torch.Tensor.size - return fn(x + 1) - - @make_test - def test_tensor_dim(x): - fn = torch.Tensor.dim - return fn(x + 1) - @make_test def test_tensor_is_inference(x): if x.is_inference(): diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_bfloat16 b/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_bfloat16 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_bool b/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_bool new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_complex128 b/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_complex128 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_complex64 b/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_complex64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_float16 b/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_float16 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_float32 b/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_float32 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_float64 b/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_int16 b/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_int16 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_int32 b/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_int32 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_int64 b/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_int64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_int8 b/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_int8 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_uint8 b/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_uint8 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 7b7d8010f0a61..514a712c89165 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -772,7 +772,7 @@ def method_item(self, *args, **kwargs): self._warn_capture_scalar_outputs() unimplemented("Tensor.item") - def method___getitem__(self, *args, **kwargs): + def method_getitem(self, *args, **kwargs): from ..symbolic_convert import InstructionTranslator from .builder import wrap_fx_proxy diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index c3b5e8165e69e..77d8d2fcf8c10 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -871,6 +871,10 @@ def handle_set_default_device( return ConstantVariable.create(None) + @register(torch._C.TensorBase.__getitem__) + def handle_getitem(self, tx: "InstructionTranslator", *args, **kwargs): + return args[0].call_method(tx, "getitem", args[1:], kwargs) + return handlers def call_function( @@ -900,9 +904,6 @@ def call_function( ), ) - if self.is_tensor_method(): - return self.call_tensor_method(tx, args, kwargs) - special_handler = self._get_handlers().get(self.value) if special_handler: result = special_handler(self, tx, *args, **kwargs) @@ -1175,16 +1176,6 @@ def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad ) return result - def call_tensor_method(self, tx, args, kwargs): - return args[0].call_method(tx, self.get_function().__name__, args[1:], kwargs) - - def is_tensor_method(self): - return ( - inspect.ismethoddescriptor(self.get_function()) - and hasattr(self.get_function(), "__objclass__") - and self.get_function().__objclass__ == torch._C.TensorBase - ) - def torch_function_override_enabled(self, tx, args, kwargs): return ( self.get_function() in get_overridable_functions()