Skip to content

Commit

Permalink
Add manual meta implementations to quantize_per_tensor.tensor and co (p…
Browse files Browse the repository at this point in the history
…ytorch#89958)

When you are writing a meta function, you cannot call item() on the tensor because there is no real data on the tensor and it will fail. The error message was not very good in this case, see also pytorch#89959

This PR takes a brute force approach to resolving the problem: just manually define meta implementations for the naughty functions that are calling item(). However, this results in a lot of code duplication. The easiest way to avoid this situation is to rewrite the decomps so they don't call item. It should not be that difficult to use direct tensors on your operations, as scalar tensors can broadcast too.

I could only test this with `buck test @mode/opt -c python.package_style=inplace //executorch/backends/test:test_backends` in internal with D41555454. Test coverage needs to be improved, otherwise don't blame us when we break you.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Pull Request resolved: pytorch#89958
Approved by: https://github.com/jerryzh168
  • Loading branch information
ezyang authored and pytorchmergebot committed Dec 1, 2022
1 parent f1978b1 commit a747326
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions torch/ao/quantization/fx/_decomposed.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ def quantize_per_tensor_tensor(
assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)

@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta")
def quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
return torch.empty_like(input, dtype=dtype)

# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
# the signature as metadata for the input Tensor, this might be useful for pattern
# matching in the future
Expand Down Expand Up @@ -156,6 +164,16 @@ def dequantize_per_tensor_tensor(
assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)

@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta")
def dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}"
if dtype in [torch.uint8, torch.int8, torch.int32]:
return torch.empty_like(input, dtype=torch.float32)
else:
raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")


quantized_decomposed_lib.define(
"choose_qparams.tensor(Tensor input, int quant_min, int quant_max, "
Expand Down

0 comments on commit a747326

Please sign in to comment.