Skip to content

Commit

Permalink
Lowerings: remove restriction on TensorBox keyword arguments (pytorch…
Browse files Browse the repository at this point in the history
…#136055)

Pull Request resolved: pytorch#136055
Approved by: https://github.com/eellison
  • Loading branch information
benjaminglass1 authored and pytorchmergebot committed Oct 2, 2024
1 parent 63d6908 commit c7638da
Showing 1 changed file with 59 additions and 20 deletions.
79 changes: 59 additions & 20 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,41 +258,75 @@ def in_namespace(op, namespace):
return False


def transform_args(args, broadcast, type_promotion_kind, convert_input_to_bool):
indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
if (type_promotion_kind or convert_input_to_bool) and indices:
def transform_args(
args: List[Any],
kwargs: Dict[str, Any],
broadcast: bool,
type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND],
convert_input_to_bool: bool,
) -> Tuple[List[Any], Dict[str, Any]]:
args_indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
kwargs_indices = [k for k, v in kwargs.items() if isinstance(v, TensorBox)]
# check that there's something to transform
if not args_indices and not kwargs_indices:
return args, kwargs

if type_promotion_kind or convert_input_to_bool:
if convert_input_to_bool:
dtype = torch.bool
else:
# FIXME that's a crude approximation for promoting args
# FIXME this is a crude approximation for promoting args
promoting_args = [
a
for a in args
if isinstance(a, (Number, sympy.Basic))
or getattr(a, "dtype", None) is not None
if isinstance(a, (Number, sympy.Basic)) or hasattr(a, "dtype")
]
# only consider tensor kwargs for promotion, for now
promoting_args.extend(a for a in kwargs.values() if hasattr(a, "dtype"))
dtype = get_promoted_dtype(
*promoting_args, type_promotion_kind=type_promotion_kind
*promoting_args, type_promotion_kind=type_promotion_kind # type: ignore[arg-type]
)

device = (
args[args_indices[0]] if args_indices else kwargs[kwargs_indices[0]]
).get_device()

# sometimes args are an immutable list so we can't mutate them
def promote(arg):
if isinstance(arg, TensorBox):
return to_dtype(arg, dtype)
elif isinstance(arg, ir.Constant):
return ir.Constant(arg.value, dtype, args[indices[0]].get_device())
return ir.Constant(arg.value, dtype, device)
else:
return arg

args = [promote(a) for a in args]
if broadcast and indices:
for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])):
kwargs = {k: promote(v) for k, v in kwargs.items()}

if broadcast:
broadcasted = broadcast_tensors(
*list(
itertools.chain(
(args[i] for i in args_indices),
(kwargs[k] for k in kwargs_indices),
)
)
)
size = list(broadcasted[0].get_size())

for i, x in zip(args_indices, broadcasted[: len(args_indices)]):
args[i] = x
for k, x in zip(kwargs_indices, broadcasted[len(args_indices) :]):
kwargs[k] = x

for i in range(len(args)):
if isinstance(args[i], ir.Constant):
args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size()))
args[i] = ExpandView.create(args[i], size)
for k in kwargs:
if isinstance(kwargs[k], ir.Constant):
kwargs[k] = ExpandView.create(kwargs[k], size)

return args
return args, kwargs


def _register_foreach_lowering(aten_fn, decomp_fn):
Expand Down Expand Up @@ -321,7 +355,11 @@ def wrapped(*args, **kwargs):


def _register_lowering(
aten_fn, decomp_fn, broadcast, type_promotion_kind, convert_input_to_bool
aten_fn,
decomp_fn,
broadcast,
type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND],
convert_input_to_bool,
):
"""
Add a lowering to lowerings dict
Expand All @@ -336,25 +374,24 @@ def _register_lowering(

@functools.wraps(decomp_fn)
def wrapped(*args, **kwargs):
args: Union[List[Any], Tuple[Any, ...], Dict[Any, Any]] = list(args)
args: List[Any] = list(args)
kwargs: Dict[str, Any] = dict(kwargs)
unpacked = False
# TODO maybe we need to use pytrees here
if len(args) == 1 and isinstance(args[0], (list, tuple)):
unpacked = True
args = args[0]
args = list(args[0])

# kwargs tensors not supported yet unless it's a fallback op
if not all(
(fn in fallbacks or in_namespace(fn, "_c10d_functional")) for fn in aten_fn
):
assert not any(isinstance(x, TensorBox) for x in kwargs.values())
# explicitly assert for "out=" ops for better error messages
assert not any(
x == "out" for x in kwargs.keys()
), "out= ops aren't yet supported"

args = transform_args(
args, broadcast, type_promotion_kind, convert_input_to_bool
args, kwargs = transform_args(
args, kwargs, broadcast, type_promotion_kind, convert_input_to_bool
)

if unpacked:
Expand All @@ -374,7 +411,9 @@ def wrapped(*args, **kwargs):
def register_lowering(
aten_fn,
broadcast=False,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
type_promotion_kind: Optional[
ELEMENTWISE_TYPE_PROMOTION_KIND
] = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
convert_input_to_bool=False,
):
"""
Expand Down

0 comments on commit c7638da

Please sign in to comment.