Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support block granularity for QuantizeLinear and DequantizeLinear #3412

Merged
merged 11 commits into from
Sep 28, 2024

Conversation

music-dino
Copy link
Collaborator

Add support for block level granularity in QuantizeLinear and DequantizeLinear.

y_scale and y_zero point are transformed to match the shape of x by applying a unsqueeze->broadcast->reshape chain of transformations.
If the final block of x is smaller than given block_size the transformed y_scale and y_zero_point are sliced to remove excess elements.

Resolves migraphx-benchmark#192

@music-dino music-dino added the Onnx Operators Adding or modifying an Onnx Operator in the MIGraphX codebase label Sep 4, 2024
Copy link

codecov bot commented Sep 4, 2024

Codecov Report

Attention: Patch coverage is 96.59091% with 3 lines in your changes missing coverage. Please review.

Project coverage is 92.02%. Comparing base (e4eb481) to head (f0c12a4).
Report is 1 commits behind head on develop.

Files with missing lines Patch % Lines
src/onnx/parse_quantizelinear.cpp 88.88% 3 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #3412      +/-   ##
===========================================
- Coverage    92.02%   92.02%   -0.01%     
===========================================
  Files          508      509       +1     
  Lines        20948    21005      +57     
===========================================
+ Hits         19278    19330      +52     
- Misses        1670     1675       +5     
Flag Coverage Δ
92.02% <96.59%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


common_args.push_back(y_zero_point);
if(parser.opset_version < 19)
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are only two types supported for T1 before version 19. I appreciate your thoroughness in following up those details. But it isn't clear that this operator should then support input type x as either float or int32. And later version should additionally support bfloat16, float16. Thanks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about adding these as well, but decided against mainly because I haven't noticed that it's common practice to have the type constraints checked in parser code, although I might be wrong here.


