-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support block granularity for QuantizeLinear and DequantizeLinear #3412
Changes from 7 commits
3a85f7c
0807594
cd15870
c3d6f67
0d9e209
8df2097
73178d7
39b628d
cf36e5c
7d6b9b7
f0c12a4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
/* | ||
* The MIT License (MIT) | ||
* | ||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. | ||
* | ||
* Permission is hereby granted, free of charge, to any person obtaining a copy | ||
* of this software and associated documentation files (the "Software"), to deal | ||
* in the Software without restriction, including without limitation the rights | ||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
* copies of the Software, and to permit persons to whom the Software is | ||
* furnished to do so, subject to the following conditions: | ||
* | ||
* The above copyright notice and this permission notice shall be included in | ||
* all copies or substantial portions of the Software. | ||
* | ||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
* THE SOFTWARE. | ||
*/ | ||
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_QUANTIZE_DEQUANTIZE_LINEAR_HPP | ||
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_QUANTIZE_DEQUANTIZE_LINEAR_HPP | ||
|
||
#include <migraphx/onnx/op_parser.hpp> | ||
#include <migraphx/instruction.hpp> | ||
|
||
namespace migraphx { | ||
inline namespace MIGRAPHX_INLINE_NS { | ||
namespace onnx { | ||
|
||
void transform_quantize_dequantize_linear_inputs(const onnx_parser::node_info& info, | ||
const std::string& op_name, | ||
int block_size, | ||
int axis, | ||
std::vector<instruction_ref>& args); | ||
|
||
} // namespace onnx | ||
} // namespace MIGRAPHX_INLINE_NS | ||
} // namespace migraphx | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
/* | ||
* The MIT License (MIT) | ||
* | ||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. | ||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. | ||
* | ||
* Permission is hereby granted, free of charge, to any person obtaining a copy | ||
* of this software and associated documentation files (the "Software"), to deal | ||
|
@@ -27,6 +27,7 @@ | |
#include <migraphx/make_op.hpp> | ||
#include <migraphx/tune_axis.hpp> | ||
#include <migraphx/common.hpp> | ||
#include <migraphx/onnx/quantize_dequantize_linear.hpp> | ||
|
||
namespace migraphx { | ||
inline namespace MIGRAPHX_INLINE_NS { | ||
|
@@ -37,47 +38,74 @@ | |
std::vector<op_desc> operators() const { return {{"QuantizeLinear"}}; } | ||
|
||
instruction_ref parse(const op_desc& opd, | ||
const onnx_parser& /*parser*/, | ||
const onnx_parser& parser, | ||
const onnx_parser::node_info& info, | ||
const std::vector<instruction_ref>& args) const | ||
std::vector<instruction_ref>& args) const | ||
{ | ||
if(args.size() < 2 or args.size() > 3) | ||
{ | ||
MIGRAPHX_THROW("QuantizeLinear: must have either 2 or 3 inputs, " + | ||
std::to_string(args.size()) + " input(s) provided"); | ||
} | ||
|
||
// Starting with version 19 ONNX introduced the constraint that x and y_scale types must be | ||
// the same | ||
if(parser.opset_version >= 19 and | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As a matter of general approach, if common_type (below) can be safely derived even for version prior to 19, is it okay to not flag errors for type mismatch -- i.e. by looking at Opset version? This is just for my understanding -- I am not suggesting a code change here. Thanks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have to flag it because the onnx spec states that it's a constraint for versions 19 and up. |
||
args[0]->get_shape().type() != args[1]->get_shape().type()) | ||
{ | ||
MIGRAPHX_THROW("QuantizeLinear: x and y_scale must be of same type"); | ||
} | ||
|
||
if(args.size() == 3 and args[1]->get_shape().lens() != args[2]->get_shape().lens()) | ||
{ | ||
MIGRAPHX_THROW( | ||
"QuantizeLinear: y_scale and y_zero_point shapes must be equal. Provided y_scale " | ||
"shape: " + | ||
to_string_range(args[1]->get_shape().lens()) + | ||
", provided y_zero_point shape: " + to_string_range(args[2]->get_shape().lens())); | ||
} | ||
|
||
int axis = 1; | ||
if(contains(info.attributes, "axis")) | ||
axis = info.attributes.at("axis").i(); | ||
|
||
auto input_lens = args[0]->get_shape().lens(); | ||
auto n_dim = input_lens.size(); | ||
int block_size = 0; | ||
if(contains(info.attributes, "block_size")) | ||
block_size = info.attributes.at("block_size").i(); | ||
|
||
instruction_ref y_scale = args[1]; | ||
if(args[1]->get_shape().elements() != 1) | ||
std::optional<migraphx::shape::type_t> output_type; | ||
if(contains(info.attributes, "output_dtype")) | ||
{ | ||
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name); | ||
y_scale = info.add_instruction( | ||
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]); | ||
output_type = get_type(info.attributes.at("output_dtype").i()); | ||
} | ||
|
||
auto common_args = add_common_args(*info.mod, {args[0], y_scale}); | ||
|
||
if(args.size() == 3) | ||
if(output_type.has_value() and args.size() == 3 and | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Style: Please do the exception processing in one clause, on line 59 above. |
||
*output_type != args[2]->get_shape().type()) | ||
{ | ||
auto y_zero_point = args[2]; | ||
if(y_zero_point->get_shape().elements() != 1) | ||
{ | ||
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name); | ||
y_zero_point = info.add_instruction( | ||
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), | ||
y_zero_point); | ||
} | ||
else | ||
{ | ||
y_zero_point = info.add_instruction( | ||
make_op("multibroadcast", {{"out_lens", input_lens}}), y_zero_point); | ||
} | ||
MIGRAPHX_THROW( | ||
"QuantizeLinear: output_type and y_zero_point type must match. output_type: " + | ||
to_string(*output_type) + | ||
+", y_zero_point type: " + to_string(args[2]->get_shape().type())); | ||
} | ||
|
||
transform_quantize_dequantize_linear_inputs(info, opd.op_name, block_size, axis, args); | ||
|
||
common_args.push_back(y_zero_point); | ||
if(parser.opset_version < 19) | ||
{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are only two types supported for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought about adding these as well, but decided against mainly because I haven't noticed that it's common practice to have the type constraints checked in parser code, although I might be wrong here. |
||
auto common_type = common_shape({args[0]->get_shape(), args[1]->get_shape()}).type(); | ||
std::transform(args.begin(), args.begin() + 2, args.begin(), [&](auto ins) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just trying to understand here: Why is it args.begin() + 2. And not args.end(). Thanks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Prior to version 19, the first two inputs(x and y_scales) can have different float types, so a conversion to common type is needed to make the mgx operator work. The optional third input will have a type of int8 or uint8, and we want to leave it that way. |
||
if(ins->get_shape().type() != common_type) | ||
ins = info.add_instruction(make_op("convert", {{"target_type", common_type}}), | ||
ins); | ||
return ins; | ||
}); | ||
} | ||
|
||
return info.add_instruction(make_op("quantizelinear"), common_args); | ||
if(output_type.has_value()) | ||
return info.add_instruction(make_op("quantizelinear", {{"out_type", *output_type}}), | ||
args); | ||
else | ||
return info.add_instruction(make_op("quantizelinear"), args); | ||
} | ||
}; | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
/* | ||
* The MIT License (MIT) | ||
* | ||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. | ||
* | ||
* Permission is hereby granted, free of charge, to any person obtaining a copy | ||
* of this software and associated documentation files (the "Software"), to deal | ||
* in the Software without restriction, including without limitation the rights | ||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
* copies of the Software, and to permit persons to whom the Software is | ||
* furnished to do so, subject to the following conditions: | ||
* | ||
* The above copyright notice and this permission notice shall be included in | ||
* all copies or substantial portions of the Software. | ||
* | ||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
* THE SOFTWARE. | ||
*/ | ||
|
||
#include <migraphx/onnx/quantize_dequantize_linear.hpp> | ||
#include <migraphx/ranges.hpp> | ||
#include <migraphx/make_op.hpp> | ||
#include <migraphx/tune_axis.hpp> | ||
#include <migraphx/common.hpp> | ||
|
||
namespace migraphx { | ||
inline namespace MIGRAPHX_INLINE_NS { | ||
namespace onnx { | ||
|
||
void transform_quantize_dequantize_linear_inputs(const onnx_parser::node_info& info, | ||
const std::string& op_name, | ||
int block_size, | ||
int axis, | ||
std::vector<instruction_ref>& args) | ||
CharlieL7 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
const auto x = args.at(0); | ||
const auto x_lens = x->get_shape().lens(); | ||
const auto x_rank = x_lens.size(); | ||
|
||
instruction_ref y_scale = args.at(1); | ||
const auto y_scale_lens = y_scale->get_shape().lens(); | ||
const auto y_scale_rank = y_scale_lens.size(); | ||
|
||
// Per-tensor (per-layer) granularity | ||
if(y_scale->get_shape().elements() == 1) | ||
{ | ||
std::transform(args.begin() + 1, args.end(), args.begin() + 1, [&](auto ins) { | ||
return info.add_instruction(make_op("multibroadcast", {{"out_lens", x_lens}}), ins); | ||
}); | ||
} | ||
// Per-axis granularity | ||
else if(y_scale_rank == 1) | ||
{ | ||
axis = tune_axis(x_rank, axis, op_name); | ||
if(x_lens[axis] != y_scale_lens[0]) | ||
{ | ||
MIGRAPHX_THROW(op_name + ": For per axis granularity the length of y_scale (actual: " + | ||
to_string(y_scale_lens[0]) + ") must be equal to size of x on axis " + | ||
to_string(axis) + "(actual: " + to_string(x_lens[axis]) + ")"); | ||
} | ||
|
||
std::transform(args.begin() + 1, args.end(), args.begin() + 1, [&](auto ins) { | ||
return info.add_instruction( | ||
make_op("broadcast", {{"axis", axis}, {"out_lens", x_lens}}), ins); | ||
}); | ||
} | ||
// Blocked granularity | ||
else | ||
{ | ||
axis = tune_axis(x_rank, axis, op_name); | ||
if(block_size == 0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Our quark generated graph doesn't use an explicit There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Therefore, please remove this exception clause. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The ONNX spec states it is an optional attribute, with a default value of 0: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't have the quark-generated graph compiling with your current code. I can change this code later. Thanks. |
||
{ | ||
MIGRAPHX_THROW(op_name + ": Invalid blocksize(0)"); | ||
} | ||
|
||
if(x_rank != y_scale_rank) | ||
{ | ||
MIGRAPHX_THROW(op_name + ": x(rank: " + to_string(x_rank) + | ||
") and y_scale(rank: " + to_string(y_scale_rank) + | ||
") must be of same rank for block granularity"); | ||
} | ||
|
||
for(auto i = 0u; i < x_lens.size(); ++i) | ||
{ | ||
if(x_lens[i] != y_scale_lens[i] and i != axis) | ||
{ | ||
MIGRAPHX_THROW(op_name + ": x(shape: " + to_string_range(x_lens) + | ||
") and y_scale(shape: " + to_string_range(y_scale_lens) + | ||
") shapes may only differ along provided axis(" + to_string(axis) + | ||
")"); | ||
} | ||
} | ||
|
||
// Given x shape (D0, ..., Di, ..., Dn), y_scale shape (S0, ... Si, ...Sn) and | ||
// axis=i, the accepted range is [ceil(Di/Si), ceil(Di/(Si-1))-1] | ||
float di = x_lens[axis]; | ||
float si = y_scale_lens[axis]; | ||
int block_size_min = std::ceil(di / si); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sample code that can be added below if exception above is removed -- for
|
||
int block_size_max = std::ceil(di / (si - 1)) - 1; | ||
if(block_size < block_size_min or block_size > block_size_max) | ||
MIGRAPHX_THROW(op_name + ": Block size(actual: " + to_string(block_size) + | ||
") must be within range [" + to_string(block_size_min) + ", " + | ||
to_string(block_size_max) + "]"); | ||
|
||
std::transform(args.begin() + 1, args.end(), args.begin() + 1, [&](auto ins) { | ||
if(block_size == 1) | ||
return ins; | ||
|
||
ins = info.add_instruction(make_op("unsqueeze", {{"axes", {axis + 1}}}), ins); | ||
|
||
auto bc_lens = ins->get_shape().lens(); | ||
bc_lens[axis + 1] = block_size; | ||
ins = info.add_instruction(make_op("multibroadcast", {{"out_lens", bc_lens}}), ins); | ||
|
||
auto reshape_lens = x_lens; | ||
reshape_lens[axis] = ins->get_shape().lens()[axis] * block_size; | ||
ins = info.add_instruction(make_op("reshape", {{"dims", reshape_lens}}), ins); | ||
|
||
// Detect runt block | ||
if(x_lens[axis] < reshape_lens[axis]) | ||
{ | ||
ins = info.add_instruction( | ||
make_op("slice", {{"axes", {axis}}, {"starts", {0}}, {"ends", {x_lens[axis]}}}), | ||
ins); | ||
} | ||
|
||
return ins; | ||
}); | ||
} | ||
} | ||
|
||
} // namespace onnx | ||
} // namespace MIGRAPHX_INLINE_NS | ||
} // namespace migraphx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: "DequantizeLinear: y_scale and y_zero_point shape mismatch."