From b57b1e4e0bdfa50ad9c72548450e21c20b7b8863 Mon Sep 17 00:00:00 2001 From: Lakhinder Walia <139581206+lakhinderwalia@users.noreply.github.com> Date: Thu, 26 Sep 2024 13:51:56 -0700 Subject: [PATCH] int4: disable const_folding for unpack_int4 (#3322) --- src/CMakeLists.txt | 1 + src/driver/main.cpp | 6 + src/include/migraphx/quantization.hpp | 2 + src/include/migraphx/quantize_int4.hpp | 50 ++++++ src/include/migraphx/shape.hpp | 2 + src/propagate_constant.cpp | 6 +- src/quantization.cpp | 6 + src/quantize_int4.cpp | 108 +++++++++++ src/shape.cpp | 8 + src/simplify_qdq.cpp | 42 ++++- src/targets/gpu/mlir.cpp | 30 +++- test/gpu/mlir.cpp | 53 ++++++ test/int4_test.cpp | 115 ++++++++++++ test/onnx/gen_onnx.py | 167 ++++++++++++++++++ ...t4_const_identity_block_sz_1_qdq_test.onnx | Bin 0 -> 387 bytes ...t4_const_identity_block_sz_2_qdq_test.onnx | Bin 0 -> 329 bytes test/onnx/int4_const_identity_qdq_test.onnx | Bin 0 -> 365 bytes 17 files changed, 588 insertions(+), 8 deletions(-) create mode 100644 src/include/migraphx/quantize_int4.hpp create mode 100644 src/quantize_int4.cpp create mode 100644 test/int4_test.cpp create mode 100755 test/onnx/int4_const_identity_block_sz_1_qdq_test.onnx create mode 100755 test/onnx/int4_const_identity_block_sz_2_qdq_test.onnx create mode 100755 test/onnx/int4_const_identity_qdq_test.onnx diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 02e3cd0e6eb..1ffe2eb9f4e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -90,6 +90,7 @@ add_library(migraphx promote_literals.cpp quantization.cpp quantize_fp16.cpp + quantize_int4.cpp quantize_8bits.cpp reduce_dims.cpp register_op.cpp diff --git a/src/driver/main.cpp b/src/driver/main.cpp index 5402ef9b7dd..6c8b1359824 100644 --- a/src/driver/main.cpp +++ b/src/driver/main.cpp @@ -479,6 +479,7 @@ struct compiler bool to_fp16 = false; bool to_fp8 = false; bool to_int8 = false; + bool to_int4 = false; std::vector fill0; std::vector fill1; @@ -502,6 +503,7 @@ struct compiler ap(to_fp16, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(true)); ap(to_int8, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(true)); ap(to_fp8, {"--fp8"}, ap.help("Quantize for fp8"), ap.set_value(true)); + ap(to_int4, {"--int4-weights"}, ap.help("Quantize weights for int4"), ap.set_value(true)); } auto params(const program& p) @@ -556,6 +558,10 @@ struct compiler { quantize_fp8(p, t, {host_params(p)}); } + if(to_int4) + { + quantize_int4_weights(p); + } p.compile(t, co); l.save(p); return p; diff --git a/src/include/migraphx/quantization.hpp b/src/include/migraphx/quantization.hpp index c01662a1580..f9727ef06db 100644 --- a/src/include/migraphx/quantization.hpp +++ b/src/include/migraphx/quantization.hpp @@ -49,6 +49,8 @@ MIGRAPHX_EXPORT void quantize_int8(program& prog, MIGRAPHX_EXPORT void quantize_fp8(program& prog, const target& t, const std::vector& calibration); +MIGRAPHX_EXPORT void quantize_int4_weights(program& prog); + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/include/migraphx/quantize_int4.hpp b/src/include/migraphx/quantize_int4.hpp new file mode 100644 index 00000000000..586ad3f72ed --- /dev/null +++ b/src/include/migraphx/quantize_int4.hpp @@ -0,0 +1,50 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT4_HPP +#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT4_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct program; +struct module; + +/** + * quantize a program to int4 + */ +struct MIGRAPHX_EXPORT quantize_int4_pass +{ + std::vector ins_names; + std::string name() const { return "quantize_int4"; } + void apply(module& m) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index 290656f003d..4a5952d5eac 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -151,6 +151,8 @@ struct MIGRAPHX_EXPORT shape static bool is_integral(type_t t); static bool is_compatible(const shape& actual, const shape& expected); + static bool is_unsigned(type_t t); + shape(); shape(type_t t); shape(type_t t, std::vector l); diff --git a/src/propagate_constant.cpp b/src/propagate_constant.cpp index 3a6856d9450..7b9d548f16b 100644 --- a/src/propagate_constant.cpp +++ b/src/propagate_constant.cpp @@ -38,8 +38,10 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT) bool skip_propagate(instruction_ref ins) { - if(ins->name() == "contiguous") + if(ins->name() == "contiguous" or ins->name() == "dequantizelinear") return skip_propagate(ins->inputs().front()); + if(ins->name() == "unpack_int4") + return true; auto&& s = ins->get_shape(); if(s.broadcasted() and not s.scalar()) return true; @@ -48,7 +50,7 @@ bool skip_propagate(instruction_ref ins) return false; } -bool is_const_ins(instruction_ref ins, std::unordered_set skip_ops) +bool is_const_ins(instruction_ref ins, const std::unordered_set& skip_ops) { return ins->can_eval() and not skip_propagate(ins) and skip_ops.find(ins->name()) == skip_ops.end(); diff --git a/src/quantization.cpp b/src/quantization.cpp index 81974515a9c..bbe29e258e4 100644 --- a/src/quantization.cpp +++ b/src/quantization.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -165,6 +166,11 @@ void quantize_int8(program& prog, quantize_8bits(prog, t, shape::int8_type, calibration, ins_names); } +void quantize_int4_weights(program& prog) +{ + run_passes(prog, {normalize_ops{}, optimize_module{}, quantize_int4_pass{}}); +} + void quantize_fp8(program& prog, const target& t, const std::vector& calibration) { std::cout << "[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in " diff --git a/src/quantize_int4.cpp b/src/quantize_int4.cpp new file mode 100644 index 00000000000..12b59d088e6 --- /dev/null +++ b/src/quantize_int4.cpp @@ -0,0 +1,108 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +static void int4_quantize_module(module& m) +{ + std::vector int4_instrs{"dot", "convolution"}; + + for(auto ins : iterator_for(m)) + { + if(not(contains(int4_instrs, ins->name()))) + continue; + + if(ins->inputs().empty()) + continue; + + auto s = ins->get_shape(); + + auto mod_inputs = ins->module_inputs(); + + // Convert each of the inputs that are fp32 or fp16 to int4 + auto inputs = ins->inputs(); + std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto inp) { + auto sh = inp->get_shape(); + if(sh.broadcasted()) + return inp; + auto input_type = sh.type(); + if(input_type != shape::float_type and input_type != shape::half_type) + return inp; + auto lens = sh.lens(); + if(lens[lens.size() - 1] % 2) + return inp; // even sized dimensions to pack + + if(not inp->can_eval()) + return inp; + + std::vector val; + inp->eval().visit([&](auto in_data) { val.assign(in_data.begin(), in_data.end()); }); + + auto [min, max] = std::minmax_element(val.begin(), val.end()); + *min = *min > 0 ? 0 : *min; + *max = *max < 0 ? 0 : *max; + float fscale4 = (*max - *min) / 15; // INT4 range is [0-15] + int zp4 = float_equal(fscale4, 0) ? 0 : std::round(-*min / fscale4); + + auto scale = m.add_literal(literal({s.type()}, {fscale4})); + scale = + m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), scale); + auto zp = m.add_literal(literal{{shape::uint8_type}, {zp4}}); + zp = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), zp); + auto q_in = m.insert_instruction(ins, make_op("quantizelinear"), inp, scale, zp); + + auto pk = m.insert_instruction(ins, make_op("pack_int4", {{"axis", -1}}), q_in); + auto unpk = m.insert_instruction(ins, make_op("unpack_int4", {{"axis", -1}}), pk); + + auto dq_scale = m.add_literal(literal({s.type()}, {fscale4})); + dq_scale = m.insert_instruction( + ins, make_op("multibroadcast", {{"out_lens", lens}}), dq_scale); + + auto dq_zp = m.add_literal(literal{{shape::uint8_type}, {zp4}}); + dq_zp = + m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), dq_zp); + + return m.insert_instruction(ins, make_op("dequantizelinear"), unpk, dq_scale, dq_zp); + }); + + auto converted_ins = m.insert_instruction(ins, ins->get_operator(), inputs, mod_inputs); + m.replace_instruction(ins, converted_ins); + } +} + +void quantize_int4_pass::apply(module& m) const { int4_quantize_module(m); } + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/shape.cpp b/src/shape.cpp index 657a131be70..c9899403548 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -260,6 +260,7 @@ std::string shape::cpp_type(shape::type_t t) } MIGRAPHX_THROW("Invalid type"); } + bool shape::is_integral(shape::type_t t) { bool result = false; @@ -291,6 +292,13 @@ bool shape::is_compatible(const shape& actual, const shape& expected) }); } +bool shape::is_unsigned(shape::type_t t) +{ + bool result = false; + visit(t, [&](auto as) { result = as.is_unsigned(); }); + return result; +} + shape::shape() : impl(shape_impl::default_shape()) {} shape::shape(type_t t) : impl(std::make_shared(t)) {} diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index 378d87a73b6..3fb781a0f8d 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -152,7 +152,7 @@ struct match_find_quantizable_ops return; // Propagate q1 and q2 through any broadcasts and transposes before qop - auto qop_args = qop->inputs(); + auto qop_args = qop->inputs(); bool is_fp16_model = false; if(dq1->get_shape().type() != qop->get_shape().type() and qop->get_shape().type() == migraphx::shape::half_type) @@ -379,6 +379,24 @@ bool is_same_scale_zero(instruction_ref a, instruction_ref b) return is_same_value(a->inputs().at(2), b->inputs().at(2)); } +// When an unpack instruction is inserted, its original input must be an int4/uint4. +// Therefore check for an unpack_int4 operator -- while ignoring out shape related ops. +bool is_any_input_int4(instruction_ref a) +{ + static std::set ign = {"unsqueeze", + "broadcast", + "multibroadcast", + "contiguous", + "transpose", + "reshape", + "convert"}; + return std::any_of(a->inputs().begin(), a->inputs().end(), [](auto i) { + while(ign.find(i->name()) != ign.end()) + i = i->inputs()[0]; + return i->name() == "unpack_int4"; + }); +} + void remove_qdq_pairs(module& m) { for(auto ins : iterator_for(m)) @@ -397,10 +415,32 @@ void remove_qdq_pairs(module& m) } } } + +void add_int4_pack_unpack_pair(module& m) +{ + for(auto ins : iterator_for(m)) + { + if(ins->name() != "dequantizelinear") + continue; + + for(auto&& inp : ins->inputs()) + { + if((inp->name() == "quantizelinear") and is_any_input_int4(inp)) + { + auto pk = m.insert_instruction(ins, make_op("pack_int4"), inp); + auto unpk = m.insert_instruction(ins, make_op("unpack_int4"), pk); + instruction::replace_argument(ins, inp, unpk); + } + } + } +} + } // namespace void simplify_qdq::apply(module& m) const { + // first step: add pack/unpack pair between qdq for int4 weights + add_int4_pack_unpack_pair(m); match::find_matches(m, match_find_quantizable_ops{}); migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); remove_qdq_pairs(m); diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index b1e1378bc58..2442e2e8aed 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -325,7 +326,7 @@ struct mlir_program { MIGRAPHX_THROW("Unsupported type: " + std::to_string(as.type_enum())); } - result = mlirIntegerTypeGet(ctx.get(), as.size() * 8); + result = mlirIntegerTypeGet(ctx.get(), as.size() * 8); // number of bits } else MIGRAPHX_THROW("Unsupported type: " + std::to_string(as.type_enum())); @@ -623,9 +624,9 @@ struct mlir_program if(ins->name() == "@return") return "func.return"; if(ins->name() == "@literal") - { return "migraphx.literal"; - } + if(ins->name() == "unpack_int4") + return "migraphx.unpack"; return "migraphx." + ins->name(); } @@ -648,6 +649,10 @@ struct mlir_program std::copy(padding.begin(), padding.end(), std::back_inserter(v.at("padding"))); } } + + if(op.name() == "unpack_int4") + v["axis"] = ins->get_shape().ndim() - 1; + return v; } @@ -697,15 +702,30 @@ struct mlir_program ops.add_attribute_value(get_operator_value(ins)); if(ins->name() != "@return") ops.add_results({get_shape(ins)}); + if(ins->name() == "@literal") { - literal r = ins->get_literal(); - MlirType shaped_type = make_mlir_shaped(ins->get_shape()); + literal r = ins->get_literal(); + auto sh = ins->get_shape(); + + // mlir works only with signed types. change uint4 to (int4 + unsigned-flag) + if(shape::is_unsigned(sh.type()) and ins->outputs()[0]->name() == "unpack_int4") + sh = ins->get_shape().with_type(shape::int8_type); + + MlirType shaped_type = make_mlir_shaped(sh); MlirType tensor_type = rocmlirMIXRShapedTypeAsTensor(shaped_type); MlirAttribute mlir_value_attr = mlirDenseElementsAttrRawBufferGet(tensor_type, r.get_shape().bytes(), r.data()); ops.add_attributes({{"value", mlir_value_attr}}); } + + if(ins->name() == "unpack_int4") + { + auto sh = get_shape(ins); + ops.add_attributes( + {{"isUnsigned", shape::is_unsigned(sh.type())}}); // flag for uint4 + } + if(ins->name() == "convolution" or ins->name() == "dot") { pp = diff --git a/test/gpu/mlir.cpp b/test/gpu/mlir.cpp index cf140c16aa4..586b102aafc 100644 --- a/test/gpu/mlir.cpp +++ b/test/gpu/mlir.cpp @@ -454,4 +454,57 @@ module { EXPECT(verify_mlir(m)); } +TEST_CASE(int4_unpack_ir) +{ + std::string mlir_output = R"__migraphx__( +module { + func.func @mlir_unpack_int4(%arg0: !migraphx.shaped<2x1xi8, 1x1>) -> !migraphx.shaped<2x2xi8, 2x1> attributes ${attrs} { + %0 = migraphx.unpack %arg0 {axis = 1 : i64, isUnsigned = false} : <2x1xi8, 1x1> -> <2x2xi8, 2x1> + return %0 : !migraphx.shaped<2x2xi8, 2x1> + } +} +)__migraphx__"; + migraphx::module m; + auto arg0 = m.add_parameter("arg0", {migraphx::shape::int8_type, {2, 1}}); + auto unpk = m.add_instruction(migraphx::make_op("unpack_int4"), arg0); + m.add_return({unpk}); + auto s = migraphx::gpu::dump_mlir(m); + + // Skip test if MLIR is not enabled + if(s.empty()) + return; + + auto mlir_output_with_attrs = + migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); + + CHECK(encode(s) == encode(mlir_output_with_attrs)); +} + +TEST_CASE(int4_unpack_conv) +{ + std::string mlir_output = R"__migraphx__( + module { + func.func @mlir_unpack_int4_quant_convolution(%arg0: !migraphx.shaped<2x8x2x1xi8, 16x2x1x1>, %arg1: !migraphx.shaped<1x8x4x4xi8, 128x16x4x1>) -> !migraphx.shaped<1x2x3x3xi32, 18x9x3x1> attributes ${attrs} { + %0 = migraphx.unpack %arg0 {axis = 3 : i64, isUnsigned = false} : <2x8x2x1xi8, 16x2x1x1> -> <2x8x2x2xi8, 32x4x2x1> + %1 = migraphx.quant_convolution %arg1, %0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xi8, 128x16x4x1>, <2x8x2x2xi8, 32x4x2x1> -> <1x2x3x3xi32, 18x9x3x1> + return %1 : !migraphx.shaped<1x2x3x3xi32, 18x9x3x1> + } +} +)__migraphx__"; + migraphx::module m; + auto x = m.add_parameter("x", {migraphx::shape::int8_type, {1, 8, 4, 4}}); + auto pk_w = m.add_parameter("w", {migraphx::shape::int8_type, {2, 8, 2, 1}}); + auto w = m.add_instruction(migraphx::make_op("unpack_int4"), pk_w); + auto conv = m.add_instruction(migraphx::make_op("quant_convolution"), x, w); // w: {2,8,2,2} + m.add_return({conv}); + auto s = migraphx::gpu::dump_mlir(m); + // Skip test if MLIR is not enabled + if(s.empty()) + return; + auto mlir_output_with_attrs = + migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); + CHECK(encode(s) == encode(mlir_output_with_attrs)); + EXPECT(verify_mlir(m)); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/int4_test.cpp b/test/int4_test.cpp new file mode 100644 index 00000000000..71b6908dbda --- /dev/null +++ b/test/int4_test.cpp @@ -0,0 +1,115 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-24 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace match = migraphx::match; + +TEST_CASE(int4_pass_test) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}}); + auto w = m1.add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {16, 8, 6, 6}})); + auto conv = m1.add_instruction( + migraphx::make_op("convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + w); + m1.add_instruction(migraphx::make_op("relu"), conv); + } + migraphx::run_passes(m1, {migraphx::quantize_int4_pass{}}); + + auto chk_1 = match::name("quantizelinear")( + match::output(match::name("pack_int4")(match::output(match::name( + "unpack_int4")(match::output(match::name("dequantizelinear"))))))) + .bind("q"); + + auto res = find_match(m1, chk_1); + + EXPECT(migraphx::contains(res.instructions, "q")); +} + +TEST_CASE(int4_const_prop_test) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}}); + auto w = m1.add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {16, 8, 6, 6}})); + auto conv = m1.add_instruction( + migraphx::make_op("convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + w); + m1.add_instruction(migraphx::make_op("relu"), conv); + } + migraphx::run_passes(m1, + {migraphx::quantize_int4_pass{}, + migraphx::propagate_constant{}, + migraphx::dead_code_elimination{}}); + + auto chk_1 = match::name("pack_int4").bind("pack_int4"); + auto res_1 = find_match(m1, chk_1); + EXPECT(not migraphx::contains(res_1.instructions, "pack_int4")); + + auto chk_2 = match::name("unpack_int4").bind("unpack_int4"); + auto res_2 = find_match(m1, chk_2); + EXPECT(migraphx::contains(res_2.instructions, "unpack_int4")); +} + +TEST_CASE(int4_simplify_qdq_pass_test) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}}); + auto scale = m1.add_parameter("sc", {migraphx::shape::float_type, {1, 8, 16, 16}}); + auto zp = m1.add_parameter("zp", {migraphx::shape::int8_type, {1, 8, 16, 8}}); + auto un_pk_zp = m1.add_instruction(migraphx::make_op("unpack_int4"), zp); + auto q = m1.add_instruction(migraphx::make_op("quantizelinear"), x, scale, un_pk_zp); + auto dq = m1.add_instruction(migraphx::make_op("dequantizelinear"), q, scale, un_pk_zp); + m1.add_return({dq}); + } + + migraphx::run_passes(m1, {migraphx::simplify_qdq{}}); + + auto chk_1 = match::name("quantizelinear")( + match::output(match::name("pack_int4")(match::output(match::name( + "unpack_int4")(match::output(match::name("dequantizelinear"))))))) + .bind("q"); + auto res_1 = find_match(m1, chk_1); + EXPECT(migraphx::contains(res_1.instructions, "q")); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index a1c9d0dd81f..7b05f961b8c 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -5656,6 +5656,173 @@ def instance_norm_val_3d_test(): return ([node], [], [y], [x_tensor, scale_tensor, bias_tensor]) +@onnx_test() +def int4_const_identity_qdq_test(): + # Graph for int4, with an identity opr + QDQ + zp_values = np.array([0, 0, 0, 0]) + x_t = helper.make_tensor(name='i_x', + data_type=TensorProto.INT4, + dims=zp_values.shape, + vals=zp_values.flatten().astype(np.int32)) + + i_node = onnx.helper.make_node( + 'Identity', + inputs=['i_x'], + outputs=['i_y_zp'], + ) + + data_values = np.array([[-3, -4, -5, 2], [2, 2, 4, 4], [2, -2, 4, 6], + [2, 6, 6, 8]]) + data_t = helper.make_tensor(name='data', + data_type=TensorProto.FLOAT16, + dims=data_values.shape, + vals=data_values.flatten().astype(np.float16)) + + sc_values = np.array([1.0, 0.5, 1.0, 0.25]) + sc_t = helper.make_tensor(name='sc_q', + data_type=TensorProto.FLOAT16, + dims=sc_values.shape, + vals=sc_values.flatten().astype(np.float16)) + + q_node = onnx.helper.make_node( + 'QuantizeLinear', + inputs=['data', 'sc_q', 'i_y_zp'], + outputs=['q_y'], + ) + + #dequantizer uses same scale values as the quantizer: + sc_2_t = helper.make_tensor(name='sc_dq', + data_type=TensorProto.FLOAT16, + dims=sc_values.shape, + vals=sc_values.flatten().astype(np.float16)) + + dq_node = onnx.helper.make_node( + 'DequantizeLinear', + inputs=['q_y', 'sc_dq', 'i_y_zp'], + outputs=['dq_y'], + ) + + t_node = helper.make_node( + 'Transpose', + inputs=['dq_y'], + outputs=['y'], + perm=[1, 0], + ) + + y_t = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [4, 4]) + + return ([i_node, q_node, dq_node, + t_node], [], [y_t], [x_t, data_t, sc_t, sc_2_t]) + + +@onnx_test() +def int4_const_identity_block_sz_1_qdq_test(): + # Graph for int4, with an identity opr + QDQ. Quantization Block size = 1 + zp_values = np.array([[0, 0, 0, 0], [0, 0, 0, 0]]) + x_t = helper.make_tensor(name='i_x', + data_type=TensorProto.INT4, + dims=zp_values.shape, + vals=zp_values.flatten().astype(np.int32)) + + i_node = onnx.helper.make_node( + 'Identity', + inputs=['i_x'], + outputs=['i_y_zp'], + ) + + data_values = np.array([[-3, -4, -5, -6], [2, 3, 4, 5]]) + data_t = helper.make_tensor(name='data', + data_type=TensorProto.FLOAT16, + dims=data_values.shape, + vals=data_values.flatten().astype(np.float16)) + + sc_values = np.array([[0.5, 0.25, 0.5, 0.125], [0.25, 0.5, 0.5, 0.25]]) + sc_t = helper.make_tensor(name='sc_q', + data_type=TensorProto.FLOAT16, + dims=sc_values.shape, + vals=sc_values.flatten().astype(np.float16)) + + q_node = onnx.helper.make_node( + 'QuantizeLinear', + inputs=['data', 'sc_q', 'i_y_zp'], + outputs=['q_y'], + ) + + # dequantizer uses same scale values as the quantizer: + sc_2_t = helper.make_tensor(name='sc_dq', + data_type=TensorProto.FLOAT16, + dims=sc_values.shape, + vals=sc_values.flatten().astype(np.float16)) + + dq_node = onnx.helper.make_node( + 'DequantizeLinear', + inputs=['q_y', 'sc_dq', 'i_y_zp'], + outputs=['dq_y'], + ) + + t_node = helper.make_node( + 'Transpose', + inputs=['dq_y'], + outputs=['y'], + perm=[1, 0], + ) + + y_t = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [4, 2]) + + return ([i_node, q_node, dq_node, + t_node], [], [y_t], [x_t, data_t, sc_t, sc_2_t]) + + +@onnx_test() +def int4_const_identity_block_sz_2_qdq_test(): + # Graph for int4, with an identity opr + QDQ. Quantization Block size = 2 + zp_values = np.array([[0, 0], [0, 0]]) + x_t = helper.make_tensor(name='i_x', + data_type=TensorProto.INT4, + dims=zp_values.shape, + vals=zp_values.flatten().astype(np.int32)) + + i_node = onnx.helper.make_node( + 'Identity', + inputs=['i_x'], + outputs=['i_y_zp'], + ) + + data_values = np.array([[-3, -4, -6, -8], [2, 3, 4, 6]]) + data_t = helper.make_tensor(name='data', + data_type=TensorProto.FLOAT16, + dims=data_values.shape, + vals=data_values.flatten().astype(np.float16)) + + sc_values = np.array([[0.5, 0.125], [0.5, 0.25]]) + sc_t = helper.make_tensor(name='sc_q', + data_type=TensorProto.FLOAT16, + dims=sc_values.shape, + vals=sc_values.flatten().astype(np.float16)) + + q_node = onnx.helper.make_node( + 'QuantizeLinear', + inputs=['data', 'sc_q', 'i_y_zp'], + outputs=['q_y'], + ) + + # dequantizer uses same scale values as the quantizer: + sc_2_t = helper.make_tensor(name='sc_dq', + data_type=TensorProto.FLOAT16, + dims=sc_values.shape, + vals=sc_values.flatten().astype(np.float16)) + + dq_node = onnx.helper.make_node( + 'DequantizeLinear', + inputs=['q_y', 'sc_dq', 'i_y_zp'], + outputs=['y'], + ) + + y_t = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [2, 4]) + + return ([i_node, q_node, dq_node], [], [y_t], [x_t, data_t, sc_t, sc_2_t]) + + @onnx_test() def isinf_half_test(): t1 = helper.make_tensor_value_info('t1', TensorProto.FLOAT16, [2, 3]) diff --git a/test/onnx/int4_const_identity_block_sz_1_qdq_test.onnx b/test/onnx/int4_const_identity_block_sz_1_qdq_test.onnx new file mode 100755 index 0000000000000000000000000000000000000000..d561cd69928d160c8659eca9081507eb048057e3 GIT binary patch literal 387 zcma)%Jxc^J5QZ$>#YV6Z?8m*`fI;1_L@-;WxZ*hc5&k9r zk-2q2ECd4+!n_YNc{qqOUD`{`tI}B1g(|JK4rk9*{)DE(2{uK8RvCNs8NmTijc;OI zqrIj|zL88DM$14AxogDx50y=Gx(r*3trvKB62)rPe|7C$HpTU-)ftL8U- z5(bAv5SMJ;w^?cG$|$)FtXA#IluSn-WaHloa)(Ep#ru+t!bw*y^0aS(JoIzw7t}AQ zf2ZCPKPP@c{F0=-nLHY}*w_9s)I)!VciKD4<3ICn)7dT%C&sYFL2wA)2o08Mq zPbZr9dilvv-^Uc{)A%xUvEL|6rt&IydMCEI*g6YELXAVa|ty+m~tD;&rCBGR> za0;x}K?Ca+%EPD&>QyYKch4Fpt<9}7}0DDy0<`sL*nXGnlNPcJlWXnLh7kr5k`5d9M+ wlHP_po``o+4H4zR(9P2-HCT_;5LwRtu;1|>_JKq(g*^^}L(W-WRbI@#0M`9!`v3p{ literal 0 HcmV?d00001