Skip to content

Commit

Permalink
Revert "[Dynamo] Handle extracted unbound tensor methods (pytorch#137227
Browse files Browse the repository at this point in the history
)"

This reverts commit 14eabd6.

Reverted pytorch#137227 on behalf of https://github.com/malfet due to Need to revert to be able to revert pytorch#136910 ([comment](pytorch#137227 (comment)))
  • Loading branch information
pytorchmergebot committed Oct 8, 2024
1 parent c88c0e6 commit 76c5bdd
Show file tree
Hide file tree
Showing 15 changed files with 5 additions and 24 deletions.
10 changes: 0 additions & 10 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,16 +938,6 @@ def test_tensor_is_complex(x):
else:
return x - 1

@make_test
def test_tensor_size(x):
fn = torch.Tensor.size
return fn(x + 1)

@make_test
def test_tensor_dim(x):
fn = torch.Tensor.dim
return fn(x + 1)

@make_test
def test_tensor_is_inference(x):
if x.is_inference():
Expand Down
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ def method_item(self, *args, **kwargs):
self._warn_capture_scalar_outputs()
unimplemented("Tensor.item")

def method___getitem__(self, *args, **kwargs):
def method_getitem(self, *args, **kwargs):
from ..symbolic_convert import InstructionTranslator
from .builder import wrap_fx_proxy

Expand Down
17 changes: 4 additions & 13 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,10 @@ def handle_set_default_device(

return ConstantVariable.create(None)

@register(torch._C.TensorBase.__getitem__)
def handle_getitem(self, tx: "InstructionTranslator", *args, **kwargs):
return args[0].call_method(tx, "getitem", args[1:], kwargs)

return handlers

def call_function(
Expand Down Expand Up @@ -900,9 +904,6 @@ def call_function(
),
)

if self.is_tensor_method():
return self.call_tensor_method(tx, args, kwargs)

special_handler = self._get_handlers().get(self.value)
if special_handler:
result = special_handler(self, tx, *args, **kwargs)
Expand Down Expand Up @@ -1175,16 +1176,6 @@ def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad
)
return result

def call_tensor_method(self, tx, args, kwargs):
return args[0].call_method(tx, self.get_function().__name__, args[1:], kwargs)

def is_tensor_method(self):
return (
inspect.ismethoddescriptor(self.get_function())
and hasattr(self.get_function(), "__objclass__")
and self.get_function().__objclass__ == torch._C.TensorBase
)

def torch_function_override_enabled(self, tx, args, kwargs):
return (
self.get_function() in get_overridable_functions()
Expand Down

0 comments on commit 76c5bdd

Please sign in to comment.