From c7638da55874b2974e29b89d17ab577fd0cb3360 Mon Sep 17 00:00:00 2001 From: Benjamin Glass Date: Tue, 1 Oct 2024 19:45:28 +0000 Subject: [PATCH] Lowerings: remove restriction on TensorBox keyword arguments (#136055) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136055 Approved by: https://github.com/eellison --- torch/_inductor/lowering.py | 79 +++++++++++++++++++++++++++---------- 1 file changed, 59 insertions(+), 20 deletions(-) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index a4676c02448ec..9e48e803885fe 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -258,41 +258,75 @@ def in_namespace(op, namespace): return False -def transform_args(args, broadcast, type_promotion_kind, convert_input_to_bool): - indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] - if (type_promotion_kind or convert_input_to_bool) and indices: +def transform_args( + args: List[Any], + kwargs: Dict[str, Any], + broadcast: bool, + type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND], + convert_input_to_bool: bool, +) -> Tuple[List[Any], Dict[str, Any]]: + args_indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] + kwargs_indices = [k for k, v in kwargs.items() if isinstance(v, TensorBox)] + # check that there's something to transform + if not args_indices and not kwargs_indices: + return args, kwargs + + if type_promotion_kind or convert_input_to_bool: if convert_input_to_bool: dtype = torch.bool else: - # FIXME that's a crude approximation for promoting args + # FIXME this is a crude approximation for promoting args promoting_args = [ a for a in args - if isinstance(a, (Number, sympy.Basic)) - or getattr(a, "dtype", None) is not None + if isinstance(a, (Number, sympy.Basic)) or hasattr(a, "dtype") ] + # only consider tensor kwargs for promotion, for now + promoting_args.extend(a for a in kwargs.values() if hasattr(a, "dtype")) dtype = get_promoted_dtype( - *promoting_args, type_promotion_kind=type_promotion_kind + *promoting_args, type_promotion_kind=type_promotion_kind # type: ignore[arg-type] ) + device = ( + args[args_indices[0]] if args_indices else kwargs[kwargs_indices[0]] + ).get_device() + # sometimes args are an immutable list so we can't mutate them def promote(arg): if isinstance(arg, TensorBox): return to_dtype(arg, dtype) elif isinstance(arg, ir.Constant): - return ir.Constant(arg.value, dtype, args[indices[0]].get_device()) + return ir.Constant(arg.value, dtype, device) else: return arg args = [promote(a) for a in args] - if broadcast and indices: - for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])): + kwargs = {k: promote(v) for k, v in kwargs.items()} + + if broadcast: + broadcasted = broadcast_tensors( + *list( + itertools.chain( + (args[i] for i in args_indices), + (kwargs[k] for k in kwargs_indices), + ) + ) + ) + size = list(broadcasted[0].get_size()) + + for i, x in zip(args_indices, broadcasted[: len(args_indices)]): args[i] = x + for k, x in zip(kwargs_indices, broadcasted[len(args_indices) :]): + kwargs[k] = x + for i in range(len(args)): if isinstance(args[i], ir.Constant): - args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size())) + args[i] = ExpandView.create(args[i], size) + for k in kwargs: + if isinstance(kwargs[k], ir.Constant): + kwargs[k] = ExpandView.create(kwargs[k], size) - return args + return args, kwargs def _register_foreach_lowering(aten_fn, decomp_fn): @@ -321,7 +355,11 @@ def wrapped(*args, **kwargs): def _register_lowering( - aten_fn, decomp_fn, broadcast, type_promotion_kind, convert_input_to_bool + aten_fn, + decomp_fn, + broadcast, + type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND], + convert_input_to_bool, ): """ Add a lowering to lowerings dict @@ -336,25 +374,24 @@ def _register_lowering( @functools.wraps(decomp_fn) def wrapped(*args, **kwargs): - args: Union[List[Any], Tuple[Any, ...], Dict[Any, Any]] = list(args) + args: List[Any] = list(args) + kwargs: Dict[str, Any] = dict(kwargs) unpacked = False # TODO maybe we need to use pytrees here if len(args) == 1 and isinstance(args[0], (list, tuple)): unpacked = True - args = args[0] + args = list(args[0]) - # kwargs tensors not supported yet unless it's a fallback op if not all( (fn in fallbacks or in_namespace(fn, "_c10d_functional")) for fn in aten_fn ): - assert not any(isinstance(x, TensorBox) for x in kwargs.values()) # explicitly assert for "out=" ops for better error messages assert not any( x == "out" for x in kwargs.keys() ), "out= ops aren't yet supported" - args = transform_args( - args, broadcast, type_promotion_kind, convert_input_to_bool + args, kwargs = transform_args( + args, kwargs, broadcast, type_promotion_kind, convert_input_to_bool ) if unpacked: @@ -374,7 +411,9 @@ def wrapped(*args, **kwargs): def register_lowering( aten_fn, broadcast=False, - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + type_promotion_kind: Optional[ + ELEMENTWISE_TYPE_PROMOTION_KIND + ] = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, convert_input_to_bool=False, ): """