Skip to content

Commit

Permalink
int4: disable const_folding for unpack_int4 (#3322)
Browse files Browse the repository at this point in the history
  • Loading branch information
lakhinderwalia committed Sep 26, 2024
1 parent 1cd2854 commit b57b1e4
Show file tree
Hide file tree
Showing 17 changed files with 588 additions and 8 deletions.
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ add_library(migraphx
promote_literals.cpp
quantization.cpp
quantize_fp16.cpp
quantize_int4.cpp
quantize_8bits.cpp
reduce_dims.cpp
register_op.cpp
Expand Down
6 changes: 6 additions & 0 deletions src/driver/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ struct compiler
bool to_fp16 = false;
bool to_fp8 = false;
bool to_int8 = false;
bool to_int4 = false;

std::vector<std::string> fill0;
std::vector<std::string> fill1;
Expand All @@ -502,6 +503,7 @@ struct compiler
ap(to_fp16, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(true));
ap(to_int8, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(true));
ap(to_fp8, {"--fp8"}, ap.help("Quantize for fp8"), ap.set_value(true));
ap(to_int4, {"--int4-weights"}, ap.help("Quantize weights for int4"), ap.set_value(true));
}

auto params(const program& p)
Expand Down Expand Up @@ -556,6 +558,10 @@ struct compiler
{
quantize_fp8(p, t, {host_params(p)});
}
if(to_int4)
{
quantize_int4_weights(p);
}
p.compile(t, co);
l.save(p);
return p;
Expand Down
2 changes: 2 additions & 0 deletions src/include/migraphx/quantization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ MIGRAPHX_EXPORT void quantize_int8(program& prog,
MIGRAPHX_EXPORT void
quantize_fp8(program& prog, const target& t, const std::vector<parameter_map>& calibration);

MIGRAPHX_EXPORT void quantize_int4_weights(program& prog);

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

Expand Down
50 changes: 50 additions & 0 deletions src/include/migraphx/quantize_int4.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT4_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT4_HPP

#include <string>
#include <vector>
#include <migraphx/config.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct program;
struct module;

/**
* quantize a program to int4
*/
struct MIGRAPHX_EXPORT quantize_int4_pass
{
std::vector<std::string> ins_names;
std::string name() const { return "quantize_int4"; }
void apply(module& m) const;
};

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
2 changes: 2 additions & 0 deletions src/include/migraphx/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ struct MIGRAPHX_EXPORT shape
static bool is_integral(type_t t);
static bool is_compatible(const shape& actual, const shape& expected);

static bool is_unsigned(type_t t);

shape();
shape(type_t t);
shape(type_t t, std::vector<std::size_t> l);
Expand Down
6 changes: 4 additions & 2 deletions src/propagate_constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT)

bool skip_propagate(instruction_ref ins)
{
if(ins->name() == "contiguous")
if(ins->name() == "contiguous" or ins->name() == "dequantizelinear")
return skip_propagate(ins->inputs().front());
if(ins->name() == "unpack_int4")
return true;
auto&& s = ins->get_shape();
if(s.broadcasted() and not s.scalar())
return true;
Expand All @@ -48,7 +50,7 @@ bool skip_propagate(instruction_ref ins)
return false;
}

bool is_const_ins(instruction_ref ins, std::unordered_set<std::string> skip_ops)
bool is_const_ins(instruction_ref ins, const std::unordered_set<std::string>& skip_ops)
{
return ins->can_eval() and not skip_propagate(ins) and
skip_ops.find(ins->name()) == skip_ops.end();
Expand Down
6 changes: 6 additions & 0 deletions src/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/quantize_8bits.hpp>
#include <migraphx/quantize_int4.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
Expand Down Expand Up @@ -165,6 +166,11 @@ void quantize_int8(program& prog,
quantize_8bits(prog, t, shape::int8_type, calibration, ins_names);
}

void quantize_int4_weights(program& prog)
{
run_passes(prog, {normalize_ops{}, optimize_module{}, quantize_int4_pass{}});
}

void quantize_fp8(program& prog, const target& t, const std::vector<parameter_map>& calibration)
{
std::cout << "[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
Expand Down
108 changes: 108 additions & 0 deletions src/quantize_int4.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantize_int4.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <migraphx/make_op.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

static void int4_quantize_module(module& m)
{
std::vector<std::string> int4_instrs{"dot", "convolution"};

for(auto ins : iterator_for(m))
{
if(not(contains(int4_instrs, ins->name())))
continue;

if(ins->inputs().empty())
continue;

auto s = ins->get_shape();

auto mod_inputs = ins->module_inputs();

// Convert each of the inputs that are fp32 or fp16 to int4
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto inp) {
auto sh = inp->get_shape();
if(sh.broadcasted())
return inp;
auto input_type = sh.type();
if(input_type != shape::float_type and input_type != shape::half_type)
return inp;
auto lens = sh.lens();
if(lens[lens.size() - 1] % 2)
return inp; // even sized dimensions to pack

if(not inp->can_eval())
return inp;

std::vector<float> val;
inp->eval().visit([&](auto in_data) { val.assign(in_data.begin(), in_data.end()); });

auto [min, max] = std::minmax_element(val.begin(), val.end());
*min = *min > 0 ? 0 : *min;
*max = *max < 0 ? 0 : *max;
float fscale4 = (*max - *min) / 15; // INT4 range is [0-15]
int zp4 = float_equal(fscale4, 0) ? 0 : std::round(-*min / fscale4);

auto scale = m.add_literal(literal({s.type()}, {fscale4}));
scale =
m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), scale);
auto zp = m.add_literal(literal{{shape::uint8_type}, {zp4}});
zp = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), zp);
auto q_in = m.insert_instruction(ins, make_op("quantizelinear"), inp, scale, zp);

auto pk = m.insert_instruction(ins, make_op("pack_int4", {{"axis", -1}}), q_in);
auto unpk = m.insert_instruction(ins, make_op("unpack_int4", {{"axis", -1}}), pk);

auto dq_scale = m.add_literal(literal({s.type()}, {fscale4}));
dq_scale = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", lens}}), dq_scale);

auto dq_zp = m.add_literal(literal{{shape::uint8_type}, {zp4}});
dq_zp =
m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), dq_zp);