// Starting with version 19 ONNX introduced the constraint that x and y_scale types must be
// the same
if(parser.opset_version >= 19 and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a matter of general approach, if common_type (below) can be safely derived even for version prior to 19, is it okay to not flag errors for type mismatch -- i.e. by looking at Opset version? This is just for my understanding -- I am not suggesting a code change here. Thanks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to flag it because the onnx spec states that it's a constraint for versions 19 and up.
The common type derivation and conversion could be done for all cases, without a version check condition, but It'd be doing the extra work of common type calculation and looping over the arguments for opset versions 19+ to no avail, since we already know that the types are the same for that case.

else
{
axis = tune_axis(x_rank, axis, op_name);
if(block_size == 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our quark generated graph doesn't use an explicit block_size. So this assumption about its being 0 needs to be tweaked a little bit. I think this parameter is optional. So we should work well in case it isn't supplied -- and not assume it is 0 then, and its final value should be computed to be = block_size_min. OTOH, if block_size is supplied with a model, and it isn't a 0, then we should compare it within the lower and upper bounds. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Therefore, please remove this exception clause.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ONNX spec states it is an optional attribute, with a default value of 0:
https://onnx.ai/onnx/operators/onnx__QuantizeLinear.html#attributes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have the quark-generated graph compiling with your current code. I can change this code later. Thanks.

// axis=i, the accepted range is [ceil(Di/Si), ceil(Di/(Si-1))-1]
float di = x_lens[axis];
float si = y_scale_lens[axis];
int block_size_min = std::ceil(di / si);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sample code that can be added below if exception above is removed -- for block_size == 0.

if(block_size ==0) block_size = block_size_min;

@causten causten added the high priority A PR with high priority for review and merging. label Sep 11, 2024
@migraphx-bot
Copy link
Collaborator

Test Batch Rate new
73178d
Rate old
7c2fdf
Diff Compare
torchvision-resnet50 64 3,250.28 3,249.19 0.03%
torchvision-resnet50_fp16 64 6,993.72 6,993.27 0.01%
torchvision-densenet121 32 2,434.70 2,434.31 0.02%
torchvision-densenet121_fp16 32 4,064.64 4,095.02 -0.74%
torchvision-inceptionv3 32 1,635.58 1,635.79 -0.01%
torchvision-inceptionv3_fp16 32 2,739.23 2,740.83 -0.06%
cadene-inceptionv4 16 776.31 776.76 -0.06%
cadene-resnext64x4 16 808.33 808.72 -0.05%
slim-mobilenet 64 7,455.05 7,455.28 -0.00%
slim-nasnetalarge 64 208.24 208.38 -0.07%
slim-resnet50v2 64 3,433.61 3,435.08 -0.04%
bert-mrpc-onnx 8 1,150.40 1,150.34 0.01%
bert-mrpc-tf 1 312.56 314.36 -0.57%
pytorch-examples-wlang-gru 1 418.10 418.46 -0.08%
pytorch-examples-wlang-lstm 1 382.21 499.68 -23.51% 🔴
torchvision-resnet50_1 1 780.24 772.72 0.97%
cadene-dpn92_1 1 437.74 397.74 10.06% 🔆
cadene-resnext101_1 1 381.54 383.61 -0.54%
onnx-taau-downsample 1 344.56 344.76 -0.06%
dlrm-criteoterabyte 1 35.05 35.10 -0.15%
dlrm-criteoterabyte_fp16 1 58.13 58.12 0.01%
agentmodel 1 8,198.74 7,932.67 3.35% 🔆
unet_fp16 2 58.11 57.85 0.44%
resnet50v1_fp16 1 941.69 935.68 0.64%
resnet50v1_int8 1 934.18 949.99 -1.66%
bert_base_cased_fp16 64 1,153.57 1,153.06 0.04%
bert_large_uncased_fp16 32 355.73 355.77 -0.01%
bert_large_fp16 1 211.68 210.32 0.64%
distilgpt2_fp16 16 2,160.36 2,161.65 -0.06%
yolov5s 1 537.57 534.27 0.62%
tinyllama 1 43.41 43.40 0.03%
vicuna-fastchat 1 176.57 170.43 3.60% 🔆
whisper-tiny-encoder 1 418.00 418.17 -0.04%
whisper-tiny-decoder 1 433.60 426.09 1.76%

This build is not recommended to merge 🔴

@migraphx-bot
Copy link
Collaborator


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

{
x_zero_point = info.add_instruction(
make_op("multibroadcast", {{"out_lens", input_lens}}), x_zero_point);
MIGRAPHX_THROW("DequantizeLinear: y_scale and y_zero_point shapes must be equal. "
Copy link
Contributor

@lakhinderwalia lakhinderwalia Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: "DequantizeLinear: y_scale and y_zero_point shape mismatch."

Copy link
Contributor

@lakhinderwalia lakhinderwalia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left you some very minor comments. They are optional.
Approved.

if(parser.opset_version < 19)
{
auto common_type = common_shape({args[0]->get_shape(), args[1]->get_shape()}).type();
std::transform(args.begin(), args.begin() + 2, args.begin(), [&](auto ins) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just trying to understand here: Why is it args.begin() + 2. And not args.end(). Thanks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prior to version 19, the first two inputs(x and y_scales) can have different float types, so a conversion to common type is needed to make the mgx operator work. The optional third input will have a type of int8 or uint8, and we want to leave it that way.

auto common_args = add_common_args(*info.mod, {args[0], y_scale});

if(args.size() == 3)
if(output_type.has_value() and args.size() == 3 and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style: Please do the exception processing in one clause, on line 59 above.

Copy link
Collaborator

@CharlieL7 CharlieL7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add some onnx parse tests that show the expected MIGraphX IR output from these changes.

src/onnx/quantize_dequantize_linear.cpp Outdated Show resolved Hide resolved
@music-dino
Copy link
Collaborator Author

Please add some onnx parse tests that show the expected MIGraphX IR output from these changes.

I've added a couple of parse tests.

@music-dino music-dino requested a review from a team as a code owner September 27, 2024 13:41
@causten causten merged commit 74bc6be into develop Sep 28, 2024
47 of 48 checks passed
@causten causten deleted the quantizelinear_blocked branch September 28, 2024 04:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority A PR with high priority for review and merging. Onnx Operators Adding or modifying an Onnx Operator in the MIGraphX codebase
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support block granularity for QuantizeLinear and DequantizeLinear
5 participants