Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

int4: disable const_folding for unpack_int4 #3322

Merged
merged 9 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ add_library(migraphx
promote_literals.cpp
quantization.cpp
quantize_fp16.cpp
quantize_int4.cpp
quantize_8bits.cpp
reduce_dims.cpp
register_op.cpp
Expand Down
6 changes: 6 additions & 0 deletions src/driver/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ struct compiler
bool to_fp16 = false;
bool to_fp8 = false;
bool to_int8 = false;
bool to_int4 = false;

std::vector<std::string> fill0;
std::vector<std::string> fill1;
Expand All @@ -502,6 +503,7 @@ struct compiler
ap(to_fp16, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(true));
ap(to_int8, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(true));
ap(to_fp8, {"--fp8"}, ap.help("Quantize for fp8"), ap.set_value(true));
ap(to_int4, {"--int4-weights"}, ap.help("Quantize weights for int4"), ap.set_value(true));
}

auto params(const program& p)
Expand Down Expand Up @@ -556,6 +558,10 @@ struct compiler
{
quantize_fp8(p, t, {host_params(p)});
}
if(to_int4)
{
quantize_int4_weights(p);
}
p.compile(t, co);
l.save(p);
return p;
Expand Down
2 changes: 2 additions & 0 deletions src/include/migraphx/quantization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ MIGRAPHX_EXPORT void quantize_int8(program& prog,
MIGRAPHX_EXPORT void
quantize_fp8(program& prog, const target& t, const std::vector<parameter_map>& calibration);

MIGRAPHX_EXPORT void quantize_int4_weights(program& prog);

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

Expand Down
50 changes: 50 additions & 0 deletions src/include/migraphx/quantize_int4.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT4_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT4_HPP

#include <string>
#include <vector>
#include <migraphx/config.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct program;
struct module;

/**
* quantize a program to int4
*/
struct MIGRAPHX_EXPORT quantize_int4_pass
{
std::vector<std::string> ins_names;
std::string name() const { return "quantize_int4"; }
void apply(module& m) const;
};

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
2 changes: 2 additions & 0 deletions src/include/migraphx/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ struct MIGRAPHX_EXPORT shape
static bool is_integral(type_t t);
static bool is_compatible(const shape& actual, const shape& expected);

static bool is_unsigned(type_t t);

shape();
shape(type_t t);
shape(type_t t, std::vector<std::size_t> l);
Expand Down
6 changes: 4 additions & 2 deletions src/propagate_constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT)

bool skip_propagate(instruction_ref ins)
{
if(ins->name() == "contiguous")
if(ins->name() == "contiguous" or ins->name() == "dequantizelinear")
return skip_propagate(ins->inputs().front());
if(ins->name() == "unpack_int4")
return true;
auto&& s = ins->get_shape();
if(s.broadcasted() and not s.scalar())
return true;
Expand All @@ -48,7 +50,7 @@ bool skip_propagate(instruction_ref ins)
return false;
}

bool is_const_ins(instruction_ref ins, std::unordered_set<std::string> skip_ops)
bool is_const_ins(instruction_ref ins, const std::unordered_set<std::string>& skip_ops)
{
return ins->can_eval() and not skip_propagate(ins) and
skip_ops.find(ins->name()) == skip_ops.end();
Expand Down
6 changes: 6 additions & 0 deletions src/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/quantize_8bits.hpp>
#include <migraphx/quantize_int4.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
Expand Down Expand Up @@ -165,6 +166,11 @@
quantize_8bits(prog, t, shape::int8_type, calibration, ins_names);
}

void quantize_int4_weights(program& prog)

Check warning on line 169 in src/quantization.cpp

View check run for this annotation

Codecov / codecov/patch

src/quantization.cpp#L169

Added line #L169 was not covered by tests
{
run_passes(prog, {normalize_ops{}, optimize_module{}, quantize_int4_pass{}});
}

Check warning on line 172 in src/quantization.cpp

View check run for this annotation

Codecov / codecov/patch

src/quantization.cpp#L171-L172

Added lines #L171 - L172 were not covered by tests

