Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

int8 optimizations #1973

Merged
merged 17 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions src/rewrite_quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <migraphx/tune_axis.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/common.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -61,13 +62,10 @@ void apply_quantizelinear(module& m, instruction_ref ins)
max_quant = qt.max();
min_quant = qt.min();
});
auto s = add_zero_point->get_shape();
std::vector<int> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant);
auto min_arg = m.add_literal(literal(s, min_data));
auto max_arg = m.add_literal(literal(s, max_data));

auto saturate = m.insert_instruction(ins, make_op("clip"), add_zero_point, min_arg, max_arg);
auto s = add_zero_point->get_shape();
auto min_arg = m.add_literal(literal{shape{s.type()}, {min_quant}});
auto max_arg = m.add_literal(literal{shape{s.type()}, {max_quant}});
auto saturate = insert_common_op(m, ins, make_op("clip"), {add_zero_point, min_arg, max_arg});
m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
}
Expand Down
11 changes: 6 additions & 5 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1095,8 +1095,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
};
};
auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot"));
auto qdots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("quant_dot"));
auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution"));
return (dots >= 2 or convs >= 2);
return (dots >= 2 or convs >= 2 or qdots >= 2);
}

struct find_conv_dot_horiz_fusion
Expand All @@ -1110,15 +1111,15 @@ struct find_conv_dot_horiz_fusion
auto pred = [](auto i, auto j) {
if(i->get_operator() != j->get_operator())
return false;
if(not contains({"dot", "convolution"}, i->name()))
if(not contains({"quant_dot", "dot", "convolution"}, i->name()))
return true;
auto x = i->inputs()[1]->get_shape().lens();
auto y = j->inputs()[1]->get_shape().lens();
if(x.size() != y.size())
return false;
// Check that non-axes match
int axis = 1;
if(i->name() == "dot")
if(i->name() == "dot" or i->name() == "quant_dot")
{
axis = x.size() - 1;
}
Expand All @@ -1129,7 +1130,7 @@ struct find_conv_dot_horiz_fusion
if(std::distance(start, last) < 2)
return;
auto&& name = (*start)->name();
if(not contains({"dot", "convolution"}, name))
if(not contains({"quant_dot", "dot", "convolution"}, name))
return;
auto op = (*start)->get_operator();
int group = 1;
Expand All @@ -1144,7 +1145,7 @@ struct find_conv_dot_horiz_fusion
start, last, std::back_inserter(args), [&](auto x) { return x->inputs().at(1); });
int axis = 1;
int concat_axis = 0;
if(name == "dot")
if(name == "dot" or name == "quant_dot")
{
axis = int(args.front()->get_shape().lens().size() - 1);
concat_axis = axis;
Expand Down
68 changes: 38 additions & 30 deletions test/onnx/onnx_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4712,14 +4712,16 @@ TEST_CASE(quantizelinear_test)
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto s = round->get_shape();
std::vector<int> min_data(s.elements(), 0);
std::vector<int> max_data(s.elements(), 255);
auto min_arg = mm->add_literal(s, min_data);
auto max_arg = mm->add_literal(s, max_data);
auto clip = mm->add_instruction(migraphx::make_op("clip"), round, min_arg, max_arg);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto s = round->get_shape();
auto min_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {0}});
auto max_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {255}});
auto min_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), min_arg);
auto max_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), max_arg);
auto clip = mm->add_instruction(migraphx::make_op("clip"), round, min_mbcast, max_mbcast);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}),
Expand All @@ -4741,14 +4743,16 @@ TEST_CASE(quantizelinear_int32_test)
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l0);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto s = round->get_shape();
std::vector<int> min_data(s.elements(), 0);
std::vector<int> max_data(s.elements(), 255);
auto min_arg = mm->add_literal(s, min_data);
auto max_arg = mm->add_literal(s, max_data);
auto clip = mm->add_instruction(migraphx::make_op("clip"), round, min_arg, max_arg);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto s = round->get_shape();
auto min_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {0}});
auto max_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {255}});
auto min_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), min_arg);
auto max_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), max_arg);
auto clip = mm->add_instruction(migraphx::make_op("clip"), round, min_mbcast, max_mbcast);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}),
Expand All @@ -4775,13 +4779,15 @@ TEST_CASE(quantizelinear_zero_point_test)
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l2_mbcast);
auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_mbcast);
auto s = round->get_shape();
std::vector<int> min_data(s.elements(), -128);
std::vector<int> max_data(s.elements(), 127);
auto min_arg = mm->add_literal(s, min_data);
auto max_arg = mm->add_literal(s, max_data);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_arg, max_arg);
auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_mbcast);
auto s = round->get_shape();
auto min_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {-128}});
auto max_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {127}});
auto min_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), min_arg);
auto max_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), max_arg);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_mbcast, max_mbcast);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
Expand Down Expand Up @@ -4812,13 +4818,15 @@ migraphx::program make_quantizelinear_axis_prog()
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l2_bcast);
auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_bcast);
auto s = round->get_shape();
std::vector<int> min_data(s.elements(), -128);
std::vector<int> max_data(s.elements(), 127);
auto min_arg = mm->add_literal(s, min_data);
auto max_arg = mm->add_literal(s, max_data);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_arg, max_arg);
auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_bcast);
auto s = round->get_shape();
auto min_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {-128}});
auto max_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {127}});
auto min_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), min_arg);
auto max_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), max_arg);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_mbcast, max_mbcast);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
Expand Down
13 changes: 13 additions & 0 deletions test/rewrite_quantization_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@

