From a747326423ed4731996769e3b8eb73eecbdee2d4 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Wed, 30 Nov 2022 22:06:08 -0500 Subject: [PATCH] Add manual meta implementations to quantize_per_tensor.tensor and co (#89958) When you are writing a meta function, you cannot call item() on the tensor because there is no real data on the tensor and it will fail. The error message was not very good in this case, see also https://github.com/pytorch/pytorch/issues/89959 This PR takes a brute force approach to resolving the problem: just manually define meta implementations for the naughty functions that are calling item(). However, this results in a lot of code duplication. The easiest way to avoid this situation is to rewrite the decomps so they don't call item. It should not be that difficult to use direct tensors on your operations, as scalar tensors can broadcast too. I could only test this with `buck test @mode/opt -c python.package_style=inplace //executorch/backends/test:test_backends` in internal with D41555454. Test coverage needs to be improved, otherwise don't blame us when we break you. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/89958 Approved by: https://github.com/jerryzh168 --- torch/ao/quantization/fx/_decomposed.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/torch/ao/quantization/fx/_decomposed.py b/torch/ao/quantization/fx/_decomposed.py index 0e020a15a826d..a6f5ad7a3d0b9 100644 --- a/torch/ao/quantization/fx/_decomposed.py +++ b/torch/ao/quantization/fx/_decomposed.py @@ -83,6 +83,14 @@ def quantize_per_tensor_tensor( assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}" return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) +@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta") +def quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype): + assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}" + assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}" + assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + return torch.empty_like(input, dtype=dtype) + # Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in # the signature as metadata for the input Tensor, this might be useful for pattern # matching in the future @@ -156,6 +164,16 @@ def dequantize_per_tensor_tensor( assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}" return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) +@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta") +def dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype): + assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}" + assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}" + assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}" + if dtype in [torch.uint8, torch.int8, torch.int32]: + return torch.empty_like(input, dtype=torch.float32) + else: + raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}") + quantized_decomposed_lib.define( "choose_qparams.tensor(Tensor input, int quant_min, int quant_max, "