return m.insert_instruction(ins, make_op("dequantizelinear"), unpk, dq_scale, dq_zp);
});

auto converted_ins = m.insert_instruction(ins, ins->get_operator(), inputs, mod_inputs);
m.replace_instruction(ins, converted_ins);
}
}

void quantize_int4_pass::apply(module& m) const { int4_quantize_module(m); }

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
8 changes: 8 additions & 0 deletions src/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ std::string shape::cpp_type(shape::type_t t)
}
MIGRAPHX_THROW("Invalid type");
}

bool shape::is_integral(shape::type_t t)
{
bool result = false;
Expand Down Expand Up @@ -291,6 +292,13 @@ bool shape::is_compatible(const shape& actual, const shape& expected)
});
}

bool shape::is_unsigned(shape::type_t t)
{
bool result = false;
visit(t, [&](auto as) { result = as.is_unsigned(); });
return result;
}

shape::shape() : impl(shape_impl::default_shape()) {}

shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {}
Expand Down
42 changes: 41 additions & 1 deletion src/simplify_qdq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ struct match_find_quantizable_ops
return;

// Propagate q1 and q2 through any broadcasts and transposes before qop
auto qop_args = qop->inputs();
auto qop_args = qop->inputs();
bool is_fp16_model = false;
if(dq1->get_shape().type() != qop->get_shape().type() and
qop->get_shape().type() == migraphx::shape::half_type)
Expand Down Expand Up @@ -379,6 +379,24 @@ bool is_same_scale_zero(instruction_ref a, instruction_ref b)
return is_same_value(a->inputs().at(2), b->inputs().at(2));
}

// When an unpack instruction is inserted, its original input must be an int4/uint4.
// Therefore check for an unpack_int4 operator -- while ignoring out shape related ops.
bool is_any_input_int4(instruction_ref a)
{
static std::set<std::string> ign = {"unsqueeze",
"broadcast",
"multibroadcast",
"contiguous",
"transpose",
"reshape",
"convert"};
return std::any_of(a->inputs().begin(), a->inputs().end(), [](auto i) {
while(ign.find(i->name()) != ign.end())
i = i->inputs()[0];
return i->name() == "unpack_int4";
});
}

void remove_qdq_pairs(module& m)
{
for(auto ins : iterator_for(m))
Expand All @@ -397,10 +415,32 @@ void remove_qdq_pairs(module& m)
}
}
}

void add_int4_pack_unpack_pair(module& m)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "dequantizelinear")
continue;

for(auto&& inp : ins->inputs())
{
if((inp->name() == "quantizelinear") and is_any_input_int4(inp))
{
auto pk = m.insert_instruction(ins, make_op("pack_int4"), inp);
auto unpk = m.insert_instruction(ins, make_op("unpack_int4"), pk);
instruction::replace_argument(ins, inp, unpk);
}
}
}
}

} // namespace

void simplify_qdq::apply(module& m) const
{
// first step: add pack/unpack pair between qdq for int4 weights
add_int4_pack_unpack_pair(m);
match::find_matches(m, match_find_quantizable_ops{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
remove_qdq_pairs(m);
Expand Down
Loading

0 comments on commit b57b1e4

Please sign in to comment.