bool is_quantizelinear(migraphx::instruction& ins) { return ins.name() == "quantizelinear"; }
bool is_dequantizelinear(migraphx::instruction& ins) { return ins.name() == "dequantizelinear"; }
bool is_clip_scalar(migraphx::instruction& ins)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be const.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should all of them be changed to const?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yea they should probably all be const.

{
if(ins.name() == "clip")
{
assert(ins.inputs().size() > 1);
return (std::all_of(ins.inputs().begin() + 1, ins.inputs().end(), [](auto input) {
return input->get_shape().scalar();
}));
}
return false;
}

void run_pass(migraphx::module& m) { migraphx::run_passes(m, {migraphx::rewrite_quantization{}}); }

Expand Down Expand Up @@ -70,6 +81,8 @@ TEST_CASE(quantizelinear)
EXPECT(eval(p1) == eval(p2));
EXPECT(any_of(*p1.get_main_module(), &is_quantizelinear));
EXPECT(none_of(*p2.get_main_module(), &is_quantizelinear));
// ensure clip literals created in quantized program are scalar
EXPECT(any_of(*p2.get_main_module(), &is_clip_scalar));
}

TEST_CASE(dequantizelinear)
Expand Down
14 changes: 9 additions & 5 deletions test/simplify_algebra_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2189,16 +2189,16 @@ TEST_CASE(simplify_split_between_add)
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(simplify_dot_horiz)
void test_dot_horiz(migraphx::shape::type_t type, const std::string& dot_type)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}};
auto s = migraphx::shape{type, {3, 2, 2}};
migraphx::module m1;
{
auto input = m1.add_parameter("input", s);
auto a = m1.add_literal(migraphx::generate_literal(s, 0));
auto b = m1.add_literal(migraphx::generate_literal(s, 1));
auto x = m1.add_instruction(migraphx::make_op("dot"), input, a);
auto y = m1.add_instruction(migraphx::make_op("dot"), input, b);
auto x = m1.add_instruction(migraphx::make_op(dot_type), input, a);
auto y = m1.add_instruction(migraphx::make_op(dot_type), input, b);
auto sum = m1.add_instruction(migraphx::make_op("add"), x, y);
m1.add_instruction(pass_op{}, sum);
}
Expand All @@ -2210,7 +2210,7 @@ TEST_CASE(simplify_dot_horiz)
auto a = m2.add_literal(migraphx::generate_literal(s, 0));
auto b = m2.add_literal(migraphx::generate_literal(s, 1));
auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), a, b);
auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat);
auto dot = m2.add_instruction(migraphx::make_op(dot_type), input, concat);
auto x = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), dot);
auto y = m2.add_instruction(
Expand All @@ -2221,6 +2221,10 @@ TEST_CASE(simplify_dot_horiz)
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(simplify_dot_horiz) { test_dot_horiz(migraphx::shape::int32_type, "dot"); }

TEST_CASE(simplify_quant_dot_horiz) { test_dot_horiz(migraphx::shape::int8_type, "quant_dot"); }

TEST_CASE(simplify_dot_horiz_same_constant)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}};
Expand Down
Loading