void quantize_fp8(program& prog, const target& t, const std::vector<parameter_map>& calibration)
{
std::cout << "[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
Expand Down
108 changes: 108 additions & 0 deletions src/quantize_int4.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantize_int4.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <migraphx/make_op.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

static void int4_quantize_module(module& m)
{
std::vector<std::string> int4_instrs{"dot", "convolution"};

for(auto ins : iterator_for(m))
{
if(not(contains(int4_instrs, ins->name())))
continue;

if(ins->inputs().empty())
continue;

Check warning on line 48 in src/quantize_int4.cpp

View check run for this annotation

Codecov / codecov/patch

src/quantize_int4.cpp#L48

Added line #L48 was not covered by tests

auto s = ins->get_shape();

auto mod_inputs = ins->module_inputs();

// Convert each of the inputs that are fp32 or fp16 to int4
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto inp) {
auto sh = inp->get_shape();
if(sh.broadcasted())
return inp;

Check warning on line 59 in src/quantize_int4.cpp

View check run for this annotation

Codecov / codecov/patch

src/quantize_int4.cpp#L59

Added line #L59 was not covered by tests
auto input_type = sh.type();
if(input_type != shape::float_type and input_type != shape::half_type)
return inp;

Check warning on line 62 in src/quantize_int4.cpp

View check run for this annotation

Codecov / codecov/patch

src/quantize_int4.cpp#L62

Added line #L62 was not covered by tests
auto lens = sh.lens();
if(lens[lens.size() - 1] % 2)
return inp; // even sized dimensions to pack

Check warning on line 65 in src/quantize_int4.cpp

View check run for this annotation

Codecov / codecov/patch

src/quantize_int4.cpp#L65

Added line #L65 was not covered by tests

if(not inp->can_eval())
return inp;

std::vector<float> val;
inp->eval().visit([&](auto in_data) { val.assign(in_data.begin(), in_data.end()); });
lakhinderwalia marked this conversation as resolved.
Show resolved Hide resolved

auto [min, max] = std::minmax_element(val.begin(), val.end());
*min = *min > 0 ? 0 : *min;
*max = *max < 0 ? 0 : *max;
float fscale4 = (*max - *min) / 15; // INT4 range is [0-15]
int zp4 = float_equal(fscale4, 0) ? 0 : std::round(-*min / fscale4);

auto scale = m.add_literal(literal({s.type()}, {fscale4}));
scale =
m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), scale);
auto zp = m.add_literal(literal{{shape::uint8_type}, {zp4}});
zp = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), zp);
auto q_in = m.insert_instruction(ins, make_op("quantizelinear"), inp, scale, zp);

auto pk = m.insert_instruction(ins, make_op("pack_int4", {{"axis", -1}}), q_in);
auto unpk = m.insert_instruction(ins, make_op("unpack_int4", {{"axis", -1}}), pk);

auto dq_scale = m.add_literal(literal({s.type()}, {fscale4}));
dq_scale = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", lens}}), dq_scale);

auto dq_zp = m.add_literal(literal{{shape::uint8_type}, {zp4}});
dq_zp =
m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), dq_zp);

return m.insert_instruction(ins, make_op("dequantizelinear"), unpk, dq_scale, dq_zp);
});

auto converted_ins = m.insert_instruction(ins, ins->get_operator(), inputs, mod_inputs);
m.replace_instruction(ins, converted_ins);
}
}

void quantize_int4_pass::apply(module& m) const { int4_quantize_module(m); }

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
8 changes: 8 additions & 0 deletions src/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@
}
MIGRAPHX_THROW("Invalid type");
}

bool shape::is_integral(shape::type_t t)
{
bool result = false;
Expand Down Expand Up @@ -291,6 +292,13 @@
});
}

bool shape::is_unsigned(shape::type_t t)

Check warning on line 295 in src/shape.cpp

View check run for this annotation

Codecov / codecov/patch

src/shape.cpp#L295

Added line #L295 was not covered by tests
{
bool result = false;
visit(t, [&](auto as) { result = as.is_unsigned(); });
return result;

Check warning on line 299 in src/shape.cpp

View check run for this annotation

Codecov / codecov/patch

src/shape.cpp#L297-L299

Added lines #L297 - L299 were not covered by tests
}

shape::shape() : impl(shape_impl::default_shape()) {}

shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {}
Expand Down
42 changes: 41 additions & 1 deletion src/simplify_qdq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ struct match_find_quantizable_ops
return;

// Propagate q1 and q2 through any broadcasts and transposes before qop
auto qop_args = qop->inputs();
auto qop_args = qop->inputs();
bool is_fp16_model = false;
if(dq1->get_shape().type() != qop->get_shape().type() and
qop->get_shape().type() == migraphx::shape::half_type)
Expand Down Expand Up @@ -379,6 +379,24 @@ bool is_same_scale_zero(instruction_ref a, instruction_ref b)
return is_same_value(a->inputs().at(2), b->inputs().at(2));
}

// When an unpack instruction is inserted, its original input must be an int4/uint4.
// Therefore check for an unpack_int4 operator -- while ignoring out shape related ops.
bool is_any_input_int4(instruction_ref a)
{
static std::set<std::string> ign = {"unsqueeze",
"broadcast",
"multibroadcast",
"contiguous",
"transpose",
"reshape",
"convert"};
return std::any_of(a->inputs().begin(), a->inputs().end(), [](auto i) {
while(ign.find(i->name()) != ign.end())
i = i->inputs()[0];
return i->name() == "unpack_int4";
});
}

void remove_qdq_pairs(module& m)
{
for(auto ins : iterator_for(m))
Expand All @@ -397,10 +415,32 @@ void remove_qdq_pairs(module& m)
}
}
}

void add_int4_pack_unpack_pair(module& m)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "dequantizelinear")
continue;

for(auto&& inp : ins->inputs())
{
if((inp->name() == "quantizelinear") and is_any_input_int4(inp))
{
auto pk = m.insert_instruction(ins, make_op("pack_int4"), inp);
auto unpk = m.insert_instruction(ins, make_op("unpack_int4"), pk);
instruction::replace_argument(ins, inp, unpk);
}
}
}
}

} // namespace

void simplify_qdq::apply(module& m) const
{
// first step: add pack/unpack pair between qdq for int4 weights
add_int4_pack_unpack_pair(m);
lakhinderwalia marked this conversation as resolved.
Show resolved Hide resolved
match::find_matches(m, match_find_quantizable_ops{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
remove_qdq_pairs(m);
Expand Down
Loading
